mct-nightly 2.2.0.20241203.546__py3-none-any.whl → 2.2.0.20241205.533__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/RECORD +29 -29
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/common/graph/base_graph.py +9 -5
  5. model_compression_toolkit/core/common/graph/base_node.py +2 -3
  6. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +32 -35
  7. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +9 -9
  8. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +5 -11
  9. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +12 -0
  10. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +11 -4
  11. model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +4 -6
  12. model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +6 -11
  13. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +6 -9
  14. model_compression_toolkit/core/keras/data_util.py +151 -18
  15. model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py +93 -86
  16. model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py +17 -0
  17. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -2
  18. model_compression_toolkit/core/keras/keras_implementation.py +23 -27
  19. model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -4
  20. model_compression_toolkit/gptq/common/gptq_training.py +58 -0
  21. model_compression_toolkit/gptq/keras/gptq_loss.py +35 -2
  22. model_compression_toolkit/gptq/keras/gptq_training.py +137 -67
  23. model_compression_toolkit/gptq/keras/graph_info.py +1 -4
  24. model_compression_toolkit/gptq/keras/quantization_facade.py +24 -11
  25. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +23 -11
  26. model_compression_toolkit/gptq/pytorch/gptq_training.py +4 -45
  27. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/LICENSE.md +0 -0
  28. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/WHEEL +0 -0
  29. {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,27 @@ import tensorflow as tf
18
18
 
19
19
  from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor
20
20
 
21
+ import tensorflow as tf
22
+ from typing import Callable, Generator, Sequence, Any
23
+
24
+
25
+ def get_tensor_spec(item, ignore_batch_dim=False):
26
+ """
27
+ Get the TensorFlow TensorSpec for an item, optionally ignoring the first dimension.
28
+
29
+ Args:
30
+ item: The input item, which could be a tensor, tuple, or list.
31
+ ignore_batch_dim (bool): Whether to ignore the first dimension of the tensor shape.
32
+
33
+ Returns:
34
+ TensorSpec or a tuple of TensorSpecs.
35
+ """
36
+ if isinstance(item, (tuple, list)):
37
+ return tuple(get_tensor_spec(sub_item, ignore_batch_dim) for sub_item in item)
38
+
39
+ shape = item.shape[1:] if ignore_batch_dim else item.shape
40
+ return tf.TensorSpec(shape=shape, dtype=item.dtype)
41
+
21
42
 
22
43
  def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
23
44
  """
@@ -29,39 +50,151 @@ def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
29
50
  Returns:
30
51
  A factory for a flattened data generator.
31
52
  """
53
+
32
54
  def gen():
33
55
  for inputs_batch in data_gen_fn():
34
56
  for sample in zip(*inputs_batch):
35
- yield to_tf_tensor(sample)
36
- return gen
57
+ yield tuple([tf.convert_to_tensor(s) for s in sample])
37
58
 
59
+ return gen
38
60
 
39
- # TODO in tf dataset and dataloader are combined within tf.data.Dataset. For advanced use cases such as gptq sla we
40
- # need to separate dataset from dataloader similarly to torch data_util.
41
61
  class TFDatasetFromGenerator:
42
- def __init__(self, data_gen, batch_size):
43
- inputs = next(data_gen())
44
- if not isinstance(inputs, list):
45
- raise TypeError(f'Representative data generator is expected to generate a list of tensors, '
46
- f'got {type(inputs)}') # pragma: no cover
62
+ """
63
+ TensorFlow dataset from a data generator function, batched to a specified size.
64
+ """
47
65
 
66
+ def __init__(self, data_gen_fn: Callable[[], Generator]):
67
+ """
68
+ Args:
69
+ data_gen_fn: a factory function for data generator that yields lists of tensors.
70
+ """
71
+ inputs = next(data_gen_fn())
72
+ if not isinstance(inputs, list):
73
+ raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
48
74
  self.orig_batch_size = inputs[0].shape[0]
49
-
50
- output_signature = tuple([tf.TensorSpec(shape=t.shape[1:], dtype=t.dtype) for t in inputs])
51
- dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen), output_signature=output_signature)
52
- self.dataset = dataset.batch(batch_size)
53
75
  self._size = None
54
76
 
77
+ # TFDatasetFromGenerator flattens the dataset, thus we ignore the batch dimension
78
+ output_signature = get_tensor_spec(inputs, ignore_batch_dim=True)
79
+ self.dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)
80
+
81
+
55
82
  def __iter__(self):
56
83
  return iter(self.dataset)
57
84
 
58
85
  def __len__(self):
59
86
  """ Returns the number of batches. """
60
87
  if self._size is None:
61
- self._num_batches = sum(1 for _ in self)
62
- return self._num_batches
88
+ self._size = sum(1 for _ in self.dataset)
89
+ return self._size
90
+
91
+
92
+
93
+ class FixedTFDataset:
94
+ """
95
+ Fixed dataset containing samples from a generator, stored in memory.
96
+ """
97
+
98
+ def __init__(self, data_gen_fn: Callable[[], Generator], n_samples: int = None):
99
+ """
100
+ Args:
101
+ data_gen_fn: data generator function.
102
+ n_samples: number of samples to store in the dataset. If None, uses all samples in one pass.
103
+ """
104
+ inputs = next(data_gen_fn())
105
+ if not isinstance(inputs, list):
106
+ raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
107
+ self.orig_batch_size = inputs[0].shape[0]
108
+
109
+ samples = []
110
+ for batch in data_gen_fn():
111
+ samples.extend(zip(*[tf.convert_to_tensor(t) for t in batch]))
112
+ if n_samples is not None and len(samples) >= n_samples:
113
+ samples = samples[:n_samples]
114
+ break
115
+
116
+ if n_samples and len(samples) < n_samples:
117
+ raise ValueError(f'Not enough samples to create a dataset with {n_samples} samples')
118
+ self.samples = samples
119
+
120
+ def __len__(self):
121
+ return len(self.samples)
122
+
123
+ def __getitem__(self, index):
124
+ return self.samples[index]
125
+
126
+
127
+ class FixedSampleInfoDataset:
128
+ """
129
+ Dataset for samples with additional info, each element is a tuple of (sample, sample_info).
130
+ """
131
+
132
+ def __init__(self, samples: Sequence, sample_info: Sequence):
133
+ if not all(len(info) == len(samples) for info in sample_info):
134
+ raise ValueError('Sample and additional info lengths must match')
135
+ self.samples = samples
136
+ self.sample_info = sample_info
137
+
138
+ def __len__(self):
139
+ return len(self.samples)
140
+
141
+ def __getitem__(self, index):
142
+ return self.samples[index], tuple([info[index] for info in self.sample_info])
143
+
144
+
145
+ class IterableSampleWithConstInfoDataset:
146
+ """
147
+ Augments each sample in an iterable dataset with constant additional information.
148
+ """
149
+
150
+ def __init__(self, samples_dataset: tf.data.Dataset, *info: Any):
151
+ self.samples_dataset = samples_dataset
152
+ self.info = info
153
+
154
+ def __iter__(self):
155
+ for sample in self.samples_dataset:
156
+ yield (sample, *self.info)
157
+
158
+
159
+ def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
160
+ """Create a DataLoader based on samples yielded by data_gen."""
161
+ ds = TFDatasetFromGenerator(data_gen_fn)
162
+ return create_tf_dataloader(dataset=ds, batch_size=batch_size)
163
+
164
+
165
+ def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
166
+ """
167
+ Creates a tf.data.Dataset with specified loading options.
168
+
169
+ Args:
170
+ dataset: The dataset container (e.g., FixedDatasetFromGenerator or FixedSampleInfoDataset).
171
+ batch_size: Number of samples per batch.
172
+ shuffle: Whether to shuffle the dataset.
173
+ collate_fn: A function to apply to each batch (e.g., add extra outputs like regularization weights).
174
+
175
+ Returns:
176
+ tf.data.Dataset: Configured for batching, shuffling, and custom transformations.
177
+ """
178
+ def generator():
179
+ for item in dataset:
180
+ yield item
181
+
182
+ dummy_input_tensors = next(generator())
183
+
184
+ output_signature = get_tensor_spec(dummy_input_tensors)
185
+
186
+ tf_dataset = tf.data.Dataset.from_generator(
187
+ generator,
188
+ output_signature=output_signature
189
+ )
190
+
191
+ if shuffle:
192
+ tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
193
+
194
+ tf_dataset = tf_dataset.batch(batch_size)
63
195
 
196
+ # Apply collate function if provided
197
+ if collate_fn:
198
+ tf_dataset = tf_dataset.map(lambda *args: collate_fn(args))
64
199
 
65
- def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size) -> TFDatasetFromGenerator:
66
- """ Create DataLoader based on samples yielded by data_gen. """
67
- return TFDatasetFromGenerator(data_gen_fn, batch_size)
200
+ return tf_dataset
@@ -60,96 +60,103 @@ class ActivationHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
60
60
  Returns:
61
61
  List[np.ndarray]: Scores based on the Hessian-approximation for the requested nodes.
62
62
  """
63
- if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
64
- model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
65
-
66
- if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
67
- Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. "
68
- "This operation is not supported. "
69
- "Remove the output node from the set of node targets in the Hessian request.")
70
-
71
- grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
72
-
73
- # Building a model to run Hessian approximation on
74
- model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
75
-
76
- # Record operations for automatic differentiation
77
- with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
78
- g.watch(self.input_images)
79
-
80
- if len(self.input_images) > 1:
81
- outputs = model(self.input_images)
82
- else:
83
- outputs = model(*self.input_images)
84
-
85
- if len(outputs) != len(grad_model_outputs): # pragma: no cover
86
- Logger.critical(
87
- f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
88
- f"outputs, but got {len(outputs)} output tensors.")
89
-
90
- # Extracting the intermediate activation tensors and the model real output.
91
- # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
92
- num_target_nodes = len(self.hessian_request.target_nodes)
93
- # Extract activation tensors of nodes for which we want to compute Hessian
94
- target_activation_tensors = outputs[:num_target_nodes]
95
- # Extract the model outputs
96
- output_tensors = outputs[num_target_nodes:]
97
-
98
- # Unfold and concatenate all outputs to form a single tensor
99
- output = self._concat_tensors(output_tensors)
100
-
101
- # List to store the Hessian-approximation scores for each interest point
102
- ipts_hessian_approximations = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
103
- for _ in range(len(target_activation_tensors))]
104
-
105
- # Loop through each interest point activation tensor
106
- prev_mean_results = None
107
- for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
108
- # Getting a random vector with normal distribution
109
- v = tf.random.normal(shape=output.shape, dtype=output.dtype)
110
- f_v = tf.reduce_sum(v * output)
111
- for i, ipt in enumerate(target_activation_tensors): # Per Interest point activation tensor
112
- interest_point_scores = [] # List to store scores for each interest point
113
- with g.stop_recording():
114
- # Computing the approximation by getting the gradient of (output * v)
115
- hess_v = g.gradient(f_v, ipt)
116
-
117
- if hess_v is None:
118
- # In case we have an output node, which is an interest point, but it is not
119
- # differentiable, we consider its Hessian to be the initial value 0.
120
- continue # pragma: no cover
121
-
63
+ model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
64
+
65
+ if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
66
+ Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. "
67
+ "This operation is not supported. "
68
+ "Remove the output node from the set of node targets in the Hessian request.")
69
+
70
+ grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
71
+
72
+ # Building a model to run Hessian approximation on
73
+ model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
74
+
75
+ # Record operations for automatic differentiation
76
+ with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
77
+ g.watch(self.input_images)
78
+
79
+ if len(self.input_images) > 1:
80
+ outputs = model(self.input_images)
81
+ else:
82
+ outputs = model(*self.input_images)
83
+
84
+ if len(outputs) != len(grad_model_outputs): # pragma: no cover
85
+ Logger.critical(
86
+ f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
87
+ f"outputs, but got {len(outputs)} output tensors.")
88
+
89
+ # Extracting the intermediate activation tensors and the model real output.
90
+ # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
91
+ num_target_nodes = len(self.hessian_request.target_nodes)
92
+ # Extract activation tensors of nodes for which we want to compute Hessian
93
+ target_activation_tensors = outputs[:num_target_nodes]
94
+ # Extract the model outputs
95
+ output_tensors = outputs[num_target_nodes:]
96
+
97
+ # Unfold and concatenate all outputs to form a single tensor
98
+ output = self._concat_tensors(output_tensors)
99
+
100
+ # List to store the Hessian-approximation scores for each interest point
101
+ ipts_hessian_approximations = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
102
+ for _ in range(len(target_activation_tensors))]
103
+
104
+ # Loop through each interest point activation tensor
105
+ prev_mean_results = None
106
+ for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
107
+ # Generate random tensor of 1s and -1s
108
+ v = self._generate_random_vectors_batch(output.shape)
109
+ f_v = tf.reduce_sum(v * output)
110
+ for i, ipt in enumerate(target_activation_tensors): # Per Interest point activation tensor
111
+ interest_point_scores = [] # List to store scores for each interest point
112
+ with g.stop_recording():
113
+ # Computing the approximation by getting the gradient of (output * v)
114
+ hess_v = g.gradient(f_v, ipt)
115
+
116
+ if hess_v is None:
117
+ # In case we have an output node, which is an interest point, but it is not
118
+ # differentiable, we consider its Hessian to be the initial value 0.
119
+ continue # pragma: no cover
120
+
121
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
122
122
  # Mean over all dims but the batch (CXHXW for conv)
123
123
  hessian_approx = tf.reduce_sum(hess_v ** 2.0,
124
124
  axis=tuple(d for d in range(1, len(hess_v.shape))))
125
-
126
- # Free gradients
127
- del hess_v
128
-
129
- # Update node Hessian approximation mean over random iterations
130
- ipts_hessian_approximations[i] = (j * ipts_hessian_approximations[i] + hessian_approx) / (j + 1)
131
-
132
- # If the change to the mean approximation is insignificant (to all outputs)
133
- # we stop the calculation.
134
- if j > MIN_HESSIAN_ITER:
135
- if prev_mean_results is not None:
136
- new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
137
- relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
138
- (tf.abs(new_mean_res) + 1e-6))
139
- max_delta = tf.reduce_max(relative_delta_per_node)
140
- if max_delta < HESSIAN_COMP_TOLERANCE:
141
- break
125
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_ELEMENT:
126
+ hessian_approx = hess_v ** 2
127
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
128
+ axes_to_sum = tuple(d for d in range(1, len(hess_v.shape)-1))
129
+ hessian_approx = tf.reduce_sum(hess_v ** 2.0, axis=axes_to_sum)
130
+
131
+ else: # pragma: no cover
132
+ Logger.critical(f"{self.hessian_request.granularity} "
133
+ f"is not supported for Keras activation hessian\'s approximation scores calculator.")
134
+
135
+ # Free gradients
136
+ del hess_v
137
+
138
+ # Update node Hessian approximation mean over random iterations
139
+ ipts_hessian_approximations[i] = (j * ipts_hessian_approximations[i] + hessian_approx) / (j + 1)
140
+
141
+ # If the change to the mean approximation is insignificant (to all outputs)
142
+ # we stop the calculation.
143
+ if j > MIN_HESSIAN_ITER and prev_mean_results is not None:
144
+ new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
145
+ relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
146
+ (tf.abs(new_mean_res) + 1e-6))
147
+ max_delta = tf.reduce_max(relative_delta_per_node)
148
+ if max_delta < HESSIAN_COMP_TOLERANCE:
149
+ break
150
+
151
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
142
152
  prev_mean_results = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
143
153
 
144
- # Convert results to list of numpy arrays
145
- hessian_results = [h.numpy() for h in ipts_hessian_approximations]
146
- # Extend the Hessian tensors shape to align with expected return type
147
- # TODO: currently, only per-tensor Hessian is available for activation.
148
- # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
149
- hessian_results = [h[..., np.newaxis] for h in hessian_results]
154
+ # Convert results to list of numpy arrays
155
+ hessian_results = [h.numpy() for h in ipts_hessian_approximations]
156
+ # Extend the Hessian tensors shape to align with expected return type
157
+ # TODO: currently, only per-tensor Hessian is available for activation.
158
+ # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
159
+ hessian_results = [h[..., np.newaxis] for h in hessian_results]
150
160
 
151
- return hessian_results
161
+ return hessian_results
152
162
 
153
- else: # pragma: no cover
154
- Logger.critical(f"{self.hessian_request.granularity} "
155
- f"is not supported for Keras activation hessian\'s approximation scores calculator.")
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from tensorflow import TensorShape
15
16
 
16
17
  from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
17
18
 
@@ -77,3 +78,19 @@ class HessianScoresCalculatorKeras(HessianScoresCalculator):
77
78
  "Unable to concatenate tensors for gradient calculation due to mismatched shapes along the first axis.") # pragma: no cover
78
79
 
79
80
  return tf.concat(_r_tensors, axis=1)
81
+
82
+ def _generate_random_vectors_batch(self, shape: TensorShape) -> tf.Tensor:
83
+ """
84
+ Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
85
+
86
+ Args:
87
+ shape: target shape.
88
+
89
+ Returns:
90
+ Random tensor.
91
+ """
92
+ v = tf.random.uniform(shape=shape, minval=0, maxval=2, dtype=tf.int32)
93
+ v = tf.where(v == 0, -1, 1)
94
+ v = tf.cast(v, tf.float32)
95
+ return v
96
+
@@ -89,8 +89,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
89
89
  prev_mean_results = None
90
90
  tensors_original_shape = []
91
91
  for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
92
- # Getting a random vector with normal distribution and the same shape as the model output
93
- v = tf.random.normal(shape=output.shape)
92
+ v = self._generate_random_vectors_batch(output.shape)
94
93
  f_v = tf.reduce_sum(v * output)
95
94
 
96
95
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
@@ -438,17 +438,11 @@ class KerasImplementation(FrameworkImplementation):
438
438
  node: Node to indicate whether it needs to be part of the interest points set.
439
439
  Returns: True if the node should be considered an interest point, False otherwise.
440
440
  """
441
-
442
- if node.is_match_type(Activation):
443
- node_type_name = node.framework_attr[keras_constants.ACTIVATION]
444
- if node_type_name in [keras_constants.SOFTMAX, keras_constants.SIGMOID]:
445
- return True
446
- elif any([node.is_match_type(_type) for _type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D,
447
- DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate, tf.concat,
448
- Add, tf.add]]):
441
+ if self.is_softmax(node) or self.is_sigmoid(node):
449
442
  return True
450
443
 
451
- return False
444
+ return any([node.is_match_type(_type) for _type in [Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense,
445
+ Concatenate, tf.concat, Add, tf.add]])
452
446
 
453
447
  def get_mp_node_distance_fn(self, n: BaseNode,
454
448
  compute_distance_fn: Callable = None,
@@ -466,32 +460,34 @@ class KerasImplementation(FrameworkImplementation):
466
460
  Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
467
461
  """
468
462
 
469
- axis = n.framework_attr.get(keras_constants.AXIS) \
470
- if not isinstance(n, FunctionalNode) else n.op_call_kwargs.get(keras_constants.AXIS)
471
-
472
- layer_class = n.layer_class
473
- framework_attrs = n.framework_attr
463
+ axis = n.op_call_kwargs.get(keras_constants.AXIS) if isinstance(n, FunctionalNode) else n.framework_attr.get(keras_constants.AXIS)
474
464
 
475
465
  if compute_distance_fn is not None:
476
466
  return compute_distance_fn, axis
477
467
 
478
- if layer_class == Activation:
479
- node_type_name = framework_attrs[ACTIVATION]
480
- if node_type_name == SOFTMAX and axis is not None:
481
- return compute_kl_divergence, axis
482
- elif node_type_name == SIGMOID:
483
- return compute_cs, axis
484
- elif axis is not None and (layer_class == tf.nn.softmax or layer_class == tf.keras.layers.Softmax
485
- or (layer_class == TFOpLambda and
486
- SOFTMAX in framework_attrs[keras_constants.FUNCTION])):
468
+ # TODO should we really return mse if axis is None? Error? Fill default?
469
+ if self.is_softmax(n) and axis is not None:
487
470
  return compute_kl_divergence, axis
488
- elif layer_class == tf.nn.sigmoid or (layer_class == TFOpLambda and
489
- SIGMOID in framework_attrs[keras_constants.FUNCTION]):
490
- return compute_cs, axis
491
- elif layer_class == Dense:
471
+
472
+ if self.is_sigmoid(n) or n.layer_class == Dense:
492
473
  return compute_cs, axis
474
+
493
475
  return partial(compute_mse, norm=norm_mse), axis
494
476
 
477
+ @staticmethod
478
+ def is_sigmoid(node: BaseNode):
479
+ cls = node.layer_class
480
+ return ((cls == Activation and node.framework_attr[ACTIVATION] == SIGMOID) or
481
+ cls == tf.nn.sigmoid or
482
+ cls == TFOpLambda and SIGMOID in node.framework_attr[keras_constants.FUNCTION])
483
+
484
+ @staticmethod
485
+ def is_softmax(node: BaseNode):
486
+ cls = node.layer_class
487
+ return ((cls == Activation and node.framework_attr[ACTIVATION] == SOFTMAX) or
488
+ cls in [tf.nn.softmax, tf.keras.layers.Softmax] or
489
+ cls == TFOpLambda and SOFTMAX in node.framework_attr[keras_constants.FUNCTION])
490
+
495
491
  def get_hessian_scores_calculator(self,
496
492
  graph: Graph,
497
493
  input_images: List[Any],
@@ -427,10 +427,8 @@ class PytorchImplementation(FrameworkImplementation):
427
427
  Returns: True if the node should be considered an interest point, False otherwise.
428
428
  """
429
429
 
430
- if any([node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax,
431
- softmax, operator.add, add, cat, operator.concat]]):
432
- return True
433
- return False
430
+ return any(node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax,
431
+ softmax, operator.add, add, cat, operator.concat])
434
432
 
435
433
  def get_mp_node_distance_fn(self, n: BaseNode,
436
434
  compute_distance_fn: Callable = None,
@@ -27,7 +27,11 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
27
27
  from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
28
28
  from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
29
29
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
30
+ from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
31
+ get_gradual_activation_quantizer_wrapper_factory
32
+ from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
30
33
  from model_compression_toolkit.logger import Logger
34
+ from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
31
35
 
32
36
 
33
37
  class GPTQTrainer(ABC):
@@ -64,6 +68,14 @@ class GPTQTrainer(ABC):
64
68
  self.fw_impl = fw_impl
65
69
  self.fw_info = fw_info
66
70
  self.representative_data_gen_fn = representative_data_gen_fn
71
+
72
+ def _get_total_grad_steps():
73
+ return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs
74
+
75
+ self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(gptq_config,
76
+ _get_total_grad_steps,
77
+ self.fw_linear_annealing_scheduler)
78
+
67
79
  # ----------------------------------------------
68
80
  # Build two models and create compare nodes
69
81
  # ----------------------------------------------
@@ -81,6 +93,52 @@ class GPTQTrainer(ABC):
81
93
  f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
82
94
  self.hessian_service = hessian_info_service
83
95
 
96
+ self.reg_func = get_regularization(self.gptq_config,
97
+ _get_total_grad_steps,
98
+ self.fw_soft_quantizer_regularization,
99
+ self.fw_linear_annealing_scheduler)
100
+ self.loss_list = []
101
+ self.input_scale = 1
102
+ if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
103
+ Logger.critical("Input scale mismatch between float and GPTQ networks. "
104
+ "Ensure both networks have matching input scales.") # pragma: no cover
105
+ else:
106
+ self.input_scale = self.gptq_user_info.input_scale
107
+
108
+ trainable_weights, trainable_bias, trainable_threshold = self.fw_get_gptq_trainable_parameters_fn(
109
+ self.fxp_model,
110
+ add_bias=self.gptq_config.train_bias)
111
+ self.flp_weights_list, self.fxp_weights_list = self.fw_get_weights_for_loss_fn(self.fxp_model)
112
+
113
+ if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
114
+ self.fxp_weights_list)):
115
+ Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
116
+ "and the number of float and quantized weights for loss calculation. "
117
+ "Ensure all these elements align to proceed with GPTQ training.")
118
+
119
+ # In Keras we need to flatten the weights first before attaching the optimizer
120
+ if len(trainable_weights) > 0 and isinstance(trainable_weights[0], (list, tuple)):
121
+ trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
122
+ if len(trainable_bias) > 0 and isinstance(trainable_bias[0], (list, tuple)):
123
+ trainable_bias = [w for layer_weights in trainable_bias for w in layer_weights]
124
+
125
+ self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
126
+ trainable_bias,
127
+ trainable_threshold)
128
+ hessian_cfg = self.gptq_config.hessian_weights_config
129
+
130
+ self.has_params_to_train = np.sum(
131
+ [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
132
+ self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
133
+
134
+ if self.use_sample_layer_attention:
135
+ # normalization is currently not supported, make sure the config reflects it.
136
+ if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
137
+ raise NotImplementedError()
138
+ self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen_fn)
139
+ else:
140
+ self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen_fn)
141
+
84
142
  def get_optimizer_with_param(self,
85
143
  flattened_trainable_weights: List[Any],
86
144
  flattened_bias_weights: List[Any],
@@ -13,9 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Any, Tuple, List
17
-
18
16
  import tensorflow as tf
17
+ from typing import List, Tuple
19
18
 
20
19
 
21
20
  def mse_loss(y: tf.Tensor, x: tf.Tensor, normalized: bool = True) -> tf.Tensor:
@@ -67,6 +66,40 @@ def multiple_tensors_mse_loss(y_list: List[tf.Tensor],
67
66
  else:
68
67
  return tf.reduce_mean(tf.stack(loss_values_list))
69
68
 
69
+ def sample_layer_attention_loss(y_list: List[tf.Tensor],
70
+ x_list: List[tf.Tensor],
71
+ fxp_w_list,
72
+ flp_w_list,
73
+ act_bn_mean,
74
+ act_bn_std,
75
+ loss_weights: Tuple[tf.Tensor]) -> tf.Tensor:
76
+ """
77
+ Compute Sample Layer Attention loss between two lists of tensors using TensorFlow.
78
+
79
+ Args:
80
+ y_list: First list of tensors.
81
+ x_list: Second list of tensors.
82
+ fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
83
+ loss_weights: layer-sample attention scores (tuplle by the same length as the number of layers, where each element is a tf.Tensor vector of length of number of samples).
84
+
85
+ Returns:
86
+ Sample Layer Attention loss (a scalar).
87
+ """
88
+ loss = 0
89
+ layers_mean_w = []
90
+ loss_weights = tf.stack(loss_weights, axis=1)
91
+
92
+ for i, (y, x) in enumerate(zip(y_list, x_list)):
93
+ norm = tf.reduce_sum(tf.square(y - x), axis=1)
94
+ if len(norm.shape) > 1:
95
+ norm = tf.reduce_mean(tf.reshape(norm, [norm.shape[0], -1]), axis=1)
96
+ w = loss_weights[:, i]
97
+ loss += tf.reduce_mean(w * norm)
98
+ layers_mean_w.append(tf.reduce_mean(w))
99
+
100
+ loss = loss / tf.reduce_max(tf.stack(layers_mean_w))
101
+ return loss
102
+
70
103
 
71
104
  def mse_loss_per_tensor(y: tf.Tensor,
72
105
  x: tf.Tensor,