mct-nightly 2.2.0.20250108.523__py3-none-any.whl → 2.2.0.20250109.528__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20250108.523
3
+ Version: 2.2.0.20250109.528
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -23,6 +23,12 @@ Requires-Dist: scipy
23
23
  Requires-Dist: protobuf
24
24
  Requires-Dist: mct-quantizers==1.5.2
25
25
  Requires-Dist: pydantic<2.0
26
+ Dynamic: classifier
27
+ Dynamic: description
28
+ Dynamic: description-content-type
29
+ Dynamic: requires-dist
30
+ Dynamic: requires-python
31
+ Dynamic: summary
26
32
 
27
33
  <div align="center" markdown="1">
28
34
  <p>
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=T9vJvHgnVfRm0QlXM_rI0gN4vrj4U14twllXd5p0Irs,1573
1
+ model_compression_toolkit/__init__.py,sha256=5vDy_YtqV8rtgpIP3dBPtIxugUkeVPFJU05OH1jkhdw,1573
2
2
  model_compression_toolkit/constants.py,sha256=i_R6uXBfO1ph_X6DNJych2x59SUojfJbn7dNjs_mZnc,3846
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -155,7 +155,7 @@ model_compression_toolkit/core/common/visualization/tensorboard_writer.py,sha256
155
155
  model_compression_toolkit/core/keras/__init__.py,sha256=mjbqLD-KcG3eNeCYpu1GBS7VclGVOQ63x2p6mAAuba4,698
156
156
  model_compression_toolkit/core/keras/constants.py,sha256=dh4elQWt6Q6NYRht5k5RiiOcnLAq1v0MMBCJqMJzzFk,3225
157
157
  model_compression_toolkit/core/keras/custom_layer_validation.py,sha256=f-b14wuiIgitBe7d0MmofYhDCTO3IhwJgwrh-Hq_t_U,1192
158
- model_compression_toolkit/core/keras/data_util.py,sha256=HQj3-GP5oT5JHpYt80mtKhZjTCvKYs6c3Ll0txEgKHQ,6892
158
+ model_compression_toolkit/core/keras/data_util.py,sha256=-fqhXTzlA3RybWp0M5phPkzVbSJ2vPLrjFcCazWMYHk,7300
159
159
  model_compression_toolkit/core/keras/default_framework_info.py,sha256=PYcER89eEXjKtR0T7-2Y4f7cckqoD5OQbpHePoRkMec,5030
160
160
  model_compression_toolkit/core/keras/keras_implementation.py,sha256=HwbIR7x4t-TBNbWHVvVNFk8z-KFt6zM0LWAUXQuNZrk,31753
161
161
  model_compression_toolkit/core/keras/keras_model_validation.py,sha256=1wNV2clFdC9BzIELRLSO2uKf0xqjLqlkTJudwtCeaJk,1722
@@ -362,9 +362,9 @@ model_compression_toolkit/gptq/common/regularization_factory.py,sha256=hyunpXepV
362
362
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
363
363
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
364
364
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=2hzWzsbuVd5XcL85NM57YeOyHxRY0qMArKn8NvQ1UWw,7643
365
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=0WGiP7Gs4xX3FBs1PNaZ7w3hWRigwQXqYjBrs_-x32o,23241
365
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=km9tcuugOkRvprGXQZrsq_GPtA3-7Du_-rnbR_Gyups,23228
366
366
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=zwoeHX67nJJ5-zYLjzvMXS9TLsy9BsizARbZiDVjVSA,4473
367
- model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=meRKqpzZe2Irf21L_rN_mkr5dqPTJHzfSFBeqv4Csp4,18536
367
+ model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=jUAjkIszziedftaQBSmjEL6tYEYpHhlFpSgw2X9OTf4,18672
368
368
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
369
369
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=Rbl9urzkmACvVxICSEyJ02qFOBxWK0UQWtysFJzBVZw,4899
370
370
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
@@ -380,7 +380,7 @@ model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=_07Zx_43bnNokwR5S8phI
380
380
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
381
381
  model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=WtehnyiYXdUXf8-uNpV0mdsalF7YF7eKnL7tcFrzZoE,19549
382
382
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=4mVM-VvnBaA64ACVdOe6wTGHdMSa2UTLIUe7nACLcdo,4008
383
- model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=kMSq9mrpcgMBRgrEKfMBHaJG6HhGRYnuiDzF4ofckwo,16581
383
+ model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=HSFpx6JgjxGhU-0jA0z85sOOgSjCq6gzDOSkmuksZVE,16713
384
384
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
385
385
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=fKg-PNOhGBiL-4eySS9Fyw0GkA76Pq8jT_HbJuJ8iZU,4143
386
386
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
@@ -525,8 +525,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
525
525
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=3jNiV5Z4BVw9cEWuLKNOlLuLdr0EMuKg6eYnSiAq3LU,3952
526
526
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
527
527
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
528
- mct_nightly-2.2.0.20250108.523.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
529
- mct_nightly-2.2.0.20250108.523.dist-info/METADATA,sha256=riYLks2VpIMjq7W0UIbOGVmX68cYfNGXjl04SFFNSnE,26461
530
- mct_nightly-2.2.0.20250108.523.dist-info/WHEEL,sha256=A3WOREP4zgxI0fKrHUG8DC8013e3dK3n7a6HDbcEIwE,91
531
- mct_nightly-2.2.0.20250108.523.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
- mct_nightly-2.2.0.20250108.523.dist-info/RECORD,,
528
+ mct_nightly-2.2.0.20250109.528.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
529
+ mct_nightly-2.2.0.20250109.528.dist-info/METADATA,sha256=AYbks8Hsbv8a3bBOMjAhG2oNxrXPfYTn8TEnOUj3KjI,26601
530
+ mct_nightly-2.2.0.20250109.528.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
531
+ mct_nightly-2.2.0.20250109.528.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
+ mct_nightly-2.2.0.20250109.528.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.7.0)
2
+ Generator: setuptools (75.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20250108.000523"
30
+ __version__ = "2.2.0.20250109.000528"
@@ -12,11 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from typing import Generator, Callable
16
-
17
- import tensorflow as tf
18
-
19
- from model_compression_toolkit.core.keras.tf_tensor_numpy import to_tf_tensor
20
15
 
21
16
  import tensorflow as tf
22
17
  from typing import Callable, Generator, Sequence, Any
@@ -58,7 +53,6 @@ def flat_gen_fn(data_gen_fn: Callable[[], Generator]):
58
53
 
59
54
  return gen
60
55
 
61
-
62
56
  class TFDatasetFromGenerator:
63
57
  """
64
58
  TensorFlow dataset from a data generator function, batched to a specified size.
@@ -77,15 +71,15 @@ class TFDatasetFromGenerator:
77
71
 
78
72
  # TFDatasetFromGenerator flattens the dataset, thus we ignore the batch dimension
79
73
  output_signature = get_tensor_spec(inputs, ignore_batch_dim=True)
80
- self.dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)
74
+ self.tf_dataset = tf.data.Dataset.from_generator(flat_gen_fn(data_gen_fn), output_signature=output_signature)
81
75
 
82
76
  def __iter__(self):
83
- return iter(self.dataset)
77
+ return iter(self.tf_dataset)
84
78
 
85
79
  def __len__(self):
86
80
  """ Returns the number of batches. """
87
81
  if self._size is None:
88
- self._size = sum(1 for _ in self.dataset)
82
+ self._size = sum(1 for _ in self.tf_dataset)
89
83
  return self._size
90
84
 
91
85
 
@@ -116,6 +110,12 @@ class FixedTFDataset:
116
110
  raise ValueError(f'Not enough samples to create a dataset with {n_samples} samples')
117
111
  self.samples = samples
118
112
 
113
+ # Use from_generator to keep tuples intact
114
+ self.tf_dataset = tf.data.Dataset.from_generator(
115
+ lambda: iter(self.samples),
116
+ output_signature=tuple(tf.TensorSpec(shape=sample.shape, dtype=sample.dtype) for sample in self.samples[0])
117
+ )
118
+
119
119
  def __len__(self):
120
120
  return len(self.samples)
121
121
 
@@ -134,6 +134,12 @@ class FixedSampleInfoDataset:
134
134
  self.samples = samples
135
135
  self.sample_info = sample_info
136
136
 
137
+ # Create a TensorFlow dataset that holds (sample, sample_info) tuples
138
+ self.tf_dataset = tf.data.Dataset.from_tensor_slices((
139
+ tf.convert_to_tensor(self.samples),
140
+ tuple(tf.convert_to_tensor(info) for info in self.sample_info)
141
+ ))
142
+
137
143
  def __len__(self):
138
144
  return len(self.samples)
139
145
 
@@ -150,18 +156,23 @@ class IterableSampleWithConstInfoDataset:
150
156
  self.samples_dataset = samples_dataset
151
157
  self.info = info
152
158
 
159
+ # Map to ensure the output is always (sample, info) as a tuple
160
+ self.tf_dataset = self.samples_dataset.map(
161
+ lambda *x: ((x,) if not isinstance(x, tuple) else x, *self.info)
162
+ )
163
+
153
164
  def __iter__(self):
154
165
  for sample in self.samples_dataset:
155
- yield (sample, *self.info)
166
+ yield ((sample,) if not isinstance(sample, tuple) else sample, *self.info)
156
167
 
157
168
 
158
169
  def data_gen_to_dataloader(data_gen_fn: Callable[[], Generator], batch_size: int):
159
170
  """Create a DataLoader based on samples yielded by data_gen."""
160
171
  ds = TFDatasetFromGenerator(data_gen_fn)
161
- return create_tf_dataloader(dataset=ds, batch_size=batch_size)
172
+ return create_tf_dataloader(mct_dataset=ds, batch_size=batch_size)
162
173
 
163
174
 
164
- def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
175
+ def create_tf_dataloader(mct_dataset, batch_size, shuffle=False, collate_fn=None):
165
176
  """
166
177
  Creates a tf.data.Dataset with specified loading options.
167
178
 
@@ -174,26 +185,15 @@ def create_tf_dataloader(dataset, batch_size, shuffle=False, collate_fn=None):
174
185
  Returns:
175
186
  tf.data.Dataset: Configured for batching, shuffling, and custom transformations.
176
187
  """
177
- def generator():
178
- for item in dataset:
179
- yield item
180
-
181
- dummy_input_tensors = next(generator())
182
-
183
- output_signature = get_tensor_spec(dummy_input_tensors)
184
-
185
- tf_dataset = tf.data.Dataset.from_generator(
186
- generator,
187
- output_signature=output_signature
188
- )
188
+ dataset = mct_dataset.tf_dataset
189
189
 
190
190
  if shuffle:
191
- tf_dataset = tf_dataset.shuffle(buffer_size=len(dataset))
191
+ dataset = dataset.shuffle(buffer_size=len(dataset))
192
192
 
193
- tf_dataset = tf_dataset.batch(batch_size)
193
+ dataset = dataset.batch(batch_size)
194
194
 
195
195
  # Apply collate function if provided
196
196
  if collate_fn:
197
- tf_dataset = tf_dataset.map(lambda *args: collate_fn(args))
197
+ dataset = dataset.map(lambda *args: collate_fn(args))
198
198
 
199
- return tf_dataset
199
+ return dataset
@@ -151,7 +151,7 @@ class KerasGPTQTrainer(GPTQTrainer):
151
151
 
152
152
  # Create final dataset using the new dataloader with collate_fn
153
153
  final_dataset = create_tf_dataloader(
154
- dataset=sla_train_dataset,
154
+ sla_train_dataset,
155
155
  batch_size=orig_batch_size,
156
156
  shuffle=True,
157
157
  collate_fn=collate_fn
@@ -176,14 +176,14 @@ class KerasGPTQTrainer(GPTQTrainer):
176
176
 
177
177
  # Step 2: Compute loss weights
178
178
  if self.gptq_config.hessian_weights_config:
179
- hessian_dataset = create_tf_dataloader(dataset=dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
179
+ hessian_dataset = create_tf_dataloader(dataset, batch_size=self.gptq_config.hessian_weights_config.hessian_batch_size)
180
180
  hessian_weights = self.compute_hessian_based_weights(hessian_dataset)
181
181
  loss_weights = tf.convert_to_tensor(hessian_weights, dtype=tf.float32)
182
182
  else:
183
183
  loss_weights = tf.ones(num_nodes, dtype=tf.float32) / num_nodes
184
184
 
185
185
  # Step 3: Create a dataset with samples and loss weights
186
- augmented_dataset = IterableSampleWithConstInfoDataset(dataset.dataset, loss_weights)
186
+ augmented_dataset = IterableSampleWithConstInfoDataset(dataset.tf_dataset, loss_weights)
187
187
 
188
188
  # Step 4: Add constant regularization weights
189
189
  reg_weights = tf.ones(num_nodes, dtype=tf.float32)
@@ -115,7 +115,6 @@ if FOUND_TF:
115
115
  if regularization_factor is None:
116
116
  regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
117
117
 
118
- loss = loss or GPTQMultipleTensorsLoss()
119
118
  hessian_weights_config = None
120
119
  if use_hessian_sample_attention:
121
120
  if not use_hessian_based_weights: # pragma: no cover
@@ -129,7 +128,10 @@ if FOUND_TF:
129
128
  hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
130
129
  hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
131
130
  hessian_batch_size=hessian_batch_size)
132
-
131
+
132
+ # If a loss was not passed (and was not initialized due to use_hessian_sample_attention), use the default loss
133
+ loss = loss or GPTQMultipleTensorsLoss()
134
+
133
135
  if isinstance(gradual_activation_quantization, bool):
134
136
  gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None
135
137
  elif isinstance(gradual_activation_quantization, GradualActivationQuantizationConfig):
@@ -104,7 +104,6 @@ if FOUND_TORCH:
104
104
  if regularization_factor is None:
105
105
  regularization_factor = REG_DEFAULT_SLA if use_hessian_sample_attention else REG_DEFAULT
106
106
 
107
- loss = loss or multiple_tensors_mse_loss
108
107
  hessian_weights_config = None
109
108
  if use_hessian_sample_attention:
110
109
  if not use_hessian_based_weights: # pragma: no cover
@@ -118,6 +117,9 @@ if FOUND_TORCH:
118
117
  hessian_weights_config = GPTQHessianScoresConfig(per_sample=False,
119
118
  hessians_num_samples=GPTQ_HESSIAN_NUM_SAMPLES,
120
119
  hessian_batch_size=hessian_batch_size)
120
+
121
+ # If a loss was not passed (and was not initialized due to use_hessian_sample_attention), use the default loss
122
+ loss = loss or multiple_tensors_mse_loss
121
123
 
122
124
  if isinstance(gradual_activation_quantization, bool):
123
125
  gradual_quant_config = GradualActivationQuantizationConfig() if gradual_activation_quantization else None