mct-nightly 2.0.0.20240410.422__py3-none-any.whl → 2.0.0.20240411.406__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.0.0.20240410.422
3
+ Version: 2.0.0.20240411.406
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=c33LV9Kt6hpVEoLixt_I5rqhtSzRBPSrdmFEifg-VHU,1573
1
+ model_compression_toolkit/__init__.py,sha256=Py1f8nJnEfhzHK091eeZjxPHNqF_ZXrOa97rXbJWdw0,1573
2
2
  model_compression_toolkit/constants.py,sha256=KW_HUEPmQEYqCvWGyORqkYxpvO7w5LViB5J5D-pm_6o,3648
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -113,7 +113,7 @@ model_compression_toolkit/core/common/quantization/quantization_params_generatio
113
113
  model_compression_toolkit/core/common/quantization/quantization_params_generation/outlier_filter.py,sha256=9gnfJV89jpGwAx8ImJ5E9NjCv3lDtbyulP4OtgWb62M,1772
114
114
  model_compression_toolkit/core/common/quantization/quantization_params_generation/power_of_two_selection.py,sha256=BiwDqt5CeU6CW0Qusy3LwWhFtf2J9BvSuGMsTsG6rSw,8538
115
115
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py,sha256=noEdvGiyyW7acgQ2OFWLedCODibTGYJifC9qo8YIU5U,4558
116
- model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=wAeLTGsbMiUrkTrIdozWN8U5ZESSJzF1p0ZpPywVlw4,4346
116
+ model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py,sha256=H2D9rdChIviL_j0mF6zy8Qeu_ZXKRu-hLqckSAT1MR8,4352
117
117
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_search.py,sha256=7kt0JB8PQE0SW9kg8fCwZ5mBkHNgiRrn0of4ZQYQN2A,41524
118
118
  model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py,sha256=nug6XgsywxYf57XF_Tnt2xwdf0zLLsajiZKEblo4lFc,3882
119
119
  model_compression_toolkit/core/common/quantization/quantization_params_generation/symmetric_selection.py,sha256=QtSAtdAb7sTgtoe9L6DnMFO7rjkOtpzE9kD9xmG7eYM,9743
@@ -320,12 +320,12 @@ model_compression_toolkit/exporter/model_wrapper/fw_agnostic/get_inferable_quant
320
320
  model_compression_toolkit/exporter/model_wrapper/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
321
321
  model_compression_toolkit/exporter/model_wrapper/keras/validate_layer.py,sha256=YffgbVYJG5LKeIsW84Pi7NqzQcvJMeQRnAKQCCmIL6c,3776
322
322
  model_compression_toolkit/exporter/model_wrapper/keras/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
323
- model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=NBDzg2rX5BcVELtExHxS5wi0HFxwpGrEedB4ZPSVMas,5130
323
+ model_compression_toolkit/exporter/model_wrapper/keras/builder/fully_quantized_model_builder.py,sha256=k3UrGAw6vKTmZ-oO1lv0VqK3IpAiet9jlIHyEIoL2u0,5132
324
324
  model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer.py,sha256=uL6tJWC4s2IWUy8GJVwtMWpwZZioRRztfKyPJHo14xI,9442
325
325
  model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
326
326
  model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=uTQcnzvP44CgPO0twsUdiMmTBE_Td6ZdQtz5U0GZuPI,3464
327
327
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
328
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=T3QNZl0JFRAm62Z66quHPx0iNHgXwyfSpoBgbqJBBnY,4915
328
+ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=tbXDDPEeWHRS_5DL8e9tTtG6nJ5UohfkLVjI2EIhQeo,4917
329
329
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=4sN5z-6BXrTE5Dp2FX_jKO9ty5iZ2r4RM7XvXtDVLSI,9348
330
330
  model_compression_toolkit/gptq/__init__.py,sha256=YKg-tMj9D4Yd0xW9VRD5EN1J5JrmlRbNEF2fOSgodqA,1228
331
331
  model_compression_toolkit/gptq/runner.py,sha256=MIg-oBtR1nbHkexySdCJD_XfjRoHSknLotmGBMuD5qM,5924
@@ -338,14 +338,14 @@ model_compression_toolkit/gptq/common/gptq_training.py,sha256=rLA1xlOO-6gWfmc2dL
338
338
  model_compression_toolkit/gptq/keras/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
339
339
  model_compression_toolkit/gptq/keras/gptq_keras_implementation.py,sha256=axBwnCSjq5xk-xGymOwSOqjp39It-CVtGcCTRTf0E_4,1248
340
340
  model_compression_toolkit/gptq/keras/gptq_loss.py,sha256=rbRkF15MYd6nq4G49kcjb_dPTa-XNq9cTkrb93mXawo,6241
341
- model_compression_toolkit/gptq/keras/gptq_training.py,sha256=OhYfH6zxRHrRhCde0lbcV9Hu2oeDD9RXh-O8vOPgLbs,18875
341
+ model_compression_toolkit/gptq/keras/gptq_training.py,sha256=zyVcEQzdnNsrIz32U1pqqoi08hzxRdJ2CumaPFGwbDM,19123
342
342
  model_compression_toolkit/gptq/keras/graph_info.py,sha256=5IvgGlJlgOmQYmldjdCBv7tuzAoY0HazatG5Pedrg0Q,4639
343
343
  model_compression_toolkit/gptq/keras/quantization_facade.py,sha256=zAkzWpWP9_aobWgMo_BlUm7-4fR5dHvoGx0sDqs2rZg,14299
344
344
  model_compression_toolkit/gptq/keras/quantizer/__init__.py,sha256=-DK1CDXvlsnEbki4lukZLpl6Xrbo91_jcqxXlG5Eg6Q,963
345
345
  model_compression_toolkit/gptq/keras/quantizer/base_keras_gptq_quantizer.py,sha256=2YU-x4-Q5f6hkUJf0tw6vcwdNwRMHdefrFjhhyHYsvA,4782
346
346
  model_compression_toolkit/gptq/keras/quantizer/quant_utils.py,sha256=Vt7Qb8i4JsE4sFtcjpfM4FTXTtfV1t6SwfoNH8a_Iaw,5055
347
347
  model_compression_toolkit/gptq/keras/quantizer/quantization_builder.py,sha256=FmK5cPwgLAzrDjHTWf_vbRO5s70S7iwpnjnlqEQTuGE,4408
348
- model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=7kvQQz2zHTRkIzJpsOPe8PWtfsOpcGZ2hjVIxbc-qJo,1906
348
+ model_compression_toolkit/gptq/keras/quantizer/regularization_factory.py,sha256=guf7ygnLsZeWnTDz4yJdE2iTkd1oE0uQAZwKnGV3OAk,1957
349
349
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
350
350
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=qUuMKysUpjWYjNbchFuyb_UFwzV1HL7R3Y7o0Z5rf60,4016
351
351
  model_compression_toolkit/gptq/keras/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=BBSDWLmeywjSM5N6oJkMgcuo7zrXTesB4zLwRGG8QB0,12159
@@ -355,14 +355,14 @@ model_compression_toolkit/gptq/keras/quantizer/ste_rounding/symmetric_ste.py,sha
355
355
  model_compression_toolkit/gptq/pytorch/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
356
356
  model_compression_toolkit/gptq/pytorch/gptq_loss.py,sha256=kDuWw-6zh17wZpYWh4Xa94rpoodf82DksgjQCnL7nBc,2719
357
357
  model_compression_toolkit/gptq/pytorch/gptq_pytorch_implementation.py,sha256=tECPTavxn8EEwgLaP2zvxdJH6Vg9jC0YOIMJ7857Sdc,1268
358
- model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=LN4vOwcMuSSFTSnHDACV9hX_Yd2YIXJRl7WkdODuA0k,16245
358
+ model_compression_toolkit/gptq/pytorch/gptq_training.py,sha256=xkDa62AdIRwv8dEshffALW9Ri66eseEpyUF9taMUKns,16509
359
359
  model_compression_toolkit/gptq/pytorch/graph_info.py,sha256=yXJzDd24zfGs2_vfMovxD1WSh1RxXoPxN4GztOf3P5c,3967
360
360
  model_compression_toolkit/gptq/pytorch/quantization_facade.py,sha256=-4USg-tep6EQSArcTxBowhMeAuExrBTNLOWgHFpsIy4,12699
361
361
  model_compression_toolkit/gptq/pytorch/quantizer/__init__.py,sha256=ZHNHo1yzye44m9_ht4UUZfTpK01RiVR3Tr74-vtnOGI,968
362
362
  model_compression_toolkit/gptq/pytorch/quantizer/base_pytorch_gptq_quantizer.py,sha256=TCA1hAc7raPnrjl06sjFtVM4XUtLtuwAhCGX4U3KGZo,4137
363
363
  model_compression_toolkit/gptq/pytorch/quantizer/quant_utils.py,sha256=OocYYRqvl7rZ37QT0hTzfJnWGiNCPskg7cziTlR7TRk,3893
364
364
  model_compression_toolkit/gptq/pytorch/quantizer/quantization_builder.py,sha256=uT9N_aBj965hvQfKd67fS1B0SXGnOLVcqa3wW4b2iZE,4566
365
- model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=-6fn6U6y2HZXluOfShYLeFKiuiDMVvsF64OTUDCrne4,1908
365
+ model_compression_toolkit/gptq/pytorch/quantizer/regularization_factory.py,sha256=mDWZERLwtDzqWeJUwHMVyGdlS8wPLjJ3NvZiKBP6BNA,1959
366
366
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/__init__.py,sha256=lNJ29DYxaLUPDstRDA1PGI5r9Fulq_hvrZMlhst1Z5g,697
367
367
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py,sha256=oO7WgsAHMnWoXNm_gTKAAe-Nd79mGL_m677ai-ui424,4132
368
368
  model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/symmetric_soft_quantizer.py,sha256=kLVQC1hXzDpP4Jx7AwnA764oGnY5AMEuvUUhAvhz09M,12347
@@ -471,8 +471,8 @@ model_compression_toolkit/trainable_infrastructure/keras/quantize_wrapper.py,sha
471
471
  model_compression_toolkit/trainable_infrastructure/keras/quantizer_utils.py,sha256=MVwXNymmFRB2NXIBx4e2mdJ1RfoHxRPYRgjb1MQP5kY,1797
472
472
  model_compression_toolkit/trainable_infrastructure/pytorch/__init__.py,sha256=huHoBUcKNB6BnY6YaUCcFvdyBtBI172ZoUD8ZYeNc6o,696
473
473
  model_compression_toolkit/trainable_infrastructure/pytorch/base_pytorch_quantizer.py,sha256=7bbzqJN8ZAycVDvZr_5xC-niTAR5df8f03Kooev_pfg,3047
474
- mct_nightly-2.0.0.20240410.422.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
475
- mct_nightly-2.0.0.20240410.422.dist-info/METADATA,sha256=Xx2HTbZkpp4O8bS07IXSnaYSh9ZZTxe61I47ovv9fzE,18795
476
- mct_nightly-2.0.0.20240410.422.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
477
- mct_nightly-2.0.0.20240410.422.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
478
- mct_nightly-2.0.0.20240410.422.dist-info/RECORD,,
474
+ mct_nightly-2.0.0.20240411.406.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
475
+ mct_nightly-2.0.0.20240411.406.dist-info/METADATA,sha256=IbtNTzo6qu2zeJ6yTF4uKQCQlaWuTHvIURKZwP1akx0,18795
476
+ mct_nightly-2.0.0.20240411.406.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
477
+ mct_nightly-2.0.0.20240411.406.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
478
+ mct_nightly-2.0.0.20240411.406.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.0.0.20240410.000422"
30
+ __version__ = "2.0.0.20240411.000406"
@@ -42,14 +42,14 @@ def calculate_quantization_params(graph: Graph,
42
42
 
43
43
  """
44
44
 
45
- Logger.info(f"Running quantization parameters search. "
45
+ Logger.info(f"\nRunning quantization parameters search. "
46
46
  f"This process might take some time, "
47
47
  f"depending on the model size and the selected quantization methods.\n")
48
48
 
49
49
  # Create a list of nodes to compute their thresholds
50
50
  nodes_list: List[BaseNode] = nodes if specific_nodes else graph.nodes()
51
51
 
52
- for n in tqdm(nodes_list, "Calculating quantization params"): # iterate only nodes that we should compute their thresholds
52
+ for n in tqdm(nodes_list, "Calculating quantization parameters"): # iterate only nodes that we should compute their thresholds
53
53
  for candidate_qc in n.candidates_quantization_cfg:
54
54
  for attr in n.get_node_weights_attributes():
55
55
  if n.is_weights_quantization_enabled(attr):
@@ -90,7 +90,7 @@ if FOUND_TF:
90
90
  fw_impl=C.keras.keras_implementation.KerasImplementation())).build_model()
91
91
  exportable_model.trainable = False
92
92
 
93
- Logger.info("Please run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
93
+ Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
94
94
  "Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
95
95
  "FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md\n"
96
96
  "Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md")
@@ -82,7 +82,7 @@ if FOUND_TORCH:
82
82
  get_activation_quantizer_holder(n,
83
83
  fw_impl=C.pytorch.pytorch_implementation.PytorchImplementation())).build_model()
84
84
 
85
- Logger.info("Please run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
85
+ Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
86
86
  "Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
87
87
  "FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md\n"
88
88
  "Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md")
@@ -301,21 +301,23 @@ class KerasGPTQTrainer(GPTQTrainer):
301
301
  Returns: None
302
302
 
303
303
  """
304
- for _ in tqdm(range(n_epochs)):
305
- for data in tqdm(data_function()):
306
- input_data = [d * self.input_scale for d in data]
307
-
308
- loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
309
- in_optimizer_with_param, is_training)
310
- # Run one step of gradient descent by updating
311
- # the value of the variables to minimize the loss.
312
- for i, (o, p) in enumerate(in_optimizer_with_param):
313
- o.apply_gradients(zip(grads[i], p))
314
- if self.gptq_config.log_function is not None:
315
- self.gptq_config.log_function(loss_value_step, grads[0], in_optimizer_with_param[0][-1],
316
- self.compare_points)
317
- self.loss_list.append(loss_value_step.numpy())
318
- Logger.debug(f'last loss value: {self.loss_list[-1]}')
304
+ with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
305
+ for _ in epochs_pbar:
306
+ with tqdm(data_function(), position=1, leave=False) as data_pbar:
307
+ for data in data_pbar:
308
+ input_data = [d * self.input_scale for d in data]
309
+
310
+ loss_value_step, grads = self.nano_training_step(input_data, in_compute_gradients,
311
+ in_optimizer_with_param, is_training)
312
+ # Run one step of gradient descent by updating
313
+ # the value of the variables to minimize the loss.
314
+ for i, (o, p) in enumerate(in_optimizer_with_param):
315
+ o.apply_gradients(zip(grads[i], p))
316
+ if self.gptq_config.log_function is not None:
317
+ self.gptq_config.log_function(loss_value_step, grads[0], in_optimizer_with_param[0][-1],
318
+ self.compare_points)
319
+ self.loss_list.append(loss_value_step.numpy())
320
+ Logger.debug(f'last loss value: {self.loss_list[-1]}')
319
321
 
320
322
  def update_graph(self):
321
323
  """
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from tqdm import tqdm
15
16
  from typing import Callable
16
17
 
17
18
  from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
@@ -35,7 +36,7 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
35
36
  if gptq_config.rounding_type == RoundingType.SoftQuantizer:
36
37
  # dry run on the representative dataset to count number of batches
37
38
  num_batches = 0
38
- for _ in representative_data_gen():
39
+ for _ in tqdm(representative_data_gen(), "GPTQ initialization"):
39
40
  num_batches += 1
40
41
 
41
42
  return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)
@@ -248,22 +248,24 @@ class PytorchGPTQTrainer(GPTQTrainer):
248
248
  data_function: A callable function that give a batch of samples.
249
249
  n_epochs: Number of update iterations of representative dataset.
250
250
  """
251
- for _ in tqdm(range(n_epochs)):
252
- for data in tqdm(data_function()):
253
- input_data = [d * self.input_scale for d in data]
254
- input_tensor = to_torch_tensor(input_data)
255
- y_float = self.float_model(input_tensor) # running float model
256
- loss_value, grads = self.compute_gradients(y_float, input_tensor)
257
- # Run one step of gradient descent by updating the value of the variables to minimize the loss.
258
- for (optimizer, _) in self.optimizer_with_param:
259
- optimizer.step()
260
- optimizer.zero_grad()
261
- if self.gptq_config.log_function is not None:
262
- self.gptq_config.log_function(loss_value.item(),
263
- torch_tensor_to_numpy(grads),
264
- torch_tensor_to_numpy(self.optimizer_with_param[0][-1]))
265
- self.loss_list.append(loss_value.item())
266
- Logger.debug(f'last loss value: {self.loss_list[-1]}')
251
+ with tqdm(range(n_epochs), "Running GPTQ optimization") as epochs_pbar:
252
+ for _ in epochs_pbar:
253
+ with tqdm(data_function(), position=1, leave=False) as data_pbar:
254
+ for data in data_pbar:
255
+ input_data = [d * self.input_scale for d in data]
256
+ input_tensor = to_torch_tensor(input_data)
257
+ y_float = self.float_model(input_tensor) # running float model
258
+ loss_value, grads = self.compute_gradients(y_float, input_tensor)
259
+ # Run one step of gradient descent by updating the value of the variables to minimize the loss.
260
+ for (optimizer, _) in self.optimizer_with_param:
261
+ optimizer.step()
262
+ optimizer.zero_grad()
263
+ if self.gptq_config.log_function is not None:
264
+ self.gptq_config.log_function(loss_value.item(),
265
+ torch_tensor_to_numpy(grads),
266
+ torch_tensor_to_numpy(self.optimizer_with_param[0][-1]))
267
+ self.loss_list.append(loss_value.item())
268
+ Logger.debug(f'last loss value: {self.loss_list[-1]}')
267
269
 
268
270
  def update_graph(self) -> Graph:
269
271
  """
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from tqdm import tqdm
15
16
  from typing import Callable
16
17
 
17
18
  from model_compression_toolkit.gptq import RoundingType, GradientPTQConfig, GradientPTQConfig
@@ -35,7 +36,7 @@ def get_regularization(gptq_config: GradientPTQConfig, representative_data_gen:
35
36
  if gptq_config.rounding_type == RoundingType.SoftQuantizer:
36
37
  # dry run on the representative dataset to count number of batches
37
38
  num_batches = 0
38
- for _ in representative_data_gen():
39
+ for _ in tqdm(representative_data_gen(), "GPTQ initialization"):
39
40
  num_batches += 1
40
41
 
41
42
  return SoftQuantizerRegularization(total_gradient_steps=num_batches * gptq_config.n_epochs)