mct-nightly 2.2.0.20241203.546__py3-none-any.whl → 2.2.0.20241205.533__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/RECORD +29 -29
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/graph/base_graph.py +9 -5
- model_compression_toolkit/core/common/graph/base_node.py +2 -3
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +32 -35
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +9 -9
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +5 -11
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py +12 -0
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/ru_functions_mapping.py +11 -4
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py +4 -6
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py +6 -11
- model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +6 -9
- model_compression_toolkit/core/keras/data_util.py +151 -18
- model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py +93 -86
- model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py +17 -0
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -2
- model_compression_toolkit/core/keras/keras_implementation.py +23 -27
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +2 -4
- model_compression_toolkit/gptq/common/gptq_training.py +58 -0
- model_compression_toolkit/gptq/keras/gptq_loss.py +35 -2
- model_compression_toolkit/gptq/keras/gptq_training.py +137 -67
- model_compression_toolkit/gptq/keras/graph_info.py +1 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +24 -11
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +23 -11
- model_compression_toolkit/gptq/pytorch/gptq_training.py +4 -45
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241203.546.dist-info → mct_nightly-2.2.0.20241205.533.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,27 @@ import tensorflow as tf
|
|
18
18
|
|
19
19
|
from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor
|
20
20
|
|
21
|
+
import tensorflow as tf
|
22
|
+
from typing import Callable, Generator, Sequence, Any
|
23
|
+
|
24
|
+
|
25
|
+
def get_tensor_spec(item, ignore_batch_dim=False):
|
26
|
+
"""
|
27
|
+
Get the TensorFlow TensorSpec for an item, optionally ignoring the first dimension.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
item: The input item, which could be a tensor, tuple, or list.
|
31
|
+
ignore_batch_dim (bool): Whether to ignore the first dimension of the tensor shape.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
TensorSpec or a tuple of TensorSpecs.
|
35
|
+
"""
|
36
|
+
if isinstance(item, (tuple, list)):
|
37
|
+
return tuple(get_tensor_spec(sub_item, ignore_batch_dim) for sub_item in item)
|
38
|
+
|
39
|
+
shape = item.shape[1:] if ignore_batch_dim else item.shape
|
40
|
+
return tf.TensorSpec(shape=shape, dtype=item.dtype)
|
41
|
+
|
21
42
|
|
22
43
|
def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
|
23
44
|
"""
|
@@ -29,39 +50,151 @@ def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
|
|
29
50
|
Returns:
|
30
51
|
A factory for a flattened data generator.
|
31
52
|
"""
|
53
|
+
|
32
54
|
def gen():
|
33
55
|
for inputs_batch in data_gen_fn():
|
34
56
|
for sample in zip(*inputs_batch):
|
35
|
-
yield
|
36
|
-
return gen
|
57
|
+
yield tuple([tf.convert_to_tensor(s) for s in sample])
|
37
58
|
|
59
|
+
return gen
|
38
60
|
|
39
|
-
# TODO in tf dataset and dataloader are combined within tf.data.Dataset. For advanced use cases such as gptq sla we
|
40
|
-
# need to separate dataset from dataloader similarly to torch data_util.
|
41
61
|
class TFDatasetFromGenerator:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
raise TypeError(f'Representative data generator is expected to generate a list of tensors, '
|
46
|
-
f'got {type(inputs)}') # pragma: no cover
|
62
|
+
"""
|
63
|
+
TensorFlow dataset from a data generator function, batched to a specified size.
|
64
|
+
"""
|
47
65
|
|
66
|
+
def __init__(self, data_gen_fn: Callable[[], Generator]):
|
67
|
+
"""
|
68
|
+
Args:
|
69
|
+
data_gen_fn: a factory function for data generator that yields lists of tensors.
|
70
|
+
"""
|
71
|
+
inputs = next(data_gen_fn())
|
72
|
+
if not isinstance(inputs, list):
|
73
|
+
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
|
48
74
|
self.orig_batch_size = inputs[0].shape[0]
|
49
|
-
|
50
|
-
output_signature = tuple([tf.TensorSpec(shape=t.shape[1:], dtype=t.dtype) for t in inputs])
|
51
|
-
dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen), output_signature=output_signature)
|
52
|
-
self.dataset = dataset.batch(batch_size)
|
53
75
|
self._size = None
|
54
76
|
|
77
|
+
# TFDatasetFromGenerator flattens the dataset, thus we ignore the batch dimension
|
78
|
+
output_signature = get_tensor_spec(inputs, ignore_batch_dim=True)
|
79
|
+
self.dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)
|
80
|
+
|
81
|
+
|
55
82
|
def __iter__(self):
|
56
83
|
return iter(self.dataset)
|
57
84
|
|
58
85
|
def __len__(self):
|
59
86
|
""" Returns the number of batches. """
|
60
87
|
if self._size is None:
|
61
|
-
self.
|
62
|
-
return self.
|
88
|
+
self._size = sum(1 for _ in self.dataset)
|
89
|
+
return self._size
|
90
|
+
|
91
|
+
|
92
|
+
|
93
|
+
class FixedTFDataset:
|
94
|
+
"""
|
95
|
+
Fixed dataset containing samples from a generator, stored in memory.
|
96
|
+
"""
|
97
|
+
|
98
|
+
def __init__(self, data_gen_fn: Callable[[], Generator], n_samples: int = None):
|
99
|
+
"""
|
100
|
+
Args:
|
101
|
+
data_gen_fn: data generator function.
|
102
|
+
n_samples: number of samples to store in the dataset. If None, uses all samples in one pass.
|
103
|
+
"""
|
104
|
+
inputs = next(data_gen_fn())
|
105
|
+
if not isinstance(inputs, list):
|
106
|
+
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
|
107
|
+
self.orig_batch_size = inputs[0].shape[0]
|
108
|
+
|
109
|
+
samples = []
|
110
|
+
for batch in data_gen_fn():
|
111
|
+
samples.extend(zip(*[tf.convert_to_tensor(t) for t in batch]))
|
112
|
+
if n_samples is not None and len(samples) >= n_samples:
|
113
|
+
samples = samples[:n_samples]
|
114
|
+
break
|
115
|
+
|
116
|
+
if n_samples and len(samples) < n_samples:
|
117
|
+
raise ValueError(f'Not enough samples to create a dataset with {n_samples} samples')
|
118
|
+
self.samples = samples
|
119
|
+
|
120
|
+
def __len__(self):
|
121
|
+
return len(self.samples)
|
122
|
+
|
123
|
+
def __getitem__(self, index):
|
124
|
+
return self.samples[index]
|
125
|
+
|
126
|
+
|
127
|
+
class FixedSampleInfoDataset:
|
128
|
+
"""
|
129
|
+
Dataset for samples with additional info, each element is a tuple of (sample, sample_info).
|
130
|
+
"""
|
131
|
+
|
132
|
+
def __init__(self, samples: Sequence, sample_info: Sequence):
|
133
|
+
if not all(len(info) == len(samples) for info in sample_info):
|
134
|
+
raise ValueError('Sample and additional info lengths must match')
|
135
|
+
self.samples = samples
|
136
|
+
self.sample_info = sample_info
|
137
|
+
|
138
|
+
def __len__(self):
|
139
|
+
return len(self.samples)
|
140
|
+
|
141
|
+
def __getitem__(self, index):
|
142
|
+
return self.samples[index], tuple([info[index] for info in self.sample_info])
|
143
|
+
|
144
|
+
|
145
|
+
class IterableSampleWithConstInfoDataset:
|
146
|
+
"""
|
147
|
+
Augments each sample in an iterable dataset with constant additional information.
|
148
|
+
"""
|
149
|
+
|
150
|
+
def __init__(self, samples_dataset: tf.data.Dataset, *info: Any):
|
151
|
+
self.samples_dataset = samples_dataset
|
152
|
+
self.info = info
|
153
|
+
|
154
|
+
def __iter__(self):
|
155
|
+
for sample in self.samples_dataset:
|
156
|
+
yield (sample, *self.info)
|
157
|
+
|
158
|
+
|
159
|
+
def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
|
160
|
+
"""Create a DataLoader based on samples yielded by data_gen."""
|
161
|
+
ds = TFDatasetFromGenerator(data_gen_fn)
|
162
|
+
return create_tf_dataloader(dataset=ds, batch_size=batch_size)
|
163
|
+
|
164
|
+
|
165
|
+
def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
|
166
|
+
"""
|
167
|
+
Creates a tf.data.Dataset with specified loading options.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
dataset: The dataset container (e.g., FixedDatasetFromGenerator or FixedSampleInfoDataset).
|
171
|
+
batch_size: Number of samples per batch.
|
172
|
+
shuffle: Whether to shuffle the dataset.
|
173
|
+
collate_fn: A function to apply to each batch (e.g., add extra outputs like regularization weights).
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
tf.data.Dataset: Configured for batching, shuffling, and custom transformations.
|
177
|
+
"""
|
178
|
+
def generator():
|
179
|
+
for item in dataset:
|
180
|
+
yield item
|
181
|
+
|
182
|
+
dummy_input_tensors = next(generator())
|
183
|
+
|
184
|
+
output_signature = get_tensor_spec(dummy_input_tensors)
|
185
|
+
|
186
|
+
tf_dataset = tf.data.Dataset.from_generator(
|
187
|
+
generator,
|
188
|
+
output_signature=output_signature
|
189
|
+
)
|
190
|
+
|
191
|
+
if shuffle:
|
192
|
+
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
|
193
|
+
|
194
|
+
tf_dataset = tf_dataset.batch(batch_size)
|
63
195
|
|
196
|
+
# Apply collate function if provided
|
197
|
+
if collate_fn:
|
198
|
+
tf_dataset = tf_dataset.map(lambda *args: collate_fn(args))
|
64
199
|
|
65
|
-
|
66
|
-
""" Create DataLoader based on samples yielded by data_gen. """
|
67
|
-
return TFDatasetFromGenerator(data_gen_fn, batch_size)
|
200
|
+
return tf_dataset
|
@@ -60,96 +60,103 @@ class ActivationHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
60
60
|
Returns:
|
61
61
|
List[np.ndarray]: Scores based on the Hessian-approximation for the requested nodes.
|
62
62
|
"""
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
63
|
+
model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
|
64
|
+
|
65
|
+
if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
|
66
|
+
Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. "
|
67
|
+
"This operation is not supported. "
|
68
|
+
"Remove the output node from the set of node targets in the Hessian request.")
|
69
|
+
|
70
|
+
grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
|
71
|
+
|
72
|
+
# Building a model to run Hessian approximation on
|
73
|
+
model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
|
74
|
+
|
75
|
+
# Record operations for automatic differentiation
|
76
|
+
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
|
77
|
+
g.watch(self.input_images)
|
78
|
+
|
79
|
+
if len(self.input_images) > 1:
|
80
|
+
outputs = model(self.input_images)
|
81
|
+
else:
|
82
|
+
outputs = model(*self.input_images)
|
83
|
+
|
84
|
+
if len(outputs) != len(grad_model_outputs): # pragma: no cover
|
85
|
+
Logger.critical(
|
86
|
+
f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
|
87
|
+
f"outputs, but got {len(outputs)} output tensors.")
|
88
|
+
|
89
|
+
# Extracting the intermediate activation tensors and the model real output.
|
90
|
+
# Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
|
91
|
+
num_target_nodes = len(self.hessian_request.target_nodes)
|
92
|
+
# Extract activation tensors of nodes for which we want to compute Hessian
|
93
|
+
target_activation_tensors = outputs[:num_target_nodes]
|
94
|
+
# Extract the model outputs
|
95
|
+
output_tensors = outputs[num_target_nodes:]
|
96
|
+
|
97
|
+
# Unfold and concatenate all outputs to form a single tensor
|
98
|
+
output = self._concat_tensors(output_tensors)
|
99
|
+
|
100
|
+
# List to store the Hessian-approximation scores for each interest point
|
101
|
+
ipts_hessian_approximations = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
|
102
|
+
for _ in range(len(target_activation_tensors))]
|
103
|
+
|
104
|
+
# Loop through each interest point activation tensor
|
105
|
+
prev_mean_results = None
|
106
|
+
for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
|
107
|
+
# Generate random tensor of 1s and -1s
|
108
|
+
v = self._generate_random_vectors_batch(output.shape)
|
109
|
+
f_v = tf.reduce_sum(v * output)
|
110
|
+
for i, ipt in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
111
|
+
interest_point_scores = [] # List to store scores for each interest point
|
112
|
+
with g.stop_recording():
|
113
|
+
# Computing the approximation by getting the gradient of (output * v)
|
114
|
+
hess_v = g.gradient(f_v, ipt)
|
115
|
+
|
116
|
+
if hess_v is None:
|
117
|
+
# In case we have an output node, which is an interest point, but it is not
|
118
|
+
# differentiable, we consider its Hessian to be the initial value 0.
|
119
|
+
continue # pragma: no cover
|
120
|
+
|
121
|
+
if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
|
122
122
|
# Mean over all dims but the batch (CXHXW for conv)
|
123
123
|
hessian_approx = tf.reduce_sum(hess_v ** 2.0,
|
124
124
|
axis=tuple(d for d in range(1, len(hess_v.shape))))
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
125
|
+
elif self.hessian_request.granularity == HessianScoresGranularity.PER_ELEMENT:
|
126
|
+
hessian_approx = hess_v ** 2
|
127
|
+
elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
|
128
|
+
axes_to_sum = tuple(d for d in range(1, len(hess_v.shape)-1))
|
129
|
+
hessian_approx = tf.reduce_sum(hess_v ** 2.0, axis=axes_to_sum)
|
130
|
+
|
131
|
+
else: # pragma: no cover
|
132
|
+
Logger.critical(f"{self.hessian_request.granularity} "
|
133
|
+
f"is not supported for Keras activation hessian\'s approximation scores calculator.")
|
134
|
+
|
135
|
+
# Free gradients
|
136
|
+
del hess_v
|
137
|
+
|
138
|
+
# Update node Hessian approximation mean over random iterations
|
139
|
+
ipts_hessian_approximations[i] = (j * ipts_hessian_approximations[i] + hessian_approx) / (j + 1)
|
140
|
+
|
141
|
+
# If the change to the mean approximation is insignificant (to all outputs)
|
142
|
+
# we stop the calculation.
|
143
|
+
if j > MIN_HESSIAN_ITER and prev_mean_results is not None:
|
144
|
+
new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
|
145
|
+
relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
|
146
|
+
(tf.abs(new_mean_res) + 1e-6))
|
147
|
+
max_delta = tf.reduce_max(relative_delta_per_node)
|
148
|
+
if max_delta < HESSIAN_COMP_TOLERANCE:
|
149
|
+
break
|
150
|
+
|
151
|
+
if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
|
142
152
|
prev_mean_results = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
|
143
153
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
154
|
+
# Convert results to list of numpy arrays
|
155
|
+
hessian_results = [h.numpy() for h in ipts_hessian_approximations]
|
156
|
+
# Extend the Hessian tensors shape to align with expected return type
|
157
|
+
# TODO: currently, only per-tensor Hessian is available for activation.
|
158
|
+
# Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
|
159
|
+
hessian_results = [h[..., np.newaxis] for h in hessian_results]
|
150
160
|
|
151
|
-
|
161
|
+
return hessian_results
|
152
162
|
|
153
|
-
else: # pragma: no cover
|
154
|
-
Logger.critical(f"{self.hessian_request.granularity} "
|
155
|
-
f"is not supported for Keras activation hessian\'s approximation scores calculator.")
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from tensorflow import TensorShape
|
15
16
|
|
16
17
|
from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
|
17
18
|
|
@@ -77,3 +78,19 @@ class HessianScoresCalculatorKeras(HessianScoresCalculator):
|
|
77
78
|
"Unable to concatenate tensors for gradient calculation due to mismatched shapes along the first axis.") # pragma: no cover
|
78
79
|
|
79
80
|
return tf.concat(_r_tensors, axis=1)
|
81
|
+
|
82
|
+
def _generate_random_vectors_batch(self, shape: TensorShape) -> tf.Tensor:
|
83
|
+
"""
|
84
|
+
Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
shape: target shape.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Random tensor.
|
91
|
+
"""
|
92
|
+
v = tf.random.uniform(shape=shape, minval=0, maxval=2, dtype=tf.int32)
|
93
|
+
v = tf.where(v == 0, -1, 1)
|
94
|
+
v = tf.cast(v, tf.float32)
|
95
|
+
return v
|
96
|
+
|
@@ -89,8 +89,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
89
89
|
prev_mean_results = None
|
90
90
|
tensors_original_shape = []
|
91
91
|
for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
|
92
|
-
|
93
|
-
v = tf.random.normal(shape=output.shape)
|
92
|
+
v = self._generate_random_vectors_batch(output.shape)
|
94
93
|
f_v = tf.reduce_sum(v * output)
|
95
94
|
|
96
95
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
@@ -438,17 +438,11 @@ class KerasImplementation(FrameworkImplementation):
|
|
438
438
|
node: Node to indicate whether it needs to be part of the interest points set.
|
439
439
|
Returns: True if the node should be considered an interest point, False otherwise.
|
440
440
|
"""
|
441
|
-
|
442
|
-
if node.is_match_type(Activation):
|
443
|
-
node_type_name = node.framework_attr[keras_constants.ACTIVATION]
|
444
|
-
if node_type_name in [keras_constants.SOFTMAX, keras_constants.SIGMOID]:
|
445
|
-
return True
|
446
|
-
elif any([node.is_match_type(_type) for _type in [tf.nn.softmax, tf.keras.layers.Softmax, tf.nn.sigmoid, Conv2D,
|
447
|
-
DepthwiseConv2D, Conv2DTranspose, Dense, Concatenate, tf.concat,
|
448
|
-
Add, tf.add]]):
|
441
|
+
if self.is_softmax(node) or self.is_sigmoid(node):
|
449
442
|
return True
|
450
443
|
|
451
|
-
return
|
444
|
+
return any([node.is_match_type(_type) for _type in [Conv2D, DepthwiseConv2D, Conv2DTranspose, Dense,
|
445
|
+
Concatenate, tf.concat, Add, tf.add]])
|
452
446
|
|
453
447
|
def get_mp_node_distance_fn(self, n: BaseNode,
|
454
448
|
compute_distance_fn: Callable = None,
|
@@ -466,32 +460,34 @@ class KerasImplementation(FrameworkImplementation):
|
|
466
460
|
Returns: A distance function between two tensors and a axis on which the distance is computed (if exists).
|
467
461
|
"""
|
468
462
|
|
469
|
-
axis = n.framework_attr.get(keras_constants.AXIS)
|
470
|
-
if not isinstance(n, FunctionalNode) else n.op_call_kwargs.get(keras_constants.AXIS)
|
471
|
-
|
472
|
-
layer_class = n.layer_class
|
473
|
-
framework_attrs = n.framework_attr
|
463
|
+
axis = n.op_call_kwargs.get(keras_constants.AXIS) if isinstance(n, FunctionalNode) else n.framework_attr.get(keras_constants.AXIS)
|
474
464
|
|
475
465
|
if compute_distance_fn is not None:
|
476
466
|
return compute_distance_fn, axis
|
477
467
|
|
478
|
-
if
|
479
|
-
|
480
|
-
if node_type_name == SOFTMAX and axis is not None:
|
481
|
-
return compute_kl_divergence, axis
|
482
|
-
elif node_type_name == SIGMOID:
|
483
|
-
return compute_cs, axis
|
484
|
-
elif axis is not None and (layer_class == tf.nn.softmax or layer_class == tf.keras.layers.Softmax
|
485
|
-
or (layer_class == TFOpLambda and
|
486
|
-
SOFTMAX in framework_attrs[keras_constants.FUNCTION])):
|
468
|
+
# TODO should we really return mse if axis is None? Error? Fill default?
|
469
|
+
if self.is_softmax(n) and axis is not None:
|
487
470
|
return compute_kl_divergence, axis
|
488
|
-
|
489
|
-
|
490
|
-
return compute_cs, axis
|
491
|
-
elif layer_class == Dense:
|
471
|
+
|
472
|
+
if self.is_sigmoid(n) or n.layer_class == Dense:
|
492
473
|
return compute_cs, axis
|
474
|
+
|
493
475
|
return partial(compute_mse, norm=norm_mse), axis
|
494
476
|
|
477
|
+
@staticmethod
|
478
|
+
def is_sigmoid(node: BaseNode):
|
479
|
+
cls = node.layer_class
|
480
|
+
return ((cls == Activation and node.framework_attr[ACTIVATION] == SIGMOID) or
|
481
|
+
cls == tf.nn.sigmoid or
|
482
|
+
cls == TFOpLambda and SIGMOID in node.framework_attr[keras_constants.FUNCTION])
|
483
|
+
|
484
|
+
@staticmethod
|
485
|
+
def is_softmax(node: BaseNode):
|
486
|
+
cls = node.layer_class
|
487
|
+
return ((cls == Activation and node.framework_attr[ACTIVATION] == SOFTMAX) or
|
488
|
+
cls in [tf.nn.softmax, tf.keras.layers.Softmax] or
|
489
|
+
cls == TFOpLambda and SOFTMAX in node.framework_attr[keras_constants.FUNCTION])
|
490
|
+
|
495
491
|
def get_hessian_scores_calculator(self,
|
496
492
|
graph: Graph,
|
497
493
|
input_images: List[Any],
|
@@ -427,10 +427,8 @@ class PytorchImplementation(FrameworkImplementation):
|
|
427
427
|
Returns: True if the node should be considered an interest point, False otherwise.
|
428
428
|
"""
|
429
429
|
|
430
|
-
|
431
|
-
|
432
|
-
return True
|
433
|
-
return False
|
430
|
+
return any(node.is_match_type(_type) for _type in [Conv2d, Linear, ConvTranspose2d, Sigmoid, sigmoid, Softmax,
|
431
|
+
softmax, operator.add, add, cat, operator.concat])
|
434
432
|
|
435
433
|
def get_mp_node_distance_fn(self, n: BaseNode,
|
436
434
|
compute_distance_fn: Callable = None,
|
@@ -27,7 +27,11 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
27
27
|
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
29
29
|
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
|
30
|
+
from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
|
31
|
+
get_gradual_activation_quantizer_wrapper_factory
|
32
|
+
from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
|
30
33
|
from model_compression_toolkit.logger import Logger
|
34
|
+
from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
|
31
35
|
|
32
36
|
|
33
37
|
class GPTQTrainer(ABC):
|
@@ -64,6 +68,14 @@ class GPTQTrainer(ABC):
|
|
64
68
|
self.fw_impl = fw_impl
|
65
69
|
self.fw_info = fw_info
|
66
70
|
self.representative_data_gen_fn = representative_data_gen_fn
|
71
|
+
|
72
|
+
def _get_total_grad_steps():
|
73
|
+
return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs
|
74
|
+
|
75
|
+
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(gptq_config,
|
76
|
+
_get_total_grad_steps,
|
77
|
+
self.fw_linear_annealing_scheduler)
|
78
|
+
|
67
79
|
# ----------------------------------------------
|
68
80
|
# Build two models and create compare nodes
|
69
81
|
# ----------------------------------------------
|
@@ -81,6 +93,52 @@ class GPTQTrainer(ABC):
|
|
81
93
|
f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
|
82
94
|
self.hessian_service = hessian_info_service
|
83
95
|
|
96
|
+
self.reg_func = get_regularization(self.gptq_config,
|
97
|
+
_get_total_grad_steps,
|
98
|
+
self.fw_soft_quantizer_regularization,
|
99
|
+
self.fw_linear_annealing_scheduler)
|
100
|
+
self.loss_list = []
|
101
|
+
self.input_scale = 1
|
102
|
+
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
103
|
+
Logger.critical("Input scale mismatch between float and GPTQ networks. "
|
104
|
+
"Ensure both networks have matching input scales.") # pragma: no cover
|
105
|
+
else:
|
106
|
+
self.input_scale = self.gptq_user_info.input_scale
|
107
|
+
|
108
|
+
trainable_weights, trainable_bias, trainable_threshold = self.fw_get_gptq_trainable_parameters_fn(
|
109
|
+
self.fxp_model,
|
110
|
+
add_bias=self.gptq_config.train_bias)
|
111
|
+
self.flp_weights_list, self.fxp_weights_list = self.fw_get_weights_for_loss_fn(self.fxp_model)
|
112
|
+
|
113
|
+
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
114
|
+
self.fxp_weights_list)):
|
115
|
+
Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
|
116
|
+
"and the number of float and quantized weights for loss calculation. "
|
117
|
+
"Ensure all these elements align to proceed with GPTQ training.")
|
118
|
+
|
119
|
+
# In Keras we need to flatten the weights first before attaching the optimizer
|
120
|
+
if len(trainable_weights) > 0 and isinstance(trainable_weights[0], (list, tuple)):
|
121
|
+
trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
|
122
|
+
if len(trainable_bias) > 0 and isinstance(trainable_bias[0], (list, tuple)):
|
123
|
+
trainable_bias = [w for layer_weights in trainable_bias for w in layer_weights]
|
124
|
+
|
125
|
+
self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
|
126
|
+
trainable_bias,
|
127
|
+
trainable_threshold)
|
128
|
+
hessian_cfg = self.gptq_config.hessian_weights_config
|
129
|
+
|
130
|
+
self.has_params_to_train = np.sum(
|
131
|
+
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
132
|
+
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
|
133
|
+
|
134
|
+
if self.use_sample_layer_attention:
|
135
|
+
# normalization is currently not supported, make sure the config reflects it.
|
136
|
+
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
|
137
|
+
raise NotImplementedError()
|
138
|
+
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen_fn)
|
139
|
+
else:
|
140
|
+
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen_fn)
|
141
|
+
|
84
142
|
def get_optimizer_with_param(self,
|
85
143
|
flattened_trainable_weights: List[Any],
|
86
144
|
flattened_bias_weights: List[Any],
|
@@ -13,9 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import Any, Tuple, List
|
17
|
-
|
18
16
|
import tensorflow as tf
|
17
|
+
from typing import List, Tuple
|
19
18
|
|
20
19
|
|
21
20
|
def mse_loss(y: tf.Tensor, x: tf.Tensor, normalized: bool = True) -> tf.Tensor:
|
@@ -67,6 +66,40 @@ def multiple_tensors_mse_loss(y_list: List[tf.Tensor],
|
|
67
66
|
else:
|
68
67
|
return tf.reduce_mean(tf.stack(loss_values_list))
|
69
68
|
|
69
|
+
def sample_layer_attention_loss(y_list: List[tf.Tensor],
|
70
|
+
x_list: List[tf.Tensor],
|
71
|
+
fxp_w_list,
|
72
|
+
flp_w_list,
|
73
|
+
act_bn_mean,
|
74
|
+
act_bn_std,
|
75
|
+
loss_weights: Tuple[tf.Tensor]) -> tf.Tensor:
|
76
|
+
"""
|
77
|
+
Compute Sample Layer Attention loss between two lists of tensors using TensorFlow.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
y_list: First list of tensors.
|
81
|
+
x_list: Second list of tensors.
|
82
|
+
fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
|
83
|
+
loss_weights: layer-sample attention scores (tuplle by the same length as the number of layers, where each element is a tf.Tensor vector of length of number of samples).
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
Sample Layer Attention loss (a scalar).
|
87
|
+
"""
|
88
|
+
loss = 0
|
89
|
+
layers_mean_w = []
|
90
|
+
loss_weights = tf.stack(loss_weights, axis=1)
|
91
|
+
|
92
|
+
for i, (y, x) in enumerate(zip(y_list, x_list)):
|
93
|
+
norm = tf.reduce_sum(tf.square(y - x), axis=1)
|
94
|
+
if len(norm.shape) > 1:
|
95
|
+
norm = tf.reduce_mean(tf.reshape(norm, [norm.shape[0], -1]), axis=1)
|
96
|
+
w = loss_weights[:, i]
|
97
|
+
loss += tf.reduce_mean(w * norm)
|
98
|
+
layers_mean_w.append(tf.reduce_mean(w))
|
99
|
+
|
100
|
+
loss = loss / tf.reduce_max(tf.stack(layers_mean_w))
|
101
|
+
return loss
|
102
|
+
|
70
103
|
|
71
104
|
def mse_loss_per_tensor(y: tf.Tensor,
|
72
105
|
x: tf.Tensor,
|