mct-nightly 2.2.0.20241204.524__py3-none-any.whl → 2.2.0.20241206.524__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241204.524
3
+ Version: 2.2.0.20241206.524
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=CwUJhq62PjrcRJgkwwmu5dArRV2bz7lgnxc2ebnm840,1573
1
+ model_compression_toolkit/__init__.py,sha256=KhP8R07jwQig7PMnV7NExSRFSjG_rAbMcGhuL8koQWc,1573
2
2
  model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -155,7 +155,7 @@ model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256
155
155
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
156
156
  model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
157
157
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
158
- model_compression_toolkit/core/keras/data_util.py,sha256=JdomIJZfep0QYPtx2jlg0xJ40cd9S_I7BakaWQi0wKw,2681
158
+ model_compression_toolkit/core/keras/data_util.py,sha256=sTEuHUrT8S3CpeAEG0XDlYA0bWZKISGPilObPlO0TA8,6833
159
159
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
160
160
  model_compression_toolkit/core/keras/keras_implementation.py,sha256=HwbIR7x4t-TBNbWHVvVNFk8z-KFt6zM0LWAUXQuNZrk,31753
161
161
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
@@ -193,9 +193,9 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_s
193
193
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
194
194
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
195
195
  model_compression_toolkit/core/keras/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
196
- model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py,sha256=p0eM-EO5ltXYjSkd7B3h9BWBcuRZvjxEcA8WaNvdyqc,8901
197
- model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=Cep-bQEwLyqLYfLxM0ByOQd_oAIT-uXjr3dFUd8T9CY,3954
198
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=970C-8J4HtUalNWvZAKlWFZVfw5r6SBdt5RQU_mZ7M0,12261
196
+ model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py,sha256=qGEyOzC1_NIcnBmvvjA-GT7o9-PWo0Ko66vcEyLixhw,9180
197
+ model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=1o7X9GXSfpEmuB5ee2AaBQ2sN2xzX4-smbrq_0qOGRU,4454
198
+ model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=Rl6NNGkHMV0ioEM5bbM4XX7yHDqG6mMp4ifN2VQBDxE,12168
199
199
  model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
200
200
  model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
201
201
  model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=Ziydik2j-LvNBXP3TSfUD6rEezPAikzQGib0_IXkmGM,6729
@@ -355,21 +355,21 @@ model_compression_toolkit/gptq/common/gptq_config.py,sha256=QwSEZZlC6OpnpoBQoAFf
355
355
  model_compression_toolkit/gptq/common/gptq_constants.py,sha256=8HB0yiX75zZ1IKgQUPWpFCM5sS8HAqslws5XrOhxJQ0,750
356
356
  model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
357
357
  model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
358
- model_compression_toolkit/gptq/common/gptq_training.py,sha256=EnG-17U6kGDgTeMkOJQmRoMs0KUldROss683_Bo5oHQ,13249
358
+ model_compression_toolkit/gptq/common/gptq_training.py,sha256=vvrQH1MIW3w90yt9VKBW6jTMXkBrDY82JmCDwT8Kve8,17002
359
359
  model_compression_toolkit/gptq/common/gradual_activation_quantization.py,sha256=EgpzMs_aDoB0wQiTagqvcxCTfrgNUuCfdXEXmfNiyb0,3780
360
360
  model_compression_toolkit/gptq/common/regularization_factory.py,sha256=hyunpXepVeHyoAFJw6zNLK-3ZHBmiut3lmNisJN_L3E,2514
361
361
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
362
362
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
363
- model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
364
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=yBiAod9hbzh2bp4xhVO5szmtCHm6bLUa7-kjUVVwo40,20845
365
- model_compression_toolkit/gptq/keras/graph_info.py,sha256=MKIfrRTRH3zCuxCR1g9ZVIFyuSSr0e0sDybqh4LDM7E,4672
366
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=e3O835Ol5ML0XuqNsCmoTbnnfs-gEgrSGT1ijUZLX7Q,17102
363
+ model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
364
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=0WGiP7Gs4xX3FBs1PNaZ7w3hWRigwQXqYjBrs_-x32o,23241
365
+ model_compression_toolkit/gptq/keras/graph_info.py,sha256=zwoeHX67nJJ5-zYLjzvMXS9TLsy9BsizARbZiDVjVSA,4473
366
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=PO-tNoCoWQpXgefVxqxBfAQ29kGe_DFBgiOQ2DLYato,18005
367
367
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
368
368
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
369
369
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
370
370
  model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=rst-u5EB9Xss4ndKqi297WvZ-9RVee2TAUVFelPVKhU,4663
371
371
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
372
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=REO-pIXpT4ZuJzhizvQjz6vn7Vxnq7k0KvikuQ4FDkE,2769
372
+ model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=UTvEL5hN2cEsMwiGBDbpcE0kQr32VFKwlJBWlDg8HNA,3271
373
373
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=BBSDWLmeywjSM5N6oJkMgcuo7zrXTesB4zLwRGG8QB0,12159
374
374
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=pyhlVpoauHM-zuixHsIGPHFgQoXppL8TlDFCjPE2RuY,10377
375
375
  model_compression_toolkit/gptq/keras/quantizer/ste_rounding/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
@@ -377,7 +377,7 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
377
377
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
378
378
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phIqeu5-_7_5VBT4DT-FCw7Do,3892
379
379
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
380
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=iuZJcoG2w-7qjWGntXWTdU2XUuMPy5IwzZbiolThuI4,22145
380
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=WtehnyiYXdUXf8-uNpV0mdsalF7YF7eKnL7tcFrzZoE,19549
381
381
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
382
382
  model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=hZFU_ZY-LYcpRZyzzX7NsJievkIYKGdkgBzEoB4rsRQ,16020
383
383
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
@@ -559,8 +559,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
559
559
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
560
560
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
561
561
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
562
- mct_nightly-2.2.0.20241204.524.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
563
- mct_nightly-2.2.0.20241204.524.dist-info/METADATA,sha256=O3ETKzNDjZGmSvp_WVmqIJz-jyk93WLG676QjyRsISs,26446
564
- mct_nightly-2.2.0.20241204.524.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
565
- mct_nightly-2.2.0.20241204.524.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
566
- mct_nightly-2.2.0.20241204.524.dist-info/RECORD,,
562
+ mct_nightly-2.2.0.20241206.524.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
563
+ mct_nightly-2.2.0.20241206.524.dist-info/METADATA,sha256=Q-MSMJXd4He0d0RJ_jhEABCs2FgxB6vZIGjv24boOnw,26446
564
+ mct_nightly-2.2.0.20241206.524.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
565
+ mct_nightly-2.2.0.20241206.524.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
566
+ mct_nightly-2.2.0.20241206.524.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241204.000524"
30
+ __version__ = "2.2.0.20241206.000524"
@@ -18,6 +18,27 @@ import tensorflow as tf
18
18
 
19
19
  from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor
20
20
 
21
+ import tensorflow as tf
22
+ from typing import Callable, Generator, Sequence, Any
23
+
24
+
25
+ def get_tensor_spec(item, ignore_batch_dim=False):
26
+ """
27
+ Get the TensorFlow TensorSpec for an item, optionally ignoring the first dimension.
28
+
29
+ Args:
30
+ item: The input item, which could be a tensor, tuple, or list.
31
+ ignore_batch_dim (bool): Whether to ignore the first dimension of the tensor shape.
32
+
33
+ Returns:
34
+ TensorSpec or a tuple of TensorSpecs.
35
+ """
36
+ if isinstance(item, (tuple, list)):
37
+ return tuple(get_tensor_spec(sub_item, ignore_batch_dim) for sub_item in item)
38
+
39
+ shape = item.shape[1:] if ignore_batch_dim else item.shape
40
+ return tf.TensorSpec(shape=shape, dtype=item.dtype)
41
+
21
42
 
22
43
  def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
23
44
  """
@@ -29,39 +50,151 @@ def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
29
50
  Returns:
30
51
  A factory for a flattened data generator.
31
52
  """
53
+
32
54
  def gen():
33
55
  for inputs_batch in data_gen_fn():
34
56
  for sample in zip(*inputs_batch):
35
- yield to_tf_tensor(sample)
36
- return gen
57
+ yield tuple([tf.convert_to_tensor(s) for s in sample])
37
58
 
59
+ return gen
38
60
 
39
- # TODO in tf dataset and dataloader are combined within tf.data.Dataset. For advanced use cases such as gptq sla we
40
- # need to separate dataset from dataloader similarly to torch data_util.
41
61
  class TFDatasetFromGenerator:
42
- def __init__(self, data_gen, batch_size):
43
- inputs = next(data_gen())
44
- if not isinstance(inputs, list):
45
- raise TypeError(f'Representative data generator is expected to generate a list of tensors, '
46
- f'got {type(inputs)}') # pragma: no cover
62
+ """
63
+ TensorFlow dataset from a data generator function, batched to a specified size.
64
+ """
47
65
 
66
+ def __init__(self, data_gen_fn: Callable[[], Generator]):
67
+ """
68
+ Args:
69
+ data_gen_fn: a factory function for data generator that yields lists of tensors.
70
+ """
71
+ inputs = next(data_gen_fn())
72
+ if not isinstance(inputs, list):
73
+ raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
48
74
  self.orig_batch_size = inputs[0].shape[0]
49
-
50
- output_signature = tuple([tf.TensorSpec(shape=t.shape[1:], dtype=t.dtype) for t in inputs])
51
- dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen), output_signature=output_signature)
52
- self.dataset = dataset.batch(batch_size)
53
75
  self._size = None
54
76
 
77
+ # TFDatasetFromGenerator flattens the dataset, thus we ignore the batch dimension
78
+ output_signature = get_tensor_spec(inputs, ignore_batch_dim=True)
79
+ self.dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)
80
+
81
+
55
82
  def __iter__(self):
56
83
  return iter(self.dataset)
57
84
 
58
85
  def __len__(self):
59
86
  """ Returns the number of batches. """
60
87
  if self._size is None:
61
- self._num_batches = sum(1 for _ in self)
62
- return self._num_batches
88
+ self._size = sum(1 for _ in self.dataset)
89
+ return self._size
90
+
91
+
92
+
93
+ class FixedTFDataset:
94
+ """
95
+ Fixed dataset containing samples from a generator, stored in memory.
96
+ """
97
+
98
+ def __init__(self, data_gen_fn: Callable[[], Generator], n_samples: int = None):
99
+ """
100
+ Args:
101
+ data_gen_fn: data generator function.
102
+ n_samples: number of samples to store in the dataset. If None, uses all samples in one pass.
103
+ """
104
+ inputs = next(data_gen_fn())
105
+ if not isinstance(inputs, list):
106
+ raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
107
+ self.orig_batch_size = inputs[0].shape[0]
108
+
109
+ samples = []
110
+ for batch in data_gen_fn():
111
+ samples.extend(zip(*[tf.convert_to_tensor(t) for t in batch]))
112
+ if n_samples is not None and len(samples) >= n_samples:
113
+ samples = samples[:n_samples]
114
+ break
115
+
116
+ if n_samples and len(samples) < n_samples:
117
+ raise ValueError(f'Not enough samples to create a dataset with {n_samples} samples')
118
+ self.samples = samples
119
+
120
+ def __len__(self):
121
+ return len(self.samples)
122
+
123
+ def __getitem__(self, index):
124
+ return self.samples[index]
125
+
126
+
127
+ class FixedSampleInfoDataset:
128
+ """
129
+ Dataset for samples with additional info, each element is a tuple of (sample, sample_info).
130
+ """
131
+
132
+ def __init__(self, samples: Sequence, sample_info: Sequence):
133
+ if not all(len(info) == len(samples) for info in sample_info):
134
+ raise ValueError('Sample and additional info lengths must match')
135
+ self.samples = samples
136
+ self.sample_info = sample_info
137
+
138
+ def __len__(self):
139
+ return len(self.samples)
140
+
141
+ def __getitem__(self, index):
142
+ return self.samples[index], tuple([info[index] for info in self.sample_info])
143
+
144
+
145
+ class IterableSampleWithConstInfoDataset:
146
+ """
147
+ Augments each sample in an iterable dataset with constant additional information.
148
+ """
149
+
150
+ def __init__(self, samples_dataset: tf.data.Dataset, *info: Any):
151
+ self.samples_dataset = samples_dataset
152
+ self.info = info
153
+
154
+ def __iter__(self):
155
+ for sample in self.samples_dataset:
156
+ yield (sample, *self.info)
157
+
158
+
159
+ def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
160
+ """Create a DataLoader based on samples yielded by data_gen."""
161
+ ds = TFDatasetFromGenerator(data_gen_fn)
162
+ return create_tf_dataloader(dataset=ds, batch_size=batch_size)
163
+
164
+
165
+ def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
166
+ """
167
+ Creates a tf.data.Dataset with specified loading options.
168
+
169
+ Args:
170
+ dataset: The dataset container (e.g., FixedDatasetFromGenerator or FixedSampleInfoDataset).
171
+ batch_size: Number of samples per batch.
172
+ shuffle: Whether to shuffle the dataset.
173
+ collate_fn: A function to apply to each batch (e.g., add extra outputs like regularization weights).
174
+
175
+ Returns:
176
+ tf.data.Dataset: Configured for batching, shuffling, and custom transformations.
177
+ """
178
+ def generator():
179
+ for item in dataset:
180
+ yield item
181
+
182
+ dummy_input_tensors = next(generator())
183
+
184
+ output_signature = get_tensor_spec(dummy_input_tensors)
185
+
186
+ tf_dataset = tf.data.Dataset.from_generator(
187
+ generator,
188
+ output_signature=output_signature
189
+ )
190
+
191
+ if shuffle:
192
+ tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
193
+
194
+ tf_dataset = tf_dataset.batch(batch_size)
63
195
 
196
+ # Apply collate function if provided
197
+ if collate_fn:
198
+ tf_dataset = tf_dataset.map(lambda *args: collate_fn(args))
64
199
 
65
- def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size) -> TFDatasetFromGenerator:
66
- """ Create DataLoader based on samples yielded by data_gen. """
67
- return TFDatasetFromGenerator(data_gen_fn, batch_size)
200
+ return tf_dataset
@@ -60,96 +60,103 @@ class ActivationHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
60
60
  Returns:
61
61
  List[np.ndarray]: Scores based on the Hessian-approximation for the requested nodes.
62
62
  """
63
- if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
64
- model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
65
-
66
- if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
67
- Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. "
68
- "This operation is not supported. "
69
- "Remove the output node from the set of node targets in the Hessian request.")
70
-
71
- grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
72
-
73
- # Building a model to run Hessian approximation on
74
- model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
75
-
76
- # Record operations for automatic differentiation
77
- with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
78
- g.watch(self.input_images)
79
-
80
- if len(self.input_images) > 1:
81
- outputs = model(self.input_images)
82
- else:
83
- outputs = model(*self.input_images)
84
-
85
- if len(outputs) != len(grad_model_outputs): # pragma: no cover
86
- Logger.critical(
87
- f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
88
- f"outputs, but got {len(outputs)} output tensors.")
89
-
90
- # Extracting the intermediate activation tensors and the model real output.
91
- # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
92
- num_target_nodes = len(self.hessian_request.target_nodes)
93
- # Extract activation tensors of nodes for which we want to compute Hessian
94
- target_activation_tensors = outputs[:num_target_nodes]
95
- # Extract the model outputs
96
- output_tensors = outputs[num_target_nodes:]
97
-
98
- # Unfold and concatenate all outputs to form a single tensor
99
- output = self._concat_tensors(output_tensors)
100
-
101
- # List to store the Hessian-approximation scores for each interest point
102
- ipts_hessian_approximations = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
103
- for _ in range(len(target_activation_tensors))]
104
-
105
- # Loop through each interest point activation tensor
106
- prev_mean_results = None
107
- for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
108
- # Getting a random vector with normal distribution
109
- v = tf.random.normal(shape=output.shape, dtype=output.dtype)
110
- f_v = tf.reduce_sum(v * output)
111
- for i, ipt in enumerate(target_activation_tensors): # Per Interest point activation tensor
112
- interest_point_scores = [] # List to store scores for each interest point
113
- with g.stop_recording():
114
- # Computing the approximation by getting the gradient of (output * v)
115
- hess_v = g.gradient(f_v, ipt)
116
-
117
- if hess_v is None:
118
- # In case we have an output node, which is an interest point, but it is not
119
- # differentiable, we consider its Hessian to be the initial value 0.
120
- continue # pragma: no cover
121
-
63
+ model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
64
+
65
+ if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
66
+ Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. "
67
+ "This operation is not supported. "
68
+ "Remove the output node from the set of node targets in the Hessian request.")
69
+
70
+ grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
71
+
72
+ # Building a model to run Hessian approximation on
73
+ model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
74
+
75
+ # Record operations for automatic differentiation
76
+ with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
77
+ g.watch(self.input_images)
78
+
79
+ if len(self.input_images) > 1:
80
+ outputs = model(self.input_images)
81
+ else:
82
+ outputs = model(*self.input_images)
83
+
84
+ if len(outputs) != len(grad_model_outputs): # pragma: no cover
85
+ Logger.critical(
86
+ f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
87
+ f"outputs, but got {len(outputs)} output tensors.")
88
+
89
+ # Extracting the intermediate activation tensors and the model real output.
90
+ # Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
91
+ num_target_nodes = len(self.hessian_request.target_nodes)
92
+ # Extract activation tensors of nodes for which we want to compute Hessian
93
+ target_activation_tensors = outputs[:num_target_nodes]
94
+ # Extract the model outputs
95
+ output_tensors = outputs[num_target_nodes:]
96
+
97
+ # Unfold and concatenate all outputs to form a single tensor
98
+ output = self._concat_tensors(output_tensors)
99
+
100
+ # List to store the Hessian-approximation scores for each interest point
101
+ ipts_hessian_approximations = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
102
+ for _ in range(len(target_activation_tensors))]
103
+
104
+ # Loop through each interest point activation tensor
105
+ prev_mean_results = None
106
+ for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
107
+ # Generate random tensor of 1s and -1s
108
+ v = self._generate_random_vectors_batch(output.shape)
109
+ f_v = tf.reduce_sum(v * output)
110
+ for i, ipt in enumerate(target_activation_tensors): # Per Interest point activation tensor
111
+ interest_point_scores = [] # List to store scores for each interest point
112
+ with g.stop_recording():
113
+ # Computing the approximation by getting the gradient of (output * v)
114
+ hess_v = g.gradient(f_v, ipt)
115
+
116
+ if hess_v is None:
117
+ # In case we have an output node, which is an interest point, but it is not
118
+ # differentiable, we consider its Hessian to be the initial value 0.
119
+ continue # pragma: no cover
120
+
121
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
122
122
  # Mean over all dims but the batch (CXHXW for conv)
123
123
  hessian_approx = tf.reduce_sum(hess_v ** 2.0,
124
124
  axis=tuple(d for d in range(1, len(hess_v.shape))))
125
-
126
- # Free gradients
127
- del hess_v
128
-
129
- # Update node Hessian approximation mean over random iterations
130
- ipts_hessian_approximations[i] = (j * ipts_hessian_approximations[i] + hessian_approx) / (j + 1)
131
-
132
- # If the change to the mean approximation is insignificant (to all outputs)
133
- # we stop the calculation.
134
- if j > MIN_HESSIAN_ITER:
135
- if prev_mean_results is not None:
136
- new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
137
- relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
138
- (tf.abs(new_mean_res) + 1e-6))
139
- max_delta = tf.reduce_max(relative_delta_per_node)
140
- if max_delta < HESSIAN_COMP_TOLERANCE:
141
- break
125
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_ELEMENT:
126
+ hessian_approx = hess_v ** 2
127
+ elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
128
+ axes_to_sum = tuple(d for d in range(1, len(hess_v.shape)-1))
129
+ hessian_approx = tf.reduce_sum(hess_v ** 2.0, axis=axes_to_sum)
130
+
131
+ else: # pragma: no cover
132
+ Logger.critical(f"{self.hessian_request.granularity} "
133
+ f"is not supported for Keras activation hessian\'s approximation scores calculator.")
134
+
135
+ # Free gradients
136
+ del hess_v
137
+
138
+ # Update node Hessian approximation mean over random iterations
139
+ ipts_hessian_approximations[i] = (j * ipts_hessian_approximations[i] + hessian_approx) / (j + 1)
140
+
141
+ # If the change to the mean approximation is insignificant (to all outputs)
142
+ # we stop the calculation.
143
+ if j > MIN_HESSIAN_ITER and prev_mean_results is not None:
144
+ new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
145
+ relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
146
+ (tf.abs(new_mean_res) + 1e-6))
147
+ max_delta = tf.reduce_max(relative_delta_per_node)
148
+ if max_delta < HESSIAN_COMP_TOLERANCE:
149
+ break
150
+
151
+ if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
142
152
  prev_mean_results = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
143
153
 
144
- # Convert results to list of numpy arrays
145
- hessian_results = [h.numpy() for h in ipts_hessian_approximations]
146
- # Extend the Hessian tensors shape to align with expected return type
147
- # TODO: currently, only per-tensor Hessian is available for activation.
148
- # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
149
- hessian_results = [h[..., np.newaxis] for h in hessian_results]
154
+ # Convert results to list of numpy arrays
155
+ hessian_results = [h.numpy() for h in ipts_hessian_approximations]
156
+ # Extend the Hessian tensors shape to align with expected return type
157
+ # TODO: currently, only per-tensor Hessian is available for activation.
158
+ # Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
159
+ hessian_results = [h[..., np.newaxis] for h in hessian_results]
150
160
 
151
- return hessian_results
161
+ return hessian_results
152
162
 
153
- else: # pragma: no cover
154
- Logger.critical(f"{self.hessian_request.granularity} "
155
- f"is not supported for Keras activation hessian\'s approximation scores calculator.")
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from tensorflow import TensorShape
15
16
 
16
17
  from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
17
18
 
@@ -77,3 +78,19 @@ class HessianScoresCalculatorKeras(HessianScoresCalculator):
77
78
  "Unable to concatenate tensors for gradient calculation due to mismatched shapes along the first axis.") # pragma: no cover
78
79
 
79
80
  return tf.concat(_r_tensors, axis=1)
81
+
82
+ def _generate_random_vectors_batch(self, shape: TensorShape) -> tf.Tensor:
83
+ """
84
+ Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
85
+
86
+ Args:
87
+ shape: target shape.
88
+
89
+ Returns:
90
+ Random tensor.
91
+ """
92
+ v = tf.random.uniform(shape=shape, minval=0, maxval=2, dtype=tf.int32)
93
+ v = tf.where(v == 0, -1, 1)
94
+ v = tf.cast(v, tf.float32)
95
+ return v
96
+
@@ -89,8 +89,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
89
89
  prev_mean_results = None
90
90
  tensors_original_shape = []
91
91
  for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
92
- # Getting a random vector with normal distribution and the same shape as the model output
93
- v = tf.random.normal(shape=output.shape)
92
+ v = self._generate_random_vectors_batch(output.shape)
94
93
  f_v = tf.reduce_sum(v * output)
95
94
 
96
95
  for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
@@ -27,7 +27,11 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
27
27
  from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
28
28
  from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
29
29
  from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
30
+ from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
31
+ get_gradual_activation_quantizer_wrapper_factory
32
+ from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
30
33
  from model_compression_toolkit.logger import Logger
34
+ from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
31
35
 
32
36
 
33
37
  class GPTQTrainer(ABC):
@@ -64,6 +68,14 @@ class GPTQTrainer(ABC):
64
68
  self.fw_impl = fw_impl
65
69
  self.fw_info = fw_info
66
70
  self.representative_data_gen_fn = representative_data_gen_fn
71
+
72
+ def _get_total_grad_steps():
73
+ return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs
74
+
75
+ self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(gptq_config,
76
+ _get_total_grad_steps,
77
+ self.fw_linear_annealing_scheduler)
78
+
67
79
  # ----------------------------------------------
68
80
  # Build two models and create compare nodes
69
81
  # ----------------------------------------------
@@ -81,6 +93,52 @@ class GPTQTrainer(ABC):
81
93
  f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
82
94
  self.hessian_service = hessian_info_service
83
95
 
96
+ self.reg_func = get_regularization(self.gptq_config,
97
+ _get_total_grad_steps,
98
+ self.fw_soft_quantizer_regularization,
99
+ self.fw_linear_annealing_scheduler)
100
+ self.loss_list = []
101
+ self.input_scale = 1
102
+ if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
103
+ Logger.critical("Input scale mismatch between float and GPTQ networks. "
104
+ "Ensure both networks have matching input scales.") # pragma: no cover
105
+ else:
106
+ self.input_scale = self.gptq_user_info.input_scale
107
+
108
+ trainable_weights, trainable_bias, trainable_threshold = self.fw_get_gptq_trainable_parameters_fn(
109
+ self.fxp_model,
110
+ add_bias=self.gptq_config.train_bias)
111
+ self.flp_weights_list, self.fxp_weights_list = self.fw_get_weights_for_loss_fn(self.fxp_model)
112
+
113
+ if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
114
+ self.fxp_weights_list)):
115
+ Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
116
+ "and the number of float and quantized weights for loss calculation. "
117
+ "Ensure all these elements align to proceed with GPTQ training.")
118
+
119
+ # In Keras we need to flatten the weights first before attaching the optimizer
120
+ if len(trainable_weights) > 0 and isinstance(trainable_weights[0], (list, tuple)):
121
+ trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
122
+ if len(trainable_bias) > 0 and isinstance(trainable_bias[0], (list, tuple)):
123
+ trainable_bias = [w for layer_weights in trainable_bias for w in layer_weights]
124
+
125
+ self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
126
+ trainable_bias,
127
+ trainable_threshold)
128
+ hessian_cfg = self.gptq_config.hessian_weights_config
129
+
130
+ self.has_params_to_train = np.sum(
131
+ [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
132
+ self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
133
+
134
+ if self.use_sample_layer_attention:
135
+ # normalization is currently not supported, make sure the config reflects it.
136
+ if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
137
+ raise NotImplementedError()
138
+ self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen_fn)
139
+ else:
140
+ self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen_fn)
141
+
84
142
  def get_optimizer_with_param(self,
85
143
  flattened_trainable_weights: List[Any],
86
144
  flattened_bias_weights: List[Any],
@@ -13,9 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import Any, Tuple, List
17
-
18
16
  import tensorflow as tf
17
+ from typing import List, Tuple
19
18
 
20
19
 
21
20
  def mse_loss(y: tf.Tensor, x: tf.Tensor, normalized: bool = True) -> tf.Tensor:
@@ -67,6 +66,40 @@ def multiple_tensors_mse_loss(y_list: List[tf.Tensor],
67
66
  else:
68
67
  return tf.reduce_mean(tf.stack(loss_values_list))
69
68
 
69
+ def sample_layer_attention_loss(y_list: List[tf.Tensor],
70
+ x_list: List[tf.Tensor],
71
+ fxp_w_list,
72
+ flp_w_list,
73
+ act_bn_mean,
74
+ act_bn_std,
75
+ loss_weights: Tuple[tf.Tensor]) -> tf.Tensor:
76
+ """
77
+ Compute Sample Layer Attention loss between two lists of tensors using TensorFlow.
78
+
79
+ Args:
80
+ y_list: First list of tensors.
81
+ x_list: Second list of tensors.
82
+ fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
83
+ loss_weights: layer-sample attention scores (tuplle by the same length as the number of layers, where each element is a tf.Tensor vector of length of number of samples).
84
+
85
+ Returns:
86
+ Sample Layer Attention loss (a scalar).
87
+ """
88
+ loss = 0
89
+ layers_mean_w = []
90
+ loss_weights = tf.stack(loss_weights, axis=1)
91
+
92
+ for i, (y, x) in enumerate(zip(y_list, x_list)):
93
+ norm = tf.reduce_sum(tf.square(y - x), axis=1)
94
+ if len(norm.shape) > 1:
95
+ norm = tf.reduce_mean(tf.reshape(norm, [norm.shape[0], -1]), axis=1)
96
+ w = loss_weights[:, i]
97
+ loss += tf.reduce_mean(w * norm)
98
+ layers_mean_w.append(tf.reduce_mean(w))
99
+
100
+ loss = loss / tf.reduce_max(tf.stack(layers_mean_w))
101
+ return loss
102
+
70
103
 
71
104
  def mse_loss_per_tensor(y: tf.Tensor,
72
105
  x: tf.Tensor,
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Callable, List, Tuple, Union
15
+ from typing import Callable, List, Tuple, Union, Generator
16
16
 
17
17
  import tensorflow as tf
18
18
  from keras import Model
@@ -20,11 +20,13 @@ from packaging import version
20
20
  from tensorflow.keras.layers import Layer
21
21
  from tqdm import tqdm
22
22
 
23
- from model_compression_toolkit.core.common.hessian import HessianInfoService
23
+ from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity
24
24
  # As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
25
25
  from model_compression_toolkit.core.common.user_info import UserInformation
26
26
  from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
27
- from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
27
+ from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, \
28
+ FixedSampleInfoDataset, FixedTFDataset, create_tf_dataloader, TFDatasetFromGenerator, \
29
+ IterableSampleWithConstInfoDataset
28
30
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
29
31
  from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
30
32
  get_gradual_activation_quantizer_wrapper_factory
@@ -83,13 +85,10 @@ class KerasGPTQTrainer(GPTQTrainer):
83
85
 
84
86
  """
85
87
 
86
- def _get_total_grad_steps():
87
- return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
88
-
89
- # This must be set before the model building (as it is required for activation holder construction),
90
- # which occurs in the base constructor.
91
- self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
92
- gptq_config, _get_total_grad_steps, KerasLinearAnnealingScheduler)
88
+ self.fw_soft_quantizer_regularization = SoftQuantizerRegularization
89
+ self.fw_linear_annealing_scheduler = KerasLinearAnnealingScheduler
90
+ self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
91
+ self.fw_get_weights_for_loss_fn = get_weights_for_loss
93
92
 
94
93
  super().__init__(graph_float,
95
94
  graph_quant,
@@ -99,53 +98,106 @@ class KerasGPTQTrainer(GPTQTrainer):
99
98
  representative_data_gen_fn=representative_data_gen,
100
99
  hessian_info_service=hessian_info_service)
101
100
 
102
- self.loss_list = []
103
- self.input_scale = 1
104
-
105
- trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
106
- self.fxp_model,
107
- fw_info,
108
- add_bias=gptq_config.train_bias)
109
-
110
- self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
111
-
112
- if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
113
- self.fxp_weights_list)):
114
- Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
115
- "and the number of float and quantized weights for loss calculation. "
116
- "Ensure all these elements align to proceed with GPTQ training.")
117
-
118
- flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
119
- flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights]
120
- trainable_quantization_parameters = trainable_threshold
121
- self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
122
- flattened_bias_weights,
123
- trainable_quantization_parameters)
124
- self.has_params_to_train = np.sum(
125
- [len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
126
-
127
- if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
128
- Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. "
129
- "Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover
130
- else:
131
- self.input_scale = self.gptq_user_info.input_scale
132
101
 
133
- self.weights_for_average_loss = self._get_compare_points_loss_weights()
102
+ def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
103
+ """
104
+ Computes Sample-Layer Attention score and builds a train dataloader in TensorFlow.
134
105
 
135
- self.reg_func = get_regularization(self.gptq_config,
136
- _get_total_grad_steps,
137
- SoftQuantizerRegularization,
138
- KerasLinearAnnealingScheduler)
106
+ Args:
107
+ data_gen_fn: function for representative dataset generation.
139
108
 
140
- def _get_compare_points_loss_weights(self):
141
- """ Get compare points weights for the distillation loss. """
142
- if self.gptq_config.hessian_weights_config:
143
- hess_dataloader = data_gen_to_dataloader(self.representative_data_gen_fn,
144
- batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
145
- return self.compute_hessian_based_weights(hess_dataloader)
109
+ Returns:
110
+ TensorFlow dataset yielding three outputs - samples, weights for the distillation loss,
111
+ and weights for regularization.
112
+ """
113
+ # Create a fixed dataset
114
+ fixed_dataset = FixedTFDataset(data_gen_fn)
115
+ orig_batch_size = fixed_dataset.orig_batch_size
116
+
117
+ # Prepare a separate loader for computing hessians over the whole dataset
118
+ hess_data_loader = create_tf_dataloader(
119
+ fixed_dataset,
120
+ batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size,
121
+ shuffle=False
122
+ )
123
+
124
+ # Prepare request for Hessian computation
125
+ request = self._build_hessian_request(
126
+ granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
127
+ data_loader=hess_data_loader,
128
+ n_samples=None
129
+ )
130
+ layers_hessians = self.hessian_service.fetch_hessian(request, force_compute=True)
131
+
132
+ # Compute SLA score defined as max over elements
133
+ layers_hessians = {
134
+ layer: tf.convert_to_tensor(tf.reduce_max(hess, axis=tuple(range(1, len(hess.shape))))) for layer, hess in layers_hessians.items()
135
+ }
136
+
137
+ # Stack hessians for comparison points
138
+ hessians_tensor = tf.stack([layers_hessians[layer.name] for layer in self.compare_points])
139
+ assert hessians_tensor.shape[0] == len(self.compare_points)
140
+ loss_weights = list(hessians_tensor.numpy()) # Convert to a list for compatibility
141
+
142
+ # Prepare final dataset with samples and loss weights
143
+ sla_train_dataset = FixedSampleInfoDataset(fixed_dataset.samples, loss_weights)
144
+
145
+ # Calculate regularization weights as mean across samples
146
+ reg_weights = tf.reduce_mean(hessians_tensor, axis=1)
147
+
148
+ # Define a collate function to add regularization weights to each batch
149
+ def collate_fn(samples_with_loss_weights):
150
+ return *samples_with_loss_weights, reg_weights
151
+
152
+ # Create final dataset using the new dataloader with collate_fn
153
+ final_dataset = create_tf_dataloader(
154
+ dataset=sla_train_dataset,
155
+ batch_size=orig_batch_size,
156
+ shuffle=True,
157
+ collate_fn=collate_fn
158
+ )
159
+
160
+ return final_dataset
161
+
162
+ def _prepare_train_dataloader_for_non_sla(self,
163
+ data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
164
+ """
165
+ Prepares a train dataloader for non-SLA tasks.
146
166
 
167
+ Args:
168
+ data_gen_fn: Factory for representative dataset generator.
169
+
170
+ Returns:
171
+ A `tf.data.Dataset` yielding samples with loss weights and regularization weights.
172
+ """
173
+ # Step 1: Create a dataset from the generator
174
+ dataset = TFDatasetFromGenerator(data_gen_fn)
147
175
  num_nodes = len(self.compare_points)
148
- return np.ones((num_nodes,)) / num_nodes
176
+
177
+ # Step 2: Compute loss weights
178
+ if self.gptq_config.hessian_weights_config:
179
+ hessian_dataset = create_tf_dataloader(dataset=dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
180
+ hessian_weights = self.compute_hessian_based_weights(hessian_dataset)
181
+ loss_weights = tf.convert_to_tensor(hessian_weights, dtype=tf.float32)
182
+ else:
183
+ loss_weights = tf.ones(num_nodes, dtype=tf.float32) / num_nodes
184
+
185
+ # Step 3: Create a dataset with samples and loss weights
186
+ augmented_dataset = IterableSampleWithConstInfoDataset(dataset.dataset, loss_weights)
187
+
188
+ # Step 4: Add constant regularization weights
189
+ reg_weights = tf.ones(num_nodes, dtype=tf.float32)
190
+
191
+ def collate_fn(batch):
192
+ samples, loss_weights = batch
193
+ return samples, loss_weights, reg_weights
194
+
195
+ # Step 5: Create a tf.data.Dataset with collate_fn
196
+ train_dataloader = create_tf_dataloader(augmented_dataset,
197
+ batch_size=dataset.orig_batch_size,
198
+ collate_fn=collate_fn)
199
+
200
+ return train_dataloader
149
201
 
150
202
  def _is_gptq_weights_trainable(self,
151
203
  node: common.BaseNode) -> bool:
@@ -226,9 +278,13 @@ class KerasGPTQTrainer(GPTQTrainer):
226
278
 
227
279
  return gptq_model, gptq_user_info
228
280
 
229
- def compute_gradients(self, in_y_float: List[tf.Tensor], input_data: List[np.ndarray],
281
+ def compute_gradients(self,
282
+ in_y_float: List[tf.Tensor],
283
+ input_data: List[np.ndarray],
230
284
  in_optimizer_with_param: List,
231
- training=True) -> Tuple[tf.Tensor, List[tf.Tensor]]:
285
+ training=True,
286
+ distill_loss_weights=None,
287
+ reg_weights=None) -> Tuple[tf.Tensor, List[tf.Tensor]]:
232
288
  """
233
289
  Get outputs from both teacher and student networks. Compute the observed error,
234
290
  and use it to compute the gradients and applying them to the student weights.
@@ -253,9 +309,9 @@ class KerasGPTQTrainer(GPTQTrainer):
253
309
  self.flp_weights_list,
254
310
  self.compare_points_mean,
255
311
  self.compare_points_std,
256
- self.weights_for_average_loss)
312
+ distill_loss_weights)
257
313
 
258
- reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
314
+ reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor, reg_weights)
259
315
 
260
316
  loss_value += reg_value
261
317
 
@@ -279,14 +335,19 @@ class KerasGPTQTrainer(GPTQTrainer):
279
335
  # Training loop
280
336
  # ----------------------------------------------
281
337
  if self.has_params_to_train:
282
- self.micro_training_loop(self.representative_data_gen_fn,
283
- compute_gradients,
338
+ self.micro_training_loop(compute_gradients,
284
339
  self.optimizer_with_param,
285
340
  self.gptq_config.n_epochs,
286
341
  True)
287
342
 
288
343
  @tf.function
289
- def nano_training_step(self, input_data, in_compute_gradients, in_optimizer_with_param, is_training):
344
+ def nano_training_step(self,
345
+ input_data,
346
+ in_compute_gradients,
347
+ in_optimizer_with_param,
348
+ is_training,
349
+ distill_loss_weights,
350
+ reg_weights):
290
351
  """
291
352
  This function run part of the training step, wrapped by a tf.function for acceleration.
292
353
  Args:
@@ -303,12 +364,15 @@ class KerasGPTQTrainer(GPTQTrainer):
303
364
  # run float model
304
365
  y_float = self.float_model(input_data)
305
366
  # rung quantized model and calculate loss & gradients
306
- loss_value_step, grads = in_compute_gradients(y_float, input_data, in_optimizer_with_param,
307
- training=is_training)
367
+ loss_value_step, grads = in_compute_gradients(y_float,
368
+ input_data,
369
+ in_optimizer_with_param,
370
+ training=is_training,
371
+ distill_loss_weights=distill_loss_weights,
372
+ reg_weights=reg_weights)
308
373
  return loss_value_step, grads
309
374
 
310
375
  def micro_training_loop(self,
311
- data_function: Callable,
312
376
  in_compute_gradients: Callable,
313
377
  in_optimizer_with_param: List[Tuple[tf.keras.optimizers.Optimizer, List[tf.Tensor]]],
314
378
  n_epochs: int,
@@ -316,7 +380,6 @@ class KerasGPTQTrainer(GPTQTrainer):
316
380
  """
317
381
  This function run a micro training loop on given set of parameters.
318
382
  Args:
319
- data_function: A callable function that give a batch of samples.
320
383
  in_compute_gradients: A callable function that compute the gradients.
321
384
  in_optimizer_with_param: A list of optimizer classes to update with the corresponding parameters.
322
385
  n_epochs: Number of update iterations of representative dataset.
@@ -327,12 +390,19 @@ class KerasGPTQTrainer(GPTQTrainer):
327
390
  """
328
391
  with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
329
392
  for _ in epochs_pbar:
330
- with tqdm(data_function(), position=1, leave=False) as data_pbar:
393
+ with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
331
394
  for data in data_pbar:
332
- input_data = [d * self.input_scale for d in data]
333
395
 
334
- loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
335
- in_optimizer_with_param, is_training)
396
+ input_data, distill_loss_weights, reg_weight = data
397
+
398
+ input_data = [d * self.input_scale for d in input_data]
399
+
400
+ loss_value_step, grads = self.nano_training_step(input_data,
401
+ in_compute_gradients,
402
+ in_optimizer_with_param,
403
+ is_training,
404
+ distill_loss_weights,
405
+ reg_weight)
336
406
  # Run one step of gradient descent by updating
337
407
  # the value of the variables to minimize the loss.
338
408
  for i, (o, p) in enumerate(in_optimizer_with_param):
@@ -16,7 +16,6 @@
16
16
  import tensorflow as tf
17
17
  from typing import Tuple, List
18
18
  from model_compression_toolkit.core.keras.constants import USE_BIAS
19
- from model_compression_toolkit.core.common.framework_info import FrameworkInfo
20
19
  from tensorflow.keras.models import Model
21
20
  from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
22
21
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
@@ -26,7 +25,6 @@ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_qu
26
25
 
27
26
 
28
27
  def get_gptq_trainable_parameters(fxp_model: Model,
29
- fw_info: FrameworkInfo,
30
28
  add_bias: bool = False) -> (
31
29
  List[tf.Variable], List[tf.Variable], List[tf.Variable]):
32
30
  """
@@ -34,7 +32,6 @@ def get_gptq_trainable_parameters(fxp_model: Model,
34
32
 
35
33
  Args:
36
34
  fxp_model: Model to get its trainable parameters.
37
- fw_info: Framework information needed for keras kernel ops list.
38
35
  add_bias: Whether to include biases of the model (if there are) or not.
39
36
 
40
37
  Returns:
@@ -60,7 +57,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
60
57
  trainable_threshold.extend(quantizer_trainable_threshold)
61
58
 
62
59
  if add_bias:
63
- kernel_ops_attrs = fw_info.kernel_ops_attributes_mapping.get(type(layer.layer))
60
+ kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
64
61
  use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
65
62
  and layer.layer.get_config().get(USE_BIAS)
66
63
  if use_bias is not None and use_bias and layer.layer.bias is not None:
@@ -19,7 +19,7 @@ from packaging import version
19
19
 
20
20
  from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
21
21
  from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
22
- LR_BIAS_DEFAULT, GPTQ_MOMENTUM
22
+ LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
23
23
  from model_compression_toolkit.logger import Logger
24
24
  from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
25
25
  from model_compression_toolkit.verify_packages import FOUND_TF
@@ -42,7 +42,7 @@ if FOUND_TF:
42
42
  from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
43
43
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
44
44
  from tensorflow.keras.models import Model
45
- from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
45
+ from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
46
46
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
47
47
  from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
48
48
  from model_compression_toolkit import get_target_platform_capabilities
@@ -61,11 +61,12 @@ if FOUND_TF:
61
61
  def get_keras_gptq_config(n_epochs: int,
62
62
  optimizer: OptimizerV2 = None,
63
63
  optimizer_rest: OptimizerV2 = None,
64
- loss: Callable = GPTQMultipleTensorsLoss(),
64
+ loss: Callable = None,
65
65
  log_function: Callable = None,
66
66
  use_hessian_based_weights: bool = True,
67
- regularization_factor: float = REG_DEFAULT,
67
+ regularization_factor: float = None,
68
68
  hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
69
+ use_hessian_sample_attention: bool = False,
69
70
  gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False) -> GradientPTQConfig:
70
71
  """
71
72
  Create a GradientPTQConfig instance for Keras models.
@@ -79,6 +80,7 @@ if FOUND_TF:
79
80
  use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
80
81
  regularization_factor (float): A floating point number that defines the regularization factor.
81
82
  hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
83
+ use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss.
82
84
  gradual_activation_quantization (bool, GradualActivationQuantizationConfig): If False, GradualActivationQuantization is disabled. If True, GradualActivationQuantization is enabled with the default settings. GradualActivationQuantizationConfig object can be passed to use non-default settings.
83
85
 
84
86
  returns:
@@ -105,9 +107,25 @@ if FOUND_TF:
105
107
  """
106
108
  optimizer = optimizer or tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT)
107
109
  optimizer_rest = optimizer_rest or tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT)
110
+ bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
108
111
 
109
- bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT,
110
- momentum=GPTQ_MOMENTUM)
112
+ if regularization_factor is None:
113
+ regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
114
+
115
+ loss = loss or GPTQMultipleTensorsLoss()
116
+ hessian_weights_config = None
117
+ if use_hessian_sample_attention:
118
+ if not use_hessian_based_weights: # pragma: no cover
119
+ raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
120
+
121
+ hessian_weights_config = GPTQHessianScoresConfig(per_sample=True,
122
+ hessians_num_samples=None,
123
+ hessian_batch_size=hessian_batch_size)
124
+ loss = loss or sample_layer_attention_loss
125
+ elif use_hessian_based_weights:
126
+ hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
127
+ hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
128
+ hessian_batch_size=hessian_batch_size)
111
129
 
112
130
  if isinstance(gradual_activation_quantization, bool):
113
131
  gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
@@ -117,11 +135,6 @@ if FOUND_TF:
117
135
  raise TypeError(f'gradual_activation_quantization argument should be bool or '
118
136
  f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
119
137
 
120
- hessian_weights_config = None
121
- if use_hessian_based_weights:
122
- hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
123
- hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
124
- hessian_batch_size=hessian_batch_size)
125
138
  return GradientPTQConfig(n_epochs=n_epochs,
126
139
  optimizer=optimizer,
127
140
  optimizer_rest=optimizer_rest,
@@ -40,30 +40,42 @@ class SoftQuantizerRegularization:
40
40
  self.count_iter = tf.Variable(0.)
41
41
 
42
42
 
43
- def __call__(self, model: Model, entropy_reg: float):
43
+ def __call__(self, model: Model, entropy_reg: float, layer_weights: tf.Tensor):
44
44
  """
45
45
  Returns the soft quantizer regularization value for SoftRounding.
46
46
 
47
47
  Args:
48
48
  model: A model to be quantized with SoftRounding.
49
49
  entropy_reg: Entropy value to scale the quantizer regularization.
50
+ layer_weights: a vector of layers weights.
50
51
 
51
52
  Returns: Regularization value.
52
53
  """
53
- soft_reg_aux: List[tf.Tensor] = []
54
+ layers = [l for l in model.layers if isinstance(l, KerasTrainableQuantizationWrapper)]
55
+
56
+ if layer_weights.shape[0] != len(layers):
57
+ raise ValueError(f'Expected weights.shape[0] to be {len(layers)}, '
58
+ f'received shape {layer_weights.shape}.') # pragma: no cover
59
+
54
60
  b = self.beta_scheduler(self.count_iter.value())
55
- for layer in model.layers:
56
- if isinstance(layer, KerasTrainableQuantizationWrapper):
57
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
58
- fw_info=DEFAULT_KERAS_INFO)
59
61
 
60
- st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
61
- soft_reg_aux.append(tf.reduce_sum(1 - tf.pow(tf.math.abs(st - .5) * 2, b)))
62
+ max_w = tf.reduce_max(layer_weights)
63
+
64
+ # Initialize reg to zero
65
+ reg = tf.constant(0.0, dtype=tf.float32)
66
+
67
+ # Compute the regularization term without concatenating
68
+ for i, layer in enumerate(layers):
69
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
70
+ fw_info=DEFAULT_KERAS_INFO)
71
+
72
+ st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
62
73
 
63
- reg = 0
74
+ soft_loss = tf.reduce_sum(1 - tf.pow(tf.math.abs(st - 0.5) * 2, b))
75
+ reg += layer_weights[i] * soft_loss
64
76
 
65
- for sq in soft_reg_aux:
66
- reg += sq
77
+ # Normalize reg by max_w
78
+ reg = reg / max_w
67
79
 
68
80
  self.count_iter.assign_add(1.0)
69
81
 
@@ -21,9 +21,6 @@ from torch.nn import Module
21
21
  from torch.utils.data import DataLoader
22
22
  from tqdm import tqdm
23
23
 
24
- from model_compression_toolkit.gptq.common.gradual_activation_quantization import get_gradual_activation_quantizer_wrapper_factory
25
- from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
26
-
27
24
  from model_compression_toolkit.core.common import Graph, BaseNode
28
25
  from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
29
26
  from model_compression_toolkit.core.common.framework_info import FrameworkInfo
@@ -41,7 +38,6 @@ from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable
41
38
  from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
42
39
 
43
40
  from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
44
- from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
45
41
  from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import PytorchLinearAnnealingScheduler
46
42
  from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import SoftQuantizerRegularization as PytorchSoftQuantizerRegularization
47
43
 
@@ -76,13 +72,10 @@ class PytorchGPTQTrainer(GPTQTrainer):
76
72
  representative_data_gen: Dataset to use for inputs of the models.
77
73
  hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
78
74
  """
79
- def _get_total_grad_steps():
80
- # TODO get it from the dataset
81
- return get_total_grad_steps(representative_data_gen) * gptq_config.n_epochs
82
-
83
- # must be set prior to model building in the base class constructor
84
- self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
85
- gptq_config, _get_total_grad_steps, PytorchLinearAnnealingScheduler)
75
+ self.fw_soft_quantizer_regularization = PytorchSoftQuantizerRegularization
76
+ self.fw_linear_annealing_scheduler = PytorchLinearAnnealingScheduler
77
+ self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
78
+ self.fw_get_weights_for_loss_fn = get_weights_for_loss
86
79
 
87
80
  super().__init__(graph_float,
88
81
  graph_quant,
@@ -92,40 +85,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
92
85
  representative_data_gen_fn=representative_data_gen,
93
86
  hessian_info_service=hessian_info_service)
94
87
 
95
- self.loss_list = []
96
- self.input_scale = 1
97
- if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
98
- Logger.critical("Input scale mismatch between float and GPTQ networks. "
99
- "Ensure both networks have matching input scales.") # pragma: no cover
100
- else:
101
- self.input_scale = self.gptq_user_info.input_scale
102
-
103
- trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
104
- self.fxp_model,
105
- add_bias=self.gptq_config.train_bias)
106
-
107
- self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
108
- if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
109
- self.fxp_weights_list)):
110
- Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, "
111
- "and float vs. quantized weights for loss calculation do not match. "
112
- "Verify consistency across these parameters for successful GPTQ training.")
113
-
114
- self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
115
- trainable_bias,
116
- trainable_threshold)
117
- hessian_cfg = self.gptq_config.hessian_weights_config
118
-
119
- self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
120
- if self.use_sample_layer_attention:
121
- # normalization is currently not supported, make sure the config reflects it.
122
- if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
123
- raise NotImplementedError()
124
- self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
125
- else:
126
- self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
127
-
128
- self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps, PytorchSoftQuantizerRegularization, PytorchLinearAnnealingScheduler)
129
88
 
130
89
  def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
131
90
  """