mct-nightly 2.2.0.20241204.524__py3-none-any.whl → 2.2.0.20241206.524__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/RECORD +17 -17
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/keras/data_util.py +151 -18
- model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py +93 -86
- model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py +17 -0
- model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +1 -2
- model_compression_toolkit/gptq/common/gptq_training.py +58 -0
- model_compression_toolkit/gptq/keras/gptq_loss.py +35 -2
- model_compression_toolkit/gptq/keras/gptq_training.py +137 -67
- model_compression_toolkit/gptq/keras/graph_info.py +1 -4
- model_compression_toolkit/gptq/keras/quantization_facade.py +24 -11
- model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +23 -11
- model_compression_toolkit/gptq/pytorch/gptq_training.py +4 -45
- {mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/top_level.txt +0 -0
{mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=KhP8R07jwQig7PMnV7NExSRFSjG_rAbMcGhuL8koQWc,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -155,7 +155,7 @@ model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256
|
|
155
155
|
model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
|
156
156
|
model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
|
157
157
|
model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
|
158
|
-
model_compression_toolkit/core/keras/data_util.py,sha256=
|
158
|
+
model_compression_toolkit/core/keras/data_util.py,sha256=sTEuHUrT8S3CpeAEG0XDlYA0bWZKISGPilObPlO0TA8,6833
|
159
159
|
model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
|
160
160
|
model_compression_toolkit/core/keras/keras_implementation.py,sha256=HwbIR7x4t-TBNbWHVvVNFk8z-KFt6zM0LWAUXQuNZrk,31753
|
161
161
|
model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
|
@@ -193,9 +193,9 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/softmax_s
|
|
193
193
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/virtual_activation_weights_composition.py,sha256=wH9ocMLL725-uUPU-zCxdd8NwT5nyd0ZShmI7iuTwF8,1462
|
194
194
|
model_compression_toolkit/core/keras/graph_substitutions/substitutions/weights_activation_split.py,sha256=rjIheZW7LbSPv9bzMSmC8wl6UUxaTkd4J2IHinObT-Y,1814
|
195
195
|
model_compression_toolkit/core/keras/hessian/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
|
196
|
-
model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py,sha256=
|
197
|
-
model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=
|
198
|
-
model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=
|
196
|
+
model_compression_toolkit/core/keras/hessian/activation_hessian_scores_calculator_keras.py,sha256=qGEyOzC1_NIcnBmvvjA-GT7o9-PWo0Ko66vcEyLixhw,9180
|
197
|
+
model_compression_toolkit/core/keras/hessian/hessian_scores_calculator_keras.py,sha256=1o7X9GXSfpEmuB5ee2AaBQ2sN2xzX4-smbrq_0qOGRU,4454
|
198
|
+
model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py,sha256=Rl6NNGkHMV0ioEM5bbM4XX7yHDqG6mMp4ifN2VQBDxE,12168
|
199
199
|
model_compression_toolkit/core/keras/mixed_precision/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
|
200
200
|
model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py,sha256=aW8wR13fK6P6xzbU9XGU60IO1yYzXSo_Hk4qeq486kg,5137
|
201
201
|
model_compression_toolkit/core/keras/mixed_precision/configurable_weights_quantizer.py,sha256=Ziydik2j-LvNBXP3TSfUD6rEezPAikzQGib0_IXkmGM,6729
|
@@ -355,21 +355,21 @@ model_compression_toolkit/gptq/common/gptq_config.py,sha256=QwSEZZlC6OpnpoBQoAFf
|
|
355
355
|
model_compression_toolkit/gptq/common/gptq_constants.py,sha256=8HB0yiX75zZ1IKgQUPWpFCM5sS8HAqslws5XrOhxJQ0,750
|
356
356
|
model_compression_toolkit/gptq/common/gptq_framework_implementation.py,sha256=n3mSf4J92kFjekzyGyrJULylI-8Jf5OVWJ5AFoVnEx0,1266
|
357
357
|
model_compression_toolkit/gptq/common/gptq_graph.py,sha256=-bL5HhPcKqV8nj4dZPXc5QmQJbFBel6etrioikP0tEo,3039
|
358
|
-
model_compression_toolkit/gptq/common/gptq_training.py,sha256=
|
358
|
+
model_compression_toolkit/gptq/common/gptq_training.py,sha256=vvrQH1MIW3w90yt9VKBW6jTMXkBrDY82JmCDwT8Kve8,17002
|
359
359
|
model_compression_toolkit/gptq/common/gradual_activation_quantization.py,sha256=EgpzMs_aDoB0wQiTagqvcxCTfrgNUuCfdXEXmfNiyb0,3780
|
360
360
|
model_compression_toolkit/gptq/common/regularization_factory.py,sha256=hyunpXepVeHyoAFJw6zNLK-3ZHBmiut3lmNisJN_L3E,2514
|
361
361
|
model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
362
362
|
model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
|
363
|
-
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=
|
364
|
-
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=
|
365
|
-
model_compression_toolkit/gptq/keras/graph_info.py,sha256=
|
366
|
-
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=
|
363
|
+
model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
|
364
|
+
model_compression_toolkit/gptq/keras/gptq_training.py,sha256=0WGiP7Gs4xX3FBs1PNaZ7w3hWRigwQXqYjBrs_-x32o,23241
|
365
|
+
model_compression_toolkit/gptq/keras/graph_info.py,sha256=zwoeHX67nJJ5-zYLjzvMXS9TLsy9BsizARbZiDVjVSA,4473
|
366
|
+
model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=PO-tNoCoWQpXgefVxqxBfAQ29kGe_DFBgiOQ2DLYato,18005
|
367
367
|
model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
|
368
368
|
model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
|
369
369
|
model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
|
370
370
|
model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=rst-u5EB9Xss4ndKqi297WvZ-9RVee2TAUVFelPVKhU,4663
|
371
371
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
|
372
|
-
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=
|
372
|
+
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=UTvEL5hN2cEsMwiGBDbpcE0kQr32VFKwlJBWlDg8HNA,3271
|
373
373
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=BBSDWLmeywjSM5N6oJkMgcuo7zrXTesB4zLwRGG8QB0,12159
|
374
374
|
model_compression_toolkit/gptq/keras/quantizer/soft_rounding/uniform_soft_quantizer.py,sha256=pyhlVpoauHM-zuixHsIGPHFgQoXppL8TlDFCjPE2RuY,10377
|
375
375
|
model_compression_toolkit/gptq/keras/quantizer/ste_rounding/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
@@ -377,7 +377,7 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
|
|
377
377
|
model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
|
378
378
|
model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phIqeu5-_7_5VBT4DT-FCw7Do,3892
|
379
379
|
model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
|
380
|
-
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=
|
380
|
+
model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=WtehnyiYXdUXf8-uNpV0mdsalF7YF7eKnL7tcFrzZoE,19549
|
381
381
|
model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
|
382
382
|
model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=hZFU_ZY-LYcpRZyzzX7NsJievkIYKGdkgBzEoB4rsRQ,16020
|
383
383
|
model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
|
@@ -559,8 +559,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
|
|
559
559
|
model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
|
560
560
|
model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
|
561
561
|
model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
|
562
|
-
mct_nightly-2.2.0.
|
563
|
-
mct_nightly-2.2.0.
|
564
|
-
mct_nightly-2.2.0.
|
565
|
-
mct_nightly-2.2.0.
|
566
|
-
mct_nightly-2.2.0.
|
562
|
+
mct_nightly-2.2.0.20241206.524.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
563
|
+
mct_nightly-2.2.0.20241206.524.dist-info/METADATA,sha256=Q-MSMJXd4He0d0RJ_jhEABCs2FgxB6vZIGjv24boOnw,26446
|
564
|
+
mct_nightly-2.2.0.20241206.524.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
565
|
+
mct_nightly-2.2.0.20241206.524.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
|
566
|
+
mct_nightly-2.2.0.20241206.524.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20241206.000524"
|
@@ -18,6 +18,27 @@ import tensorflow as tf
|
|
18
18
|
|
19
19
|
from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor
|
20
20
|
|
21
|
+
import tensorflow as tf
|
22
|
+
from typing import Callable, Generator, Sequence, Any
|
23
|
+
|
24
|
+
|
25
|
+
def get_tensor_spec(item, ignore_batch_dim=False):
|
26
|
+
"""
|
27
|
+
Get the TensorFlow TensorSpec for an item, optionally ignoring the first dimension.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
item: The input item, which could be a tensor, tuple, or list.
|
31
|
+
ignore_batch_dim (bool): Whether to ignore the first dimension of the tensor shape.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
TensorSpec or a tuple of TensorSpecs.
|
35
|
+
"""
|
36
|
+
if isinstance(item, (tuple, list)):
|
37
|
+
return tuple(get_tensor_spec(sub_item, ignore_batch_dim) for sub_item in item)
|
38
|
+
|
39
|
+
shape = item.shape[1:] if ignore_batch_dim else item.shape
|
40
|
+
return tf.TensorSpec(shape=shape, dtype=item.dtype)
|
41
|
+
|
21
42
|
|
22
43
|
def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
|
23
44
|
"""
|
@@ -29,39 +50,151 @@ def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
|
|
29
50
|
Returns:
|
30
51
|
A factory for a flattened data generator.
|
31
52
|
"""
|
53
|
+
|
32
54
|
def gen():
|
33
55
|
for inputs_batch in data_gen_fn():
|
34
56
|
for sample in zip(*inputs_batch):
|
35
|
-
yield
|
36
|
-
return gen
|
57
|
+
yield tuple([tf.convert_to_tensor(s) for s in sample])
|
37
58
|
|
59
|
+
return gen
|
38
60
|
|
39
|
-
# TODO in tf dataset and dataloader are combined within tf.data.Dataset. For advanced use cases such as gptq sla we
|
40
|
-
# need to separate dataset from dataloader similarly to torch data_util.
|
41
61
|
class TFDatasetFromGenerator:
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
raise TypeError(f'Representative data generator is expected to generate a list of tensors, '
|
46
|
-
f'got {type(inputs)}') # pragma: no cover
|
62
|
+
"""
|
63
|
+
TensorFlow dataset from a data generator function, batched to a specified size.
|
64
|
+
"""
|
47
65
|
|
66
|
+
def __init__(self, data_gen_fn: Callable[[], Generator]):
|
67
|
+
"""
|
68
|
+
Args:
|
69
|
+
data_gen_fn: a factory function for data generator that yields lists of tensors.
|
70
|
+
"""
|
71
|
+
inputs = next(data_gen_fn())
|
72
|
+
if not isinstance(inputs, list):
|
73
|
+
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
|
48
74
|
self.orig_batch_size = inputs[0].shape[0]
|
49
|
-
|
50
|
-
output_signature = tuple([tf.TensorSpec(shape=t.shape[1:], dtype=t.dtype) for t in inputs])
|
51
|
-
dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen), output_signature=output_signature)
|
52
|
-
self.dataset = dataset.batch(batch_size)
|
53
75
|
self._size = None
|
54
76
|
|
77
|
+
# TFDatasetFromGenerator flattens the dataset, thus we ignore the batch dimension
|
78
|
+
output_signature = get_tensor_spec(inputs, ignore_batch_dim=True)
|
79
|
+
self.dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)
|
80
|
+
|
81
|
+
|
55
82
|
def __iter__(self):
|
56
83
|
return iter(self.dataset)
|
57
84
|
|
58
85
|
def __len__(self):
|
59
86
|
""" Returns the number of batches. """
|
60
87
|
if self._size is None:
|
61
|
-
self.
|
62
|
-
return self.
|
88
|
+
self._size = sum(1 for _ in self.dataset)
|
89
|
+
return self._size
|
90
|
+
|
91
|
+
|
92
|
+
|
93
|
+
class FixedTFDataset:
|
94
|
+
"""
|
95
|
+
Fixed dataset containing samples from a generator, stored in memory.
|
96
|
+
"""
|
97
|
+
|
98
|
+
def __init__(self, data_gen_fn: Callable[[], Generator], n_samples: int = None):
|
99
|
+
"""
|
100
|
+
Args:
|
101
|
+
data_gen_fn: data generator function.
|
102
|
+
n_samples: number of samples to store in the dataset. If None, uses all samples in one pass.
|
103
|
+
"""
|
104
|
+
inputs = next(data_gen_fn())
|
105
|
+
if not isinstance(inputs, list):
|
106
|
+
raise TypeError(f'Data generator is expected to yield a list of tensors, got {type(inputs)}')
|
107
|
+
self.orig_batch_size = inputs[0].shape[0]
|
108
|
+
|
109
|
+
samples = []
|
110
|
+
for batch in data_gen_fn():
|
111
|
+
samples.extend(zip(*[tf.convert_to_tensor(t) for t in batch]))
|
112
|
+
if n_samples is not None and len(samples) >= n_samples:
|
113
|
+
samples = samples[:n_samples]
|
114
|
+
break
|
115
|
+
|
116
|
+
if n_samples and len(samples) < n_samples:
|
117
|
+
raise ValueError(f'Not enough samples to create a dataset with {n_samples} samples')
|
118
|
+
self.samples = samples
|
119
|
+
|
120
|
+
def __len__(self):
|
121
|
+
return len(self.samples)
|
122
|
+
|
123
|
+
def __getitem__(self, index):
|
124
|
+
return self.samples[index]
|
125
|
+
|
126
|
+
|
127
|
+
class FixedSampleInfoDataset:
|
128
|
+
"""
|
129
|
+
Dataset for samples with additional info, each element is a tuple of (sample, sample_info).
|
130
|
+
"""
|
131
|
+
|
132
|
+
def __init__(self, samples: Sequence, sample_info: Sequence):
|
133
|
+
if not all(len(info) == len(samples) for info in sample_info):
|
134
|
+
raise ValueError('Sample and additional info lengths must match')
|
135
|
+
self.samples = samples
|
136
|
+
self.sample_info = sample_info
|
137
|
+
|
138
|
+
def __len__(self):
|
139
|
+
return len(self.samples)
|
140
|
+
|
141
|
+
def __getitem__(self, index):
|
142
|
+
return self.samples[index], tuple([info[index] for info in self.sample_info])
|
143
|
+
|
144
|
+
|
145
|
+
class IterableSampleWithConstInfoDataset:
|
146
|
+
"""
|
147
|
+
Augments each sample in an iterable dataset with constant additional information.
|
148
|
+
"""
|
149
|
+
|
150
|
+
def __init__(self, samples_dataset: tf.data.Dataset, *info: Any):
|
151
|
+
self.samples_dataset = samples_dataset
|
152
|
+
self.info = info
|
153
|
+
|
154
|
+
def __iter__(self):
|
155
|
+
for sample in self.samples_dataset:
|
156
|
+
yield (sample, *self.info)
|
157
|
+
|
158
|
+
|
159
|
+
def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
|
160
|
+
"""Create a DataLoader based on samples yielded by data_gen."""
|
161
|
+
ds = TFDatasetFromGenerator(data_gen_fn)
|
162
|
+
return create_tf_dataloader(dataset=ds, batch_size=batch_size)
|
163
|
+
|
164
|
+
|
165
|
+
def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
|
166
|
+
"""
|
167
|
+
Creates a tf.data.Dataset with specified loading options.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
dataset: The dataset container (e.g., FixedDatasetFromGenerator or FixedSampleInfoDataset).
|
171
|
+
batch_size: Number of samples per batch.
|
172
|
+
shuffle: Whether to shuffle the dataset.
|
173
|
+
collate_fn: A function to apply to each batch (e.g., add extra outputs like regularization weights).
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
tf.data.Dataset: Configured for batching, shuffling, and custom transformations.
|
177
|
+
"""
|
178
|
+
def generator():
|
179
|
+
for item in dataset:
|
180
|
+
yield item
|
181
|
+
|
182
|
+
dummy_input_tensors = next(generator())
|
183
|
+
|
184
|
+
output_signature = get_tensor_spec(dummy_input_tensors)
|
185
|
+
|
186
|
+
tf_dataset = tf.data.Dataset.from_generator(
|
187
|
+
generator,
|
188
|
+
output_signature=output_signature
|
189
|
+
)
|
190
|
+
|
191
|
+
if shuffle:
|
192
|
+
tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
|
193
|
+
|
194
|
+
tf_dataset = tf_dataset.batch(batch_size)
|
63
195
|
|
196
|
+
# Apply collate function if provided
|
197
|
+
if collate_fn:
|
198
|
+
tf_dataset = tf_dataset.map(lambda *args: collate_fn(args))
|
64
199
|
|
65
|
-
|
66
|
-
""" Create DataLoader based on samples yielded by data_gen. """
|
67
|
-
return TFDatasetFromGenerator(data_gen_fn, batch_size)
|
200
|
+
return tf_dataset
|
@@ -60,96 +60,103 @@ class ActivationHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
60
60
|
Returns:
|
61
61
|
List[np.ndarray]: Scores based on the Hessian-approximation for the requested nodes.
|
62
62
|
"""
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
63
|
+
model_output_nodes = [ot.node for ot in self.graph.get_outputs()]
|
64
|
+
|
65
|
+
if len([n for n in self.hessian_request.target_nodes if n in model_output_nodes]) > 0:
|
66
|
+
Logger.critical("Trying to compute activation Hessian approximation with respect to the model output. "
|
67
|
+
"This operation is not supported. "
|
68
|
+
"Remove the output node from the set of node targets in the Hessian request.")
|
69
|
+
|
70
|
+
grad_model_outputs = self.hessian_request.target_nodes + model_output_nodes
|
71
|
+
|
72
|
+
# Building a model to run Hessian approximation on
|
73
|
+
model, _ = FloatKerasModelBuilder(graph=self.graph, append2output=grad_model_outputs).build_model()
|
74
|
+
|
75
|
+
# Record operations for automatic differentiation
|
76
|
+
with tf.GradientTape(persistent=True, watch_accessed_variables=False) as g:
|
77
|
+
g.watch(self.input_images)
|
78
|
+
|
79
|
+
if len(self.input_images) > 1:
|
80
|
+
outputs = model(self.input_images)
|
81
|
+
else:
|
82
|
+
outputs = model(*self.input_images)
|
83
|
+
|
84
|
+
if len(outputs) != len(grad_model_outputs): # pragma: no cover
|
85
|
+
Logger.critical(
|
86
|
+
f"Model for computing activation Hessian approximation expects {len(grad_model_outputs)} "
|
87
|
+
f"outputs, but got {len(outputs)} output tensors.")
|
88
|
+
|
89
|
+
# Extracting the intermediate activation tensors and the model real output.
|
90
|
+
# Note that we do not allow computing Hessian for output nodes, so there shouldn't be an overlap.
|
91
|
+
num_target_nodes = len(self.hessian_request.target_nodes)
|
92
|
+
# Extract activation tensors of nodes for which we want to compute Hessian
|
93
|
+
target_activation_tensors = outputs[:num_target_nodes]
|
94
|
+
# Extract the model outputs
|
95
|
+
output_tensors = outputs[num_target_nodes:]
|
96
|
+
|
97
|
+
# Unfold and concatenate all outputs to form a single tensor
|
98
|
+
output = self._concat_tensors(output_tensors)
|
99
|
+
|
100
|
+
# List to store the Hessian-approximation scores for each interest point
|
101
|
+
ipts_hessian_approximations = [tf.Variable([0.0], dtype=tf.float32, trainable=True)
|
102
|
+
for _ in range(len(target_activation_tensors))]
|
103
|
+
|
104
|
+
# Loop through each interest point activation tensor
|
105
|
+
prev_mean_results = None
|
106
|
+
for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
|
107
|
+
# Generate random tensor of 1s and -1s
|
108
|
+
v = self._generate_random_vectors_batch(output.shape)
|
109
|
+
f_v = tf.reduce_sum(v * output)
|
110
|
+
for i, ipt in enumerate(target_activation_tensors): # Per Interest point activation tensor
|
111
|
+
interest_point_scores = [] # List to store scores for each interest point
|
112
|
+
with g.stop_recording():
|
113
|
+
# Computing the approximation by getting the gradient of (output * v)
|
114
|
+
hess_v = g.gradient(f_v, ipt)
|
115
|
+
|
116
|
+
if hess_v is None:
|
117
|
+
# In case we have an output node, which is an interest point, but it is not
|
118
|
+
# differentiable, we consider its Hessian to be the initial value 0.
|
119
|
+
continue # pragma: no cover
|
120
|
+
|
121
|
+
if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
|
122
122
|
# Mean over all dims but the batch (CXHXW for conv)
|
123
123
|
hessian_approx = tf.reduce_sum(hess_v ** 2.0,
|
124
124
|
axis=tuple(d for d in range(1, len(hess_v.shape))))
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
125
|
+
elif self.hessian_request.granularity == HessianScoresGranularity.PER_ELEMENT:
|
126
|
+
hessian_approx = hess_v ** 2
|
127
|
+
elif self.hessian_request.granularity == HessianScoresGranularity.PER_OUTPUT_CHANNEL:
|
128
|
+
axes_to_sum = tuple(d for d in range(1, len(hess_v.shape)-1))
|
129
|
+
hessian_approx = tf.reduce_sum(hess_v ** 2.0, axis=axes_to_sum)
|
130
|
+
|
131
|
+
else: # pragma: no cover
|
132
|
+
Logger.critical(f"{self.hessian_request.granularity} "
|
133
|
+
f"is not supported for Keras activation hessian\'s approximation scores calculator.")
|
134
|
+
|
135
|
+
# Free gradients
|
136
|
+
del hess_v
|
137
|
+
|
138
|
+
# Update node Hessian approximation mean over random iterations
|
139
|
+
ipts_hessian_approximations[i] = (j * ipts_hessian_approximations[i] + hessian_approx) / (j + 1)
|
140
|
+
|
141
|
+
# If the change to the mean approximation is insignificant (to all outputs)
|
142
|
+
# we stop the calculation.
|
143
|
+
if j > MIN_HESSIAN_ITER and prev_mean_results is not None:
|
144
|
+
new_mean_res = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
|
145
|
+
relative_delta_per_node = (tf.abs(new_mean_res - prev_mean_results) /
|
146
|
+
(tf.abs(new_mean_res) + 1e-6))
|
147
|
+
max_delta = tf.reduce_max(relative_delta_per_node)
|
148
|
+
if max_delta < HESSIAN_COMP_TOLERANCE:
|
149
|
+
break
|
150
|
+
|
151
|
+
if self.hessian_request.granularity == HessianScoresGranularity.PER_TENSOR:
|
142
152
|
prev_mean_results = tf.reduce_mean(tf.stack(ipts_hessian_approximations), axis=1)
|
143
153
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
154
|
+
# Convert results to list of numpy arrays
|
155
|
+
hessian_results = [h.numpy() for h in ipts_hessian_approximations]
|
156
|
+
# Extend the Hessian tensors shape to align with expected return type
|
157
|
+
# TODO: currently, only per-tensor Hessian is available for activation.
|
158
|
+
# Once implementing per-channel or per-element, this alignment needs to be verified and handled separately.
|
159
|
+
hessian_results = [h[..., np.newaxis] for h in hessian_results]
|
150
160
|
|
151
|
-
|
161
|
+
return hessian_results
|
152
162
|
|
153
|
-
else: # pragma: no cover
|
154
|
-
Logger.critical(f"{self.hessian_request.granularity} "
|
155
|
-
f"is not supported for Keras activation hessian\'s approximation scores calculator.")
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
+
from tensorflow import TensorShape
|
15
16
|
|
16
17
|
from model_compression_toolkit.core.common.hessian.hessian_scores_calculator import HessianScoresCalculator
|
17
18
|
|
@@ -77,3 +78,19 @@ class HessianScoresCalculatorKeras(HessianScoresCalculator):
|
|
77
78
|
"Unable to concatenate tensors for gradient calculation due to mismatched shapes along the first axis.") # pragma: no cover
|
78
79
|
|
79
80
|
return tf.concat(_r_tensors, axis=1)
|
81
|
+
|
82
|
+
def _generate_random_vectors_batch(self, shape: TensorShape) -> tf.Tensor:
|
83
|
+
"""
|
84
|
+
Generate a batch of random vectors for Hutchinson estimation using Rademacher distribution.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
shape: target shape.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
Random tensor.
|
91
|
+
"""
|
92
|
+
v = tf.random.uniform(shape=shape, minval=0, maxval=2, dtype=tf.int32)
|
93
|
+
v = tf.where(v == 0, -1, 1)
|
94
|
+
v = tf.cast(v, tf.float32)
|
95
|
+
return v
|
96
|
+
|
@@ -89,8 +89,7 @@ class WeightsHessianScoresCalculatorKeras(HessianScoresCalculatorKeras):
|
|
89
89
|
prev_mean_results = None
|
90
90
|
tensors_original_shape = []
|
91
91
|
for j in tqdm(range(self.num_iterations_for_approximation)): # Approximation iterations
|
92
|
-
|
93
|
-
v = tf.random.normal(shape=output.shape)
|
92
|
+
v = self._generate_random_vectors_batch(output.shape)
|
94
93
|
f_v = tf.reduce_sum(v * output)
|
95
94
|
|
96
95
|
for i, ipt_node in enumerate(self.hessian_request.target_nodes): # Per Interest point weights tensor
|
@@ -27,7 +27,11 @@ from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig
|
|
27
27
|
from model_compression_toolkit.gptq.common.gptq_constants import QUANT_PARAM_LEARNING_STR
|
28
28
|
from model_compression_toolkit.gptq.common.gptq_framework_implementation import GPTQFrameworkImplemantation
|
29
29
|
from model_compression_toolkit.gptq.common.gptq_graph import get_compare_points
|
30
|
+
from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
|
31
|
+
get_gradual_activation_quantizer_wrapper_factory
|
32
|
+
from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
|
30
33
|
from model_compression_toolkit.logger import Logger
|
34
|
+
from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
|
31
35
|
|
32
36
|
|
33
37
|
class GPTQTrainer(ABC):
|
@@ -64,6 +68,14 @@ class GPTQTrainer(ABC):
|
|
64
68
|
self.fw_impl = fw_impl
|
65
69
|
self.fw_info = fw_info
|
66
70
|
self.representative_data_gen_fn = representative_data_gen_fn
|
71
|
+
|
72
|
+
def _get_total_grad_steps():
|
73
|
+
return get_total_grad_steps(representative_data_gen_fn) * gptq_config.n_epochs
|
74
|
+
|
75
|
+
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(gptq_config,
|
76
|
+
_get_total_grad_steps,
|
77
|
+
self.fw_linear_annealing_scheduler)
|
78
|
+
|
67
79
|
# ----------------------------------------------
|
68
80
|
# Build two models and create compare nodes
|
69
81
|
# ----------------------------------------------
|
@@ -81,6 +93,52 @@ class GPTQTrainer(ABC):
|
|
81
93
|
f"an 'HessianInfoService' object must be provided, but received: {hessian_info_service}.") # pragma: no cover
|
82
94
|
self.hessian_service = hessian_info_service
|
83
95
|
|
96
|
+
self.reg_func = get_regularization(self.gptq_config,
|
97
|
+
_get_total_grad_steps,
|
98
|
+
self.fw_soft_quantizer_regularization,
|
99
|
+
self.fw_linear_annealing_scheduler)
|
100
|
+
self.loss_list = []
|
101
|
+
self.input_scale = 1
|
102
|
+
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
103
|
+
Logger.critical("Input scale mismatch between float and GPTQ networks. "
|
104
|
+
"Ensure both networks have matching input scales.") # pragma: no cover
|
105
|
+
else:
|
106
|
+
self.input_scale = self.gptq_user_info.input_scale
|
107
|
+
|
108
|
+
trainable_weights, trainable_bias, trainable_threshold = self.fw_get_gptq_trainable_parameters_fn(
|
109
|
+
self.fxp_model,
|
110
|
+
add_bias=self.gptq_config.train_bias)
|
111
|
+
self.flp_weights_list, self.fxp_weights_list = self.fw_get_weights_for_loss_fn(self.fxp_model)
|
112
|
+
|
113
|
+
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
114
|
+
self.fxp_weights_list)):
|
115
|
+
Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
|
116
|
+
"and the number of float and quantized weights for loss calculation. "
|
117
|
+
"Ensure all these elements align to proceed with GPTQ training.")
|
118
|
+
|
119
|
+
# In Keras we need to flatten the weights first before attaching the optimizer
|
120
|
+
if len(trainable_weights) > 0 and isinstance(trainable_weights[0], (list, tuple)):
|
121
|
+
trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
|
122
|
+
if len(trainable_bias) > 0 and isinstance(trainable_bias[0], (list, tuple)):
|
123
|
+
trainable_bias = [w for layer_weights in trainable_bias for w in layer_weights]
|
124
|
+
|
125
|
+
self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
|
126
|
+
trainable_bias,
|
127
|
+
trainable_threshold)
|
128
|
+
hessian_cfg = self.gptq_config.hessian_weights_config
|
129
|
+
|
130
|
+
self.has_params_to_train = np.sum(
|
131
|
+
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
132
|
+
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
|
133
|
+
|
134
|
+
if self.use_sample_layer_attention:
|
135
|
+
# normalization is currently not supported, make sure the config reflects it.
|
136
|
+
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
|
137
|
+
raise NotImplementedError()
|
138
|
+
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen_fn)
|
139
|
+
else:
|
140
|
+
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen_fn)
|
141
|
+
|
84
142
|
def get_optimizer_with_param(self,
|
85
143
|
flattened_trainable_weights: List[Any],
|
86
144
|
flattened_bias_weights: List[Any],
|
@@ -13,9 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
from typing import Any, Tuple, List
|
17
|
-
|
18
16
|
import tensorflow as tf
|
17
|
+
from typing import List, Tuple
|
19
18
|
|
20
19
|
|
21
20
|
def mse_loss(y: tf.Tensor, x: tf.Tensor, normalized: bool = True) -> tf.Tensor:
|
@@ -67,6 +66,40 @@ def multiple_tensors_mse_loss(y_list: List[tf.Tensor],
|
|
67
66
|
else:
|
68
67
|
return tf.reduce_mean(tf.stack(loss_values_list))
|
69
68
|
|
69
|
+
def sample_layer_attention_loss(y_list: List[tf.Tensor],
|
70
|
+
x_list: List[tf.Tensor],
|
71
|
+
fxp_w_list,
|
72
|
+
flp_w_list,
|
73
|
+
act_bn_mean,
|
74
|
+
act_bn_std,
|
75
|
+
loss_weights: Tuple[tf.Tensor]) -> tf.Tensor:
|
76
|
+
"""
|
77
|
+
Compute Sample Layer Attention loss between two lists of tensors using TensorFlow.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
y_list: First list of tensors.
|
81
|
+
x_list: Second list of tensors.
|
82
|
+
fxp_w_list, flp_w_list, act_bn_mean, act_bn_std: unused (needed to comply with the interface).
|
83
|
+
loss_weights: layer-sample attention scores (tuplle by the same length as the number of layers, where each element is a tf.Tensor vector of length of number of samples).
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
Sample Layer Attention loss (a scalar).
|
87
|
+
"""
|
88
|
+
loss = 0
|
89
|
+
layers_mean_w = []
|
90
|
+
loss_weights = tf.stack(loss_weights, axis=1)
|
91
|
+
|
92
|
+
for i, (y, x) in enumerate(zip(y_list, x_list)):
|
93
|
+
norm = tf.reduce_sum(tf.square(y - x), axis=1)
|
94
|
+
if len(norm.shape) > 1:
|
95
|
+
norm = tf.reduce_mean(tf.reshape(norm, [norm.shape[0], -1]), axis=1)
|
96
|
+
w = loss_weights[:, i]
|
97
|
+
loss += tf.reduce_mean(w * norm)
|
98
|
+
layers_mean_w.append(tf.reduce_mean(w))
|
99
|
+
|
100
|
+
loss = loss / tf.reduce_max(tf.stack(layers_mean_w))
|
101
|
+
return loss
|
102
|
+
|
70
103
|
|
71
104
|
def mse_loss_per_tensor(y: tf.Tensor,
|
72
105
|
x: tf.Tensor,
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
from typing import Callable, List, Tuple, Union
|
15
|
+
from typing import Callable, List, Tuple, Union, Generator
|
16
16
|
|
17
17
|
import tensorflow as tf
|
18
18
|
from keras import Model
|
@@ -20,11 +20,13 @@ from packaging import version
|
|
20
20
|
from tensorflow.keras.layers import Layer
|
21
21
|
from tqdm import tqdm
|
22
22
|
|
23
|
-
from model_compression_toolkit.core.common.hessian import HessianInfoService
|
23
|
+
from model_compression_toolkit.core.common.hessian import HessianInfoService, HessianScoresGranularity
|
24
24
|
# As from Tensorflow 2.6, keras is a separate package and some classes should be imported differently.
|
25
25
|
from model_compression_toolkit.core.common.user_info import UserInformation
|
26
26
|
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
|
27
|
-
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader
|
27
|
+
from model_compression_toolkit.core.keras.data_util import data_gen_to_dataloader, \
|
28
|
+
FixedSampleInfoDataset, FixedTFDataset, create_tf_dataloader, TFDatasetFromGenerator, \
|
29
|
+
IterableSampleWithConstInfoDataset
|
28
30
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
29
31
|
from model_compression_toolkit.gptq.common.gradual_activation_quantization import \
|
30
32
|
get_gradual_activation_quantizer_wrapper_factory
|
@@ -83,13 +85,10 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
83
85
|
|
84
86
|
"""
|
85
87
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
# which occurs in the base constructor.
|
91
|
-
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
|
92
|
-
gptq_config, _get_total_grad_steps, KerasLinearAnnealingScheduler)
|
88
|
+
self.fw_soft_quantizer_regularization = SoftQuantizerRegularization
|
89
|
+
self.fw_linear_annealing_scheduler = KerasLinearAnnealingScheduler
|
90
|
+
self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
|
91
|
+
self.fw_get_weights_for_loss_fn = get_weights_for_loss
|
93
92
|
|
94
93
|
super().__init__(graph_float,
|
95
94
|
graph_quant,
|
@@ -99,53 +98,106 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
99
98
|
representative_data_gen_fn=representative_data_gen,
|
100
99
|
hessian_info_service=hessian_info_service)
|
101
100
|
|
102
|
-
self.loss_list = []
|
103
|
-
self.input_scale = 1
|
104
|
-
|
105
|
-
trainable_weights, bias_weights, trainable_threshold = get_gptq_trainable_parameters(
|
106
|
-
self.fxp_model,
|
107
|
-
fw_info,
|
108
|
-
add_bias=gptq_config.train_bias)
|
109
|
-
|
110
|
-
self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
|
111
|
-
|
112
|
-
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
113
|
-
self.fxp_weights_list)):
|
114
|
-
Logger.critical("Mismatch in the number of comparison points, layers with trainable weights, "
|
115
|
-
"and the number of float and quantized weights for loss calculation. "
|
116
|
-
"Ensure all these elements align to proceed with GPTQ training.")
|
117
|
-
|
118
|
-
flattened_trainable_weights = [w for layer_weights in trainable_weights for w in layer_weights]
|
119
|
-
flattened_bias_weights = [w for layer_weights in bias_weights for w in layer_weights]
|
120
|
-
trainable_quantization_parameters = trainable_threshold
|
121
|
-
self.optimizer_with_param = self.get_optimizer_with_param(flattened_trainable_weights,
|
122
|
-
flattened_bias_weights,
|
123
|
-
trainable_quantization_parameters)
|
124
|
-
self.has_params_to_train = np.sum(
|
125
|
-
[len(optimizer_params_tuple[1]) for optimizer_params_tuple in self.optimizer_with_param]) > 0
|
126
|
-
|
127
|
-
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
128
|
-
Logger.critical("Input scale mismatch detected between the float model and the GPTQ model. "
|
129
|
-
"Confirm that the input scales for both models are correctly configured and aligned.") # pragma: no cover
|
130
|
-
else:
|
131
|
-
self.input_scale = self.gptq_user_info.input_scale
|
132
101
|
|
133
|
-
|
102
|
+
def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
|
103
|
+
"""
|
104
|
+
Computes Sample-Layer Attention score and builds a train dataloader in TensorFlow.
|
134
105
|
|
135
|
-
|
136
|
-
|
137
|
-
SoftQuantizerRegularization,
|
138
|
-
KerasLinearAnnealingScheduler)
|
106
|
+
Args:
|
107
|
+
data_gen_fn: function for representative dataset generation.
|
139
108
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
109
|
+
Returns:
|
110
|
+
TensorFlow dataset yielding three outputs - samples, weights for the distillation loss,
|
111
|
+
and weights for regularization.
|
112
|
+
"""
|
113
|
+
# Create a fixed dataset
|
114
|
+
fixed_dataset = FixedTFDataset(data_gen_fn)
|
115
|
+
orig_batch_size = fixed_dataset.orig_batch_size
|
116
|
+
|
117
|
+
# Prepare a separate loader for computing hessians over the whole dataset
|
118
|
+
hess_data_loader = create_tf_dataloader(
|
119
|
+
fixed_dataset,
|
120
|
+
batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size,
|
121
|
+
shuffle=False
|
122
|
+
)
|
123
|
+
|
124
|
+
# Prepare request for Hessian computation
|
125
|
+
request = self._build_hessian_request(
|
126
|
+
granularity=HessianScoresGranularity.PER_OUTPUT_CHANNEL,
|
127
|
+
data_loader=hess_data_loader,
|
128
|
+
n_samples=None
|
129
|
+
)
|
130
|
+
layers_hessians = self.hessian_service.fetch_hessian(request, force_compute=True)
|
131
|
+
|
132
|
+
# Compute SLA score defined as max over elements
|
133
|
+
layers_hessians = {
|
134
|
+
layer: tf.convert_to_tensor(tf.reduce_max(hess, axis=tuple(range(1, len(hess.shape))))) for layer, hess in layers_hessians.items()
|
135
|
+
}
|
136
|
+
|
137
|
+
# Stack hessians for comparison points
|
138
|
+
hessians_tensor = tf.stack([layers_hessians[layer.name] for layer in self.compare_points])
|
139
|
+
assert hessians_tensor.shape[0] == len(self.compare_points)
|
140
|
+
loss_weights = list(hessians_tensor.numpy()) # Convert to a list for compatibility
|
141
|
+
|
142
|
+
# Prepare final dataset with samples and loss weights
|
143
|
+
sla_train_dataset = FixedSampleInfoDataset(fixed_dataset.samples, loss_weights)
|
144
|
+
|
145
|
+
# Calculate regularization weights as mean across samples
|
146
|
+
reg_weights = tf.reduce_mean(hessians_tensor, axis=1)
|
147
|
+
|
148
|
+
# Define a collate function to add regularization weights to each batch
|
149
|
+
def collate_fn(samples_with_loss_weights):
|
150
|
+
return *samples_with_loss_weights, reg_weights
|
151
|
+
|
152
|
+
# Create final dataset using the new dataloader with collate_fn
|
153
|
+
final_dataset = create_tf_dataloader(
|
154
|
+
dataset=sla_train_dataset,
|
155
|
+
batch_size=orig_batch_size,
|
156
|
+
shuffle=True,
|
157
|
+
collate_fn=collate_fn
|
158
|
+
)
|
159
|
+
|
160
|
+
return final_dataset
|
161
|
+
|
162
|
+
def _prepare_train_dataloader_for_non_sla(self,
|
163
|
+
data_gen_fn: Callable[[], Generator]) -> tf.data.Dataset:
|
164
|
+
"""
|
165
|
+
Prepares a train dataloader for non-SLA tasks.
|
146
166
|
|
167
|
+
Args:
|
168
|
+
data_gen_fn: Factory for representative dataset generator.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
A `tf.data.Dataset` yielding samples with loss weights and regularization weights.
|
172
|
+
"""
|
173
|
+
# Step 1: Create a dataset from the generator
|
174
|
+
dataset = TFDatasetFromGenerator(data_gen_fn)
|
147
175
|
num_nodes = len(self.compare_points)
|
148
|
-
|
176
|
+
|
177
|
+
# Step 2: Compute loss weights
|
178
|
+
if self.gptq_config.hessian_weights_config:
|
179
|
+
hessian_dataset = create_tf_dataloader(dataset=dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
|
180
|
+
hessian_weights = self.compute_hessian_based_weights(hessian_dataset)
|
181
|
+
loss_weights = tf.convert_to_tensor(hessian_weights, dtype=tf.float32)
|
182
|
+
else:
|
183
|
+
loss_weights = tf.ones(num_nodes, dtype=tf.float32) / num_nodes
|
184
|
+
|
185
|
+
# Step 3: Create a dataset with samples and loss weights
|
186
|
+
augmented_dataset = IterableSampleWithConstInfoDataset(dataset.dataset, loss_weights)
|
187
|
+
|
188
|
+
# Step 4: Add constant regularization weights
|
189
|
+
reg_weights = tf.ones(num_nodes, dtype=tf.float32)
|
190
|
+
|
191
|
+
def collate_fn(batch):
|
192
|
+
samples, loss_weights = batch
|
193
|
+
return samples, loss_weights, reg_weights
|
194
|
+
|
195
|
+
# Step 5: Create a tf.data.Dataset with collate_fn
|
196
|
+
train_dataloader = create_tf_dataloader(augmented_dataset,
|
197
|
+
batch_size=dataset.orig_batch_size,
|
198
|
+
collate_fn=collate_fn)
|
199
|
+
|
200
|
+
return train_dataloader
|
149
201
|
|
150
202
|
def _is_gptq_weights_trainable(self,
|
151
203
|
node: common.BaseNode) -> bool:
|
@@ -226,9 +278,13 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
226
278
|
|
227
279
|
return gptq_model, gptq_user_info
|
228
280
|
|
229
|
-
def compute_gradients(self,
|
281
|
+
def compute_gradients(self,
|
282
|
+
in_y_float: List[tf.Tensor],
|
283
|
+
input_data: List[np.ndarray],
|
230
284
|
in_optimizer_with_param: List,
|
231
|
-
training=True
|
285
|
+
training=True,
|
286
|
+
distill_loss_weights=None,
|
287
|
+
reg_weights=None) -> Tuple[tf.Tensor, List[tf.Tensor]]:
|
232
288
|
"""
|
233
289
|
Get outputs from both teacher and student networks. Compute the observed error,
|
234
290
|
and use it to compute the gradients and applying them to the student weights.
|
@@ -253,9 +309,9 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
253
309
|
self.flp_weights_list,
|
254
310
|
self.compare_points_mean,
|
255
311
|
self.compare_points_std,
|
256
|
-
|
312
|
+
distill_loss_weights)
|
257
313
|
|
258
|
-
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor)
|
314
|
+
reg_value = self.reg_func(self.fxp_model, self.gptq_config.regularization_factor, reg_weights)
|
259
315
|
|
260
316
|
loss_value += reg_value
|
261
317
|
|
@@ -279,14 +335,19 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
279
335
|
# Training loop
|
280
336
|
# ----------------------------------------------
|
281
337
|
if self.has_params_to_train:
|
282
|
-
self.micro_training_loop(
|
283
|
-
compute_gradients,
|
338
|
+
self.micro_training_loop(compute_gradients,
|
284
339
|
self.optimizer_with_param,
|
285
340
|
self.gptq_config.n_epochs,
|
286
341
|
True)
|
287
342
|
|
288
343
|
@tf.function
|
289
|
-
def nano_training_step(self,
|
344
|
+
def nano_training_step(self,
|
345
|
+
input_data,
|
346
|
+
in_compute_gradients,
|
347
|
+
in_optimizer_with_param,
|
348
|
+
is_training,
|
349
|
+
distill_loss_weights,
|
350
|
+
reg_weights):
|
290
351
|
"""
|
291
352
|
This function run part of the training step, wrapped by a tf.function for acceleration.
|
292
353
|
Args:
|
@@ -303,12 +364,15 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
303
364
|
# run float model
|
304
365
|
y_float = self.float_model(input_data)
|
305
366
|
# rung quantized model and calculate loss & gradients
|
306
|
-
loss_value_step, grads = in_compute_gradients(y_float,
|
307
|
-
|
367
|
+
loss_value_step, grads = in_compute_gradients(y_float,
|
368
|
+
input_data,
|
369
|
+
in_optimizer_with_param,
|
370
|
+
training=is_training,
|
371
|
+
distill_loss_weights=distill_loss_weights,
|
372
|
+
reg_weights=reg_weights)
|
308
373
|
return loss_value_step, grads
|
309
374
|
|
310
375
|
def micro_training_loop(self,
|
311
|
-
data_function: Callable,
|
312
376
|
in_compute_gradients: Callable,
|
313
377
|
in_optimizer_with_param: List[Tuple[tf.keras.optimizers.Optimizer, List[tf.Tensor]]],
|
314
378
|
n_epochs: int,
|
@@ -316,7 +380,6 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
316
380
|
"""
|
317
381
|
This function run a micro training loop on given set of parameters.
|
318
382
|
Args:
|
319
|
-
data_function: A callable function that give a batch of samples.
|
320
383
|
in_compute_gradients: A callable function that compute the gradients.
|
321
384
|
in_optimizer_with_param: A list of optimizer classes to update with the corresponding parameters.
|
322
385
|
n_epochs: Number of update iterations of representative dataset.
|
@@ -327,12 +390,19 @@ class KerasGPTQTrainer(GPTQTrainer):
|
|
327
390
|
"""
|
328
391
|
with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
|
329
392
|
for _ in epochs_pbar:
|
330
|
-
with tqdm(
|
393
|
+
with tqdm(self.train_dataloader, position=1, leave=False) as data_pbar:
|
331
394
|
for data in data_pbar:
|
332
|
-
input_data = [d * self.input_scale for d in data]
|
333
395
|
|
334
|
-
|
335
|
-
|
396
|
+
input_data, distill_loss_weights, reg_weight = data
|
397
|
+
|
398
|
+
input_data = [d * self.input_scale for d in input_data]
|
399
|
+
|
400
|
+
loss_value_step, grads = self.nano_training_step(input_data,
|
401
|
+
in_compute_gradients,
|
402
|
+
in_optimizer_with_param,
|
403
|
+
is_training,
|
404
|
+
distill_loss_weights,
|
405
|
+
reg_weight)
|
336
406
|
# Run one step of gradient descent by updating
|
337
407
|
# the value of the variables to minimize the loss.
|
338
408
|
for i, (o, p) in enumerate(in_optimizer_with_param):
|
@@ -16,7 +16,6 @@
|
|
16
16
|
import tensorflow as tf
|
17
17
|
from typing import Tuple, List
|
18
18
|
from model_compression_toolkit.core.keras.constants import USE_BIAS
|
19
|
-
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
20
19
|
from tensorflow.keras.models import Model
|
21
20
|
from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
|
22
21
|
from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
|
@@ -26,7 +25,6 @@ from model_compression_toolkit.trainable_infrastructure.common.base_trainable_qu
|
|
26
25
|
|
27
26
|
|
28
27
|
def get_gptq_trainable_parameters(fxp_model: Model,
|
29
|
-
fw_info: FrameworkInfo,
|
30
28
|
add_bias: bool = False) -> (
|
31
29
|
List[tf.Variable], List[tf.Variable], List[tf.Variable]):
|
32
30
|
"""
|
@@ -34,7 +32,6 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
34
32
|
|
35
33
|
Args:
|
36
34
|
fxp_model: Model to get its trainable parameters.
|
37
|
-
fw_info: Framework information needed for keras kernel ops list.
|
38
35
|
add_bias: Whether to include biases of the model (if there are) or not.
|
39
36
|
|
40
37
|
Returns:
|
@@ -60,7 +57,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
|
|
60
57
|
trainable_threshold.extend(quantizer_trainable_threshold)
|
61
58
|
|
62
59
|
if add_bias:
|
63
|
-
kernel_ops_attrs =
|
60
|
+
kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
|
64
61
|
use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
|
65
62
|
and layer.layer.get_config().get(USE_BIAS)
|
66
63
|
if use_bias is not None and use_bias and layer.layer.bias is not None:
|
@@ -19,7 +19,7 @@ from packaging import version
|
|
19
19
|
|
20
20
|
from model_compression_toolkit.core.common.visualization.tensorboard_writer import init_tensorboard_writer
|
21
21
|
from model_compression_toolkit.gptq.common.gptq_constants import REG_DEFAULT, LR_DEFAULT, LR_REST_DEFAULT, \
|
22
|
-
LR_BIAS_DEFAULT, GPTQ_MOMENTUM
|
22
|
+
LR_BIAS_DEFAULT, GPTQ_MOMENTUM, REG_DEFAULT_SLA
|
23
23
|
from model_compression_toolkit.logger import Logger
|
24
24
|
from model_compression_toolkit.constants import TENSORFLOW, ACT_HESSIAN_DEFAULT_BATCH_SIZE, GPTQ_HESSIAN_NUM_SAMPLES
|
25
25
|
from model_compression_toolkit.verify_packages import FOUND_TF
|
@@ -42,7 +42,7 @@ if FOUND_TF:
|
|
42
42
|
from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
|
43
43
|
from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
|
44
44
|
from tensorflow.keras.models import Model
|
45
|
-
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss
|
45
|
+
from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
|
46
46
|
from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
|
47
47
|
from model_compression_toolkit.exporter.model_wrapper import get_exportable_keras_model
|
48
48
|
from model_compression_toolkit import get_target_platform_capabilities
|
@@ -61,11 +61,12 @@ if FOUND_TF:
|
|
61
61
|
def get_keras_gptq_config(n_epochs: int,
|
62
62
|
optimizer: OptimizerV2 = None,
|
63
63
|
optimizer_rest: OptimizerV2 = None,
|
64
|
-
loss: Callable =
|
64
|
+
loss: Callable = None,
|
65
65
|
log_function: Callable = None,
|
66
66
|
use_hessian_based_weights: bool = True,
|
67
|
-
regularization_factor: float =
|
67
|
+
regularization_factor: float = None,
|
68
68
|
hessian_batch_size: int = ACT_HESSIAN_DEFAULT_BATCH_SIZE,
|
69
|
+
use_hessian_sample_attention: bool = False,
|
69
70
|
gradual_activation_quantization: Union[bool, GradualActivationQuantizationConfig] = False) -> GradientPTQConfig:
|
70
71
|
"""
|
71
72
|
Create a GradientPTQConfig instance for Keras models.
|
@@ -79,6 +80,7 @@ if FOUND_TF:
|
|
79
80
|
use_hessian_based_weights (bool): Whether to use Hessian-based weights for weighted average loss.
|
80
81
|
regularization_factor (float): A floating point number that defines the regularization factor.
|
81
82
|
hessian_batch_size (int): Batch size for Hessian computation in Hessian-based weights GPTQ.
|
83
|
+
use_hessian_sample_attention (bool): whether to use Sample-Layer Attention score for weighted loss.
|
82
84
|
gradual_activation_quantization (bool, GradualActivationQuantizationConfig): If False, GradualActivationQuantization is disabled. If True, GradualActivationQuantization is enabled with the default settings. GradualActivationQuantizationConfig object can be passed to use non-default settings.
|
83
85
|
|
84
86
|
returns:
|
@@ -105,9 +107,25 @@ if FOUND_TF:
|
|
105
107
|
"""
|
106
108
|
optimizer = optimizer or tf.keras.optimizers.Adam(learning_rate=LR_DEFAULT)
|
107
109
|
optimizer_rest = optimizer_rest or tf.keras.optimizers.Adam(learning_rate=LR_REST_DEFAULT)
|
110
|
+
bias_optimizer = tf.keras.optimizers.SGD(learning_rate=LR_BIAS_DEFAULT, momentum=GPTQ_MOMENTUM)
|
108
111
|
|
109
|
-
|
110
|
-
|
112
|
+
if regularization_factor is None:
|
113
|
+
regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
|
114
|
+
|
115
|
+
loss = loss or GPTQMultipleTensorsLoss()
|
116
|
+
hessian_weights_config = None
|
117
|
+
if use_hessian_sample_attention:
|
118
|
+
if not use_hessian_based_weights: # pragma: no cover
|
119
|
+
raise ValueError('use_hessian_based_weights must be set to True in order to use Sample Layer Attention.')
|
120
|
+
|
121
|
+
hessian_weights_config = GPTQHessianScoresConfig(per_sample=True,
|
122
|
+
hessians_num_samples=None,
|
123
|
+
hessian_batch_size=hessian_batch_size)
|
124
|
+
loss = loss or sample_layer_attention_loss
|
125
|
+
elif use_hessian_based_weights:
|
126
|
+
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
|
127
|
+
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
|
128
|
+
hessian_batch_size=hessian_batch_size)
|
111
129
|
|
112
130
|
if isinstance(gradual_activation_quantization, bool):
|
113
131
|
gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
|
@@ -117,11 +135,6 @@ if FOUND_TF:
|
|
117
135
|
raise TypeError(f'gradual_activation_quantization argument should be bool or '
|
118
136
|
f'GradualActivationQuantizationConfig, received {type(gradual_activation_quantization)}')
|
119
137
|
|
120
|
-
hessian_weights_config = None
|
121
|
-
if use_hessian_based_weights:
|
122
|
-
hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
|
123
|
-
hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
|
124
|
-
hessian_batch_size=hessian_batch_size)
|
125
138
|
return GradientPTQConfig(n_epochs=n_epochs,
|
126
139
|
optimizer=optimizer,
|
127
140
|
optimizer_rest=optimizer_rest,
|
@@ -40,30 +40,42 @@ class SoftQuantizerRegularization:
|
|
40
40
|
self.count_iter = tf.Variable(0.)
|
41
41
|
|
42
42
|
|
43
|
-
def __call__(self, model: Model, entropy_reg: float):
|
43
|
+
def __call__(self, model: Model, entropy_reg: float, layer_weights: tf.Tensor):
|
44
44
|
"""
|
45
45
|
Returns the soft quantizer regularization value for SoftRounding.
|
46
46
|
|
47
47
|
Args:
|
48
48
|
model: A model to be quantized with SoftRounding.
|
49
49
|
entropy_reg: Entropy value to scale the quantizer regularization.
|
50
|
+
layer_weights: a vector of layers weights.
|
50
51
|
|
51
52
|
Returns: Regularization value.
|
52
53
|
"""
|
53
|
-
|
54
|
+
layers = [l for l in model.layers if isinstance(l, KerasTrainableQuantizationWrapper)]
|
55
|
+
|
56
|
+
if layer_weights.shape[0] != len(layers):
|
57
|
+
raise ValueError(f'Expected weights.shape[0] to be {len(layers)}, '
|
58
|
+
f'received shape {layer_weights.shape}.') # pragma: no cover
|
59
|
+
|
54
60
|
b = self.beta_scheduler(self.count_iter.value())
|
55
|
-
for layer in model.layers:
|
56
|
-
if isinstance(layer, KerasTrainableQuantizationWrapper):
|
57
|
-
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
58
|
-
fw_info=DEFAULT_KERAS_INFO)
|
59
61
|
|
60
|
-
|
61
|
-
|
62
|
+
max_w = tf.reduce_max(layer_weights)
|
63
|
+
|
64
|
+
# Initialize reg to zero
|
65
|
+
reg = tf.constant(0.0, dtype=tf.float32)
|
66
|
+
|
67
|
+
# Compute the regularization term without concatenating
|
68
|
+
for i, layer in enumerate(layers):
|
69
|
+
kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
|
70
|
+
fw_info=DEFAULT_KERAS_INFO)
|
71
|
+
|
72
|
+
st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
|
62
73
|
|
63
|
-
|
74
|
+
soft_loss = tf.reduce_sum(1 - tf.pow(tf.math.abs(st - 0.5) * 2, b))
|
75
|
+
reg += layer_weights[i] * soft_loss
|
64
76
|
|
65
|
-
|
66
|
-
|
77
|
+
# Normalize reg by max_w
|
78
|
+
reg = reg / max_w
|
67
79
|
|
68
80
|
self.count_iter.assign_add(1.0)
|
69
81
|
|
@@ -21,9 +21,6 @@ from torch.nn import Module
|
|
21
21
|
from torch.utils.data import DataLoader
|
22
22
|
from tqdm import tqdm
|
23
23
|
|
24
|
-
from model_compression_toolkit.gptq.common.gradual_activation_quantization import get_gradual_activation_quantizer_wrapper_factory
|
25
|
-
from model_compression_toolkit.gptq.common.regularization_factory import get_regularization
|
26
|
-
|
27
24
|
from model_compression_toolkit.core.common import Graph, BaseNode
|
28
25
|
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
|
29
26
|
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
|
@@ -41,7 +38,6 @@ from model_compression_toolkit.gptq.pytorch.graph_info import get_gptq_trainable
|
|
41
38
|
from model_compression_toolkit.gptq.pytorch.quantizer.quantization_builder import quantization_builder
|
42
39
|
|
43
40
|
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
|
44
|
-
from model_compression_toolkit.trainable_infrastructure.common.util import get_total_grad_steps
|
45
41
|
from model_compression_toolkit.trainable_infrastructure.pytorch.annealing_schedulers import PytorchLinearAnnealingScheduler
|
46
42
|
from model_compression_toolkit.gptq.pytorch.quantizer.soft_rounding.soft_quantizer_reg import SoftQuantizerRegularization as PytorchSoftQuantizerRegularization
|
47
43
|
|
@@ -76,13 +72,10 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
76
72
|
representative_data_gen: Dataset to use for inputs of the models.
|
77
73
|
hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
|
78
74
|
"""
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
# must be set prior to model building in the base class constructor
|
84
|
-
self.gradual_act_quantizer_wrapper_factory = get_gradual_activation_quantizer_wrapper_factory(
|
85
|
-
gptq_config, _get_total_grad_steps, PytorchLinearAnnealingScheduler)
|
75
|
+
self.fw_soft_quantizer_regularization = PytorchSoftQuantizerRegularization
|
76
|
+
self.fw_linear_annealing_scheduler = PytorchLinearAnnealingScheduler
|
77
|
+
self.fw_get_gptq_trainable_parameters_fn = get_gptq_trainable_parameters
|
78
|
+
self.fw_get_weights_for_loss_fn = get_weights_for_loss
|
86
79
|
|
87
80
|
super().__init__(graph_float,
|
88
81
|
graph_quant,
|
@@ -92,40 +85,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
|
|
92
85
|
representative_data_gen_fn=representative_data_gen,
|
93
86
|
hessian_info_service=hessian_info_service)
|
94
87
|
|
95
|
-
self.loss_list = []
|
96
|
-
self.input_scale = 1
|
97
|
-
if self.float_user_info.input_scale != self.gptq_user_info.input_scale:
|
98
|
-
Logger.critical("Input scale mismatch between float and GPTQ networks. "
|
99
|
-
"Ensure both networks have matching input scales.") # pragma: no cover
|
100
|
-
else:
|
101
|
-
self.input_scale = self.gptq_user_info.input_scale
|
102
|
-
|
103
|
-
trainable_weights, trainable_bias, trainable_threshold = get_gptq_trainable_parameters(
|
104
|
-
self.fxp_model,
|
105
|
-
add_bias=self.gptq_config.train_bias)
|
106
|
-
|
107
|
-
self.flp_weights_list, self.fxp_weights_list = get_weights_for_loss(self.fxp_model)
|
108
|
-
if not (len(self.compare_points) == len(trainable_weights) == len(self.flp_weights_list) == len(
|
109
|
-
self.fxp_weights_list)):
|
110
|
-
Logger.critical("GPTQ: Number of comparison points, layers with trainable weights, "
|
111
|
-
"and float vs. quantized weights for loss calculation do not match. "
|
112
|
-
"Verify consistency across these parameters for successful GPTQ training.")
|
113
|
-
|
114
|
-
self.optimizer_with_param = self.get_optimizer_with_param(trainable_weights,
|
115
|
-
trainable_bias,
|
116
|
-
trainable_threshold)
|
117
|
-
hessian_cfg = self.gptq_config.hessian_weights_config
|
118
|
-
|
119
|
-
self.use_sample_layer_attention = hessian_cfg and hessian_cfg.per_sample
|
120
|
-
if self.use_sample_layer_attention:
|
121
|
-
# normalization is currently not supported, make sure the config reflects it.
|
122
|
-
if hessian_cfg.norm_scores or hessian_cfg.log_norm or hessian_cfg.scale_log_norm:
|
123
|
-
raise NotImplementedError()
|
124
|
-
self.train_dataloader = self._prepare_train_dataloader_sla(representative_data_gen)
|
125
|
-
else:
|
126
|
-
self.train_dataloader = self._prepare_train_dataloader_for_non_sla(representative_data_gen)
|
127
|
-
|
128
|
-
self.reg_func = get_regularization(self.gptq_config, _get_total_grad_steps, PytorchSoftQuantizerRegularization, PytorchLinearAnnealingScheduler)
|
129
88
|
|
130
89
|
def _prepare_train_dataloader_sla(self, data_gen_fn: Callable[[], Generator]) -> DataLoader:
|
131
90
|
"""
|
{mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20241204.524.dist-info → mct_nightly-2.2.0.20241206.524.dist-info}/top_level.txt
RENAMED
File without changes
|