mct-nightly 2.4.0.20250925.543__py3-none-any.whl → 2.4.2.20250926.532__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/METADATA +6 -3
  2. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/RECORD +165 -159
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +5 -2
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +4 -0
  6. model_compression_toolkit/core/common/collectors/base_collector.py +1 -4
  7. model_compression_toolkit/core/common/collectors/mean_collector.py +4 -7
  8. model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.py +4 -7
  9. model_compression_toolkit/core/common/framework_implementation.py +22 -10
  10. model_compression_toolkit/core/common/framework_info.py +83 -93
  11. model_compression_toolkit/core/common/fusion/graph_fuser.py +9 -12
  12. model_compression_toolkit/core/common/graph/base_graph.py +72 -45
  13. model_compression_toolkit/core/common/graph/base_node.py +141 -121
  14. model_compression_toolkit/core/common/graph/functional_node.py +2 -19
  15. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +21 -17
  16. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +18 -8
  17. model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py +9 -14
  18. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +21 -12
  19. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +3 -2
  20. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +5 -2
  21. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +6 -3
  22. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +10 -5
  23. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +5 -2
  24. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +9 -4
  25. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +7 -2
  26. model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +5 -7
  27. model_compression_toolkit/core/common/model_collector.py +18 -22
  28. model_compression_toolkit/core/common/model_validation.py +44 -0
  29. model_compression_toolkit/core/common/network_editors/__init__.py +1 -8
  30. model_compression_toolkit/core/common/network_editors/actions.py +130 -14
  31. model_compression_toolkit/core/common/network_editors/edit_network.py +4 -1
  32. model_compression_toolkit/core/common/pruning/channels_grouping.py +5 -1
  33. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +6 -0
  34. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +15 -5
  35. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +7 -3
  36. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +4 -2
  37. model_compression_toolkit/core/common/pruning/memory_calculator.py +13 -5
  38. model_compression_toolkit/core/common/pruning/prune_graph.py +4 -1
  39. model_compression_toolkit/core/common/pruning/pruner.py +6 -1
  40. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +13 -5
  41. model_compression_toolkit/core/common/pruning/pruning_section.py +18 -9
  42. model_compression_toolkit/core/common/quantization/bit_width_config.py +10 -10
  43. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +55 -116
  44. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +14 -20
  45. model_compression_toolkit/core/common/quantization/node_quantization_config.py +228 -43
  46. model_compression_toolkit/core/common/quantization/quantization_config.py +1 -0
  47. model_compression_toolkit/core/common/quantization/quantization_fn_selection.py +1 -21
  48. model_compression_toolkit/core/common/quantization/quantization_params_fn_selection.py +78 -0
  49. model_compression_toolkit/core/common/quantization/quantization_params_generation/__init__.py +5 -8
  50. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +76 -91
  51. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +66 -36
  52. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_weights_computation.py +32 -61
  53. model_compression_toolkit/core/common/quantization/quantize_node.py +8 -8
  54. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +412 -93
  55. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +7 -3
  56. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +19 -6
  57. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +19 -11
  58. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +15 -15
  59. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +20 -4
  60. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +9 -4
  61. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +12 -8
  62. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +6 -3
  63. model_compression_toolkit/core/common/substitutions/scale_equalization.py +21 -5
  64. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +55 -43
  65. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +3 -1
  66. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  67. model_compression_toolkit/core/common/visualization/nn_visualizer.py +8 -3
  68. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +12 -8
  69. model_compression_toolkit/core/graph_prep_runner.py +35 -22
  70. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +4 -0
  71. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +5 -0
  72. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +15 -8
  73. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +6 -5
  74. model_compression_toolkit/core/keras/default_framework_info.py +91 -131
  75. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +7 -2
  76. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +1 -0
  77. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +18 -29
  78. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +16 -8
  79. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +5 -4
  80. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +13 -3
  81. model_compression_toolkit/core/keras/keras_implementation.py +37 -17
  82. model_compression_toolkit/core/keras/keras_model_validation.py +38 -0
  83. model_compression_toolkit/core/keras/keras_node_prior_info.py +13 -4
  84. model_compression_toolkit/core/keras/mixed_precision/configurable_activation_quantizer.py +1 -2
  85. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +34 -19
  86. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  87. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +5 -3
  88. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +12 -3
  89. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +16 -9
  90. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +5 -1
  91. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +3 -2
  92. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +6 -5
  93. model_compression_toolkit/core/pytorch/default_framework_info.py +79 -93
  94. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +4 -3
  95. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/relu_bound_to_power_of_2.py +5 -5
  96. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +8 -4
  97. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +4 -3
  98. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +12 -3
  99. model_compression_toolkit/core/pytorch/mixed_precision/configurable_activation_quantizer.py +1 -2
  100. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +41 -24
  101. model_compression_toolkit/core/pytorch/pytorch_implementation.py +33 -13
  102. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +5 -1
  103. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  104. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +5 -3
  105. model_compression_toolkit/core/quantization_prep_runner.py +11 -6
  106. model_compression_toolkit/core/runner.py +15 -5
  107. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  108. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  109. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +0 -2
  110. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +1 -0
  111. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +9 -13
  112. model_compression_toolkit/gptq/common/gptq_graph.py +11 -5
  113. model_compression_toolkit/gptq/common/gptq_training.py +8 -1
  114. model_compression_toolkit/gptq/keras/gptq_training.py +9 -3
  115. model_compression_toolkit/gptq/keras/graph_info.py +6 -4
  116. model_compression_toolkit/gptq/keras/quantization_facade.py +10 -4
  117. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  118. model_compression_toolkit/gptq/pytorch/gptq_training.py +9 -3
  119. model_compression_toolkit/gptq/pytorch/graph_info.py +3 -1
  120. model_compression_toolkit/gptq/pytorch/quantization_facade.py +7 -5
  121. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +3 -1
  122. model_compression_toolkit/gptq/runner.py +7 -1
  123. model_compression_toolkit/pruning/keras/pruning_facade.py +12 -7
  124. model_compression_toolkit/pruning/pytorch/pruning_facade.py +8 -4
  125. model_compression_toolkit/ptq/keras/quantization_facade.py +13 -5
  126. model_compression_toolkit/ptq/pytorch/quantization_facade.py +8 -4
  127. model_compression_toolkit/ptq/runner.py +4 -1
  128. model_compression_toolkit/qat/common/qat_config.py +6 -2
  129. model_compression_toolkit/qat/keras/quantization_facade.py +13 -7
  130. model_compression_toolkit/qat/pytorch/quantization_facade.py +11 -7
  131. model_compression_toolkit/target_platform_capabilities/constants.py +1 -1
  132. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -3
  133. model_compression_toolkit/trainable_infrastructure/common/get_quantizer_config.py +2 -0
  134. model_compression_toolkit/trainable_infrastructure/common/trainable_quantizer_config.py +6 -0
  135. model_compression_toolkit/trainable_infrastructure/keras/config_serialization.py +4 -2
  136. model_compression_toolkit/xquant/__init__.py +1 -0
  137. model_compression_toolkit/xquant/common/constants.py +1 -0
  138. model_compression_toolkit/xquant/common/model_folding_utils.py +6 -1
  139. model_compression_toolkit/xquant/common/tensorboard_utils.py +4 -1
  140. model_compression_toolkit/xquant/common/xquant_config.py +27 -1
  141. model_compression_toolkit/xquant/{common → keras}/core_report_generator.py +2 -2
  142. model_compression_toolkit/xquant/keras/facade_xquant_report.py +1 -1
  143. model_compression_toolkit/xquant/{common → keras}/framework_report_utils.py +23 -2
  144. model_compression_toolkit/xquant/keras/keras_report_utils.py +10 -5
  145. model_compression_toolkit/xquant/keras/similarity_calculator.py +199 -0
  146. model_compression_toolkit/xquant/keras/tensorboard_utils.py +3 -0
  147. model_compression_toolkit/xquant/pytorch/core_detect_degrade_layer.py +77 -0
  148. model_compression_toolkit/xquant/pytorch/core_judge_troubleshoot.py +66 -0
  149. model_compression_toolkit/xquant/pytorch/core_report_generator.py +177 -0
  150. model_compression_toolkit/xquant/pytorch/detect_degrade_utils.py +78 -0
  151. model_compression_toolkit/xquant/pytorch/facade_xquant_report.py +41 -1
  152. model_compression_toolkit/xquant/pytorch/framework_report_utils.py +98 -0
  153. model_compression_toolkit/xquant/pytorch/judge_troubleshoot_utils.py +562 -0
  154. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +10 -7
  155. model_compression_toolkit/xquant/{common → pytorch}/similarity_calculator.py +6 -1
  156. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +3 -0
  157. model_compression_toolkit/core/keras/quantization/activation_quantization_fn_factory.py +0 -47
  158. model_compression_toolkit/core/pytorch/quantization/activation_quantization_fn_factory.py +0 -45
  159. model_compression_toolkit/quantization_preparation/__init__.py +0 -14
  160. model_compression_toolkit/quantization_preparation/load_fqc.py +0 -223
  161. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/WHEEL +0 -0
  162. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/licenses/LICENSE.md +0 -0
  163. {mct_nightly-2.4.0.20250925.543.dist-info → mct_nightly-2.4.2.20250926.532.dist-info}/top_level.txt +0 -0
  164. /model_compression_toolkit/core/keras/{quantization → quantizer}/__init__.py +0 -0
  165. /model_compression_toolkit/core/keras/{quantization → quantizer}/fake_quant_builder.py +0 -0
  166. /model_compression_toolkit/core/keras/{quantization → quantizer}/lut_fake_quant.py +0 -0
  167. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/__init__.py +0 -0
  168. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/fake_quant_builder.py +0 -0
  169. /model_compression_toolkit/core/pytorch/{quantization → quantizer}/lut_fake_quant.py +0 -0
@@ -56,7 +56,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
56
56
  super(ReduceLROnPlateau, self).__init__()
57
57
 
58
58
  if factor >= 1.0:
59
- Logger.critical('Factor should be < 1.0.') # pragma: no cover
59
+ Logger.critical('Factor should be < 1.0.') # pragma: no cover
60
60
  self.factor = factor
61
61
 
62
62
  self.optimizer = optimizer
@@ -101,7 +101,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
101
101
  else:
102
102
  self.num_bad_epochs += 1
103
103
 
104
- if self.in_cooldown: # pragma: no cover
104
+ if self.in_cooldown:
105
105
  self.cooldown_counter -= 1
106
106
  self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
107
107
 
@@ -122,7 +122,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
122
122
  new_lr = max(old_lr * self.factor, self.min_lr)
123
123
  if old_lr - new_lr > self.eps:
124
124
  tf.keras.backend.set_value(self.optimizer.learning_rate, new_lr)
125
- if self.verbose: # pragma: no cover
125
+ if self.verbose:
126
126
  print(f'Epoch {epoch:05d}: reducing learning rate to {new_lr:.4e}.')
127
127
 
128
128
  @property
@@ -152,13 +152,13 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
152
152
  if self.mode == 'min' and self.threshold_mode == 'rel':
153
153
  rel_epsilon = 1. - self.threshold
154
154
  return a < best * rel_epsilon
155
- elif self.mode == 'min' and self.threshold_mode == 'abs': # pragma: no cover
155
+ elif self.mode == 'min' and self.threshold_mode == 'abs':
156
156
  return a < best - self.threshold
157
- elif self.mode == 'max' and self.threshold_mode == 'rel': # pragma: no cover
157
+ elif self.mode == 'max' and self.threshold_mode == 'rel':
158
158
  rel_epsilon = self.threshold + 1.
159
159
  return a > best * rel_epsilon
160
160
  else: # mode == 'max' and threshold_mode == 'abs':
161
- return a > best + self.threshold # pragma: no cover
161
+ return a > best + self.threshold
162
162
 
163
163
  def _init_is_better(self, mode: str, threshold: float, threshold_mode: str) -> None:
164
164
  """
@@ -186,7 +186,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
186
186
  self.threshold = threshold
187
187
  self.threshold_mode = threshold_mode
188
188
 
189
- def get_config(self) -> Dict: # pragma: no cover
189
+ def get_config(self) -> Dict:
190
190
  """
191
191
  Return the configuration of the scheduler as a dictionary.
192
192
 
@@ -207,7 +207,7 @@ class ReduceLROnPlateau(tf.keras.callbacks.Callback):
207
207
  base_config = super(ReduceLROnPlateau, self).get_config()
208
208
  return {**base_config, **config}
209
209
 
210
- def set_config(self, config: Dict) -> None: # pragma: no cover
210
+ def set_config(self, config: Dict) -> None:
211
211
  """
212
212
  Set the configuration of the scheduler from a dictionary.
213
213
 
@@ -60,10 +60,10 @@ class ReduceLROnPlateauWithReset:
60
60
  # Attach optimizer
61
61
  if not isinstance(optimizer, Optimizer):
62
62
  Logger.critical('{} is not an Optimizer'.format(
63
- type(optimizer).__name__)) # pragma: no cover
63
+ type(optimizer).__name__)) # pragma: no cover
64
64
  self.optimizer = optimizer
65
65
 
66
- if isinstance(min_lr, (list, tuple)): # pragma: no cover
66
+ if isinstance(min_lr, (list, tuple)):
67
67
  if len(min_lr) != len(optimizer.param_groups):
68
68
  Logger.critical("expected {} min_lrs, got {}".format(
69
69
  len(optimizer.param_groups), len(min_lr))) # pragma: no cover
@@ -117,7 +117,7 @@ class ReduceLROnPlateauWithReset:
117
117
  self.num_bad_epochs += 1
118
118
 
119
119
  # Handle cooldown period
120
- if self.in_cooldown: # pragma: no cover
120
+ if self.in_cooldown:
121
121
  self.cooldown_counter -= 1
122
122
  self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
123
123
 
@@ -142,7 +142,7 @@ class ReduceLROnPlateauWithReset:
142
142
  new_lr = max(old_lr * self.factor, self.min_lrs[i])
143
143
  if old_lr - new_lr > self.eps:
144
144
  param_group['lr'] = new_lr
145
- if self.verbose: # pragma: no cover
145
+ if self.verbose:
146
146
  epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
147
147
  print('Epoch {}: reducing learning rate'
148
148
  ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
@@ -168,19 +168,19 @@ class ReduceLROnPlateauWithReset:
168
168
  Returns:
169
169
  bool: True if the new value is better, False otherwise.
170
170
  """
171
- if best is None: # pragma: no cover
171
+ if best is None:
172
172
  return True
173
173
 
174
174
  if self.mode == 'min' and self.threshold_mode == 'rel':
175
175
  rel_epsilon = 1. - self.threshold
176
176
  return a < best * rel_epsilon
177
- elif self.mode == 'min' and self.threshold_mode == 'abs': # pragma: no cover
177
+ elif self.mode == 'min' and self.threshold_mode == 'abs':
178
178
  return a < best - self.threshold
179
- elif self.mode == 'max' and self.threshold_mode == 'rel': # pragma: no cover
179
+ elif self.mode == 'max' and self.threshold_mode == 'rel':
180
180
  rel_epsilon = self.threshold + 1.
181
181
  return a > best * rel_epsilon
182
182
  else: # mode == 'max' and threshold_mode == 'abs':
183
- return a > best + self.threshold # pragma: no cover
183
+ return a > best + self.threshold
184
184
 
185
185
  def _init_is_better(self) -> None:
186
186
  """
@@ -197,9 +197,9 @@ class ReduceLROnPlateauWithReset:
197
197
  if self.mode == 'min':
198
198
  self.mode_worse = float('inf')
199
199
  else: # mode == 'max':
200
- self.mode_worse = float('-inf') # pragma: no cover
200
+ self.mode_worse = float('-inf')
201
201
 
202
- def state_dict(self) -> Dict[str, Any]: # pragma: no cover
202
+ def state_dict(self) -> Dict[str, Any]:
203
203
  """
204
204
  Return the state of the scheduler as a dictionary.
205
205
 
@@ -208,7 +208,7 @@ class ReduceLROnPlateauWithReset:
208
208
  """
209
209
  return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
210
210
 
211
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # pragma: no cover
211
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
212
212
  """
213
213
  Load the scheduler state.
214
214
 
@@ -21,7 +21,6 @@ from model_compression_toolkit.logger import Logger
21
21
 
22
22
  if FOUND_TF:
23
23
  import keras
24
- from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
25
24
  from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
26
25
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import \
27
26
  FakelyQuantKerasExporter
@@ -37,7 +36,6 @@ if FOUND_TF:
37
36
  KerasExportSerializationFormat.TFLITE: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.INT8]
38
37
  }
39
38
 
40
- @set_keras_info
41
39
  def keras_export_model(model: keras.models.Model,
42
40
  save_model_path: str,
43
41
  is_layer_exportable_fn: Callable = is_keras_layer_exportable,
@@ -19,6 +19,7 @@ import torch.nn
19
19
 
20
20
  from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
21
21
 
22
+ from model_compression_toolkit.core.pytorch.reader.node_holders import DummyPlaceHolder
22
23
  from model_compression_toolkit.verify_packages import FOUND_ONNX
23
24
  from model_compression_toolkit.logger import Logger
24
25
  from model_compression_toolkit.core.pytorch.utils import to_torch_tensor
@@ -27,7 +27,6 @@ DEFAULT_ONNX_OPSET_VERSION = 15
27
27
 
28
28
  if FOUND_TORCH:
29
29
  import torch.nn
30
- from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
31
30
  from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
32
31
  from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
33
32
  from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
@@ -42,14 +41,13 @@ if FOUND_TORCH:
42
41
  PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.MCTQ]
43
42
  }
44
43
 
45
- @set_pytorch_info
46
44
  def pytorch_export_model(model: torch.nn.Module,
47
45
  save_model_path: str,
48
46
  repr_dataset: Callable,
49
47
  is_layer_exportable_fn: Callable = is_pytorch_layer_exportable,
50
48
  serialization_format: PytorchExportSerializationFormat = PytorchExportSerializationFormat.ONNX,
51
49
  quantization_format: QuantizationFormat = QuantizationFormat.MCTQ,
52
- onnx_opset_version=DEFAULT_ONNX_OPSET_VERSION,
50
+ onnx_opset_version: int = DEFAULT_ONNX_OPSET_VERSION,
53
51
  output_names: Optional[List[str]] = None) -> None:
54
52
  """
55
53
  Export a PyTorch quantized model to a torchscript or onnx model.
@@ -60,16 +58,14 @@ if FOUND_TORCH:
60
58
  (where the model will be saved to ONNX model).
61
59
 
62
60
  Args:
63
- model: Model to export.
64
- save_model_path: Path to save the model.
65
- repr_dataset: Representative dataset for tracing the pytorch model (mandatory for exporting it).
66
- is_layer_exportable_fn: Callable to check whether a layer can be exported or not.
67
- serialization_format: Format to export the model according to (by default
68
- PytorchExportSerializationFormat.ONNX).
69
- quantization_format: Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
70
- onnx_opset_version: ONNX opset version to use for exported ONNX model.
71
- output_names (Optional[List[str]]): Optional list of output node names for export compatibility.
72
- This argument is relevant only when using PytorchExportSerializationFormat.ONNX.
61
+ model (Module): Model to export.
62
+ save_model_path (str): Path to save the model.
63
+ repr_dataset (Callable): Representative dataset for tracing the pytorch model (mandatory for exporting it).
64
+ is_layer_exportable_fn (Callable): Callable to check whether a layer can be exported or not.
65
+ serialization_format (PytorchExportSerializationFormat): Format to export the model according to (by default PytorchExportSerializationFormat.ONNX).
66
+ quantization_format (QuantizationFormat): Format of how quantizers are exported (fakely-quant, int8, MCTQ quantizers).
67
+ onnx_opset_version (int): ONNX opset version to use for exported ONNX model.
68
+ output_names (Optional[List[str]]): Optional list of output node names for export compatibility. This argument is relevant only when using PytorchExportSerializationFormat.ONNX.
73
69
 
74
70
  """
75
71
  # Ensure 'metadata' is available directly on the model, if present in submodules
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple, List
16
16
 
17
+ from model_compression_toolkit.core import FrameworkInfo
17
18
  from model_compression_toolkit.logger import Logger
18
- from model_compression_toolkit.core.common.framework_info import get_fw_info
19
19
  from model_compression_toolkit.core.common.graph.base_graph import Graph
20
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
21
21
 
@@ -40,7 +40,8 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
40
40
  compare_points_name = []
41
41
  for n in input_graph.get_topo_sorted_nodes():
42
42
  # only nodes with kernel attribute are currently trained with GPTQ and are used as compare points
43
- if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and not n.reuse:
43
+ kernel_attr = input_graph.fw_info.get_kernel_op_attributes(n.type)[0]
44
+ if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr) and not n.reuse:
44
45
  compare_points.append(n)
45
46
  compare_points_name.append(n.name)
46
47
  compare_points_std.append(n.prior_info.std_output)
@@ -48,15 +49,20 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
48
49
  return compare_points, compare_points_name, compare_points_mean, compare_points_std
49
50
 
50
51
 
51
- def get_kernel_attribute_name_for_gptq(layer_type: type) -> str:
52
+ def get_kernel_attribute_name_for_gptq(layer_type: type, fw_info: FrameworkInfo) -> str:
52
53
  """
53
54
  Returns a layer's kernel attribute name for GPTQ training purposes.
54
55
 
55
56
  Args:
56
57
  layer_type: A type of model's layer.
58
+ fw_info: A FrameworkInfo object.
57
59
 
58
60
  Returns: The name of the kernel attribute.
59
61
 
60
62
  """
61
-
62
- return get_fw_info().get_kernel_op_attribute(layer_type)
63
+ kernel_attribute = fw_info.get_kernel_op_attributes(layer_type)
64
+ if len(kernel_attribute) != 1:
65
+ Logger.critical( # pragma: no cover
66
+ f"In GPTQ training, only the kernel weights attribute should be trained. "
67
+ f"However, the number of kernel attributes is {len(kernel_attribute)}.")
68
+ return kernel_attribute[0]
@@ -44,6 +44,7 @@ class GPTQTrainer(ABC):
44
44
  graph_quant: Graph,
45
45
  gptq_config: GradientPTQConfig,
46
46
  fw_impl: GPTQFrameworkImplemantation,
47
+ fw_info: FrameworkInfo,
47
48
  representative_data_gen_fn: Callable[[], Generator],
48
49
  hessian_info_service: HessianInfoService = None):
49
50
  """
@@ -57,6 +58,7 @@ class GPTQTrainer(ABC):
57
58
  graph_quant: Graph to build a quantized networks from.
58
59
  gptq_config: GradientPTQConfig with parameters about the tuning process.
59
60
  fw_impl: Framework implementation
61
+ fw_info: Framework information
60
62
  representative_data_gen_fn: factory for representative data generator.
61
63
  hessian_info_service: HessianInfoService for fetching and computing Hessian-approximation information.
62
64
  """
@@ -64,6 +66,7 @@ class GPTQTrainer(ABC):
64
66
  self.graph_quant = copy.deepcopy(graph_quant)
65
67
  self.gptq_config = gptq_config
66
68
  self.fw_impl = fw_impl
69
+ self.fw_info = fw_info
67
70
  self.representative_data_gen_fn = representative_data_gen_fn
68
71
 
69
72
  def _get_total_grad_steps():
@@ -80,7 +83,8 @@ class GPTQTrainer(ABC):
80
83
 
81
84
  self.float_model, self.float_user_info = fw_impl.model_builder(self.graph_float,
82
85
  mode=ModelBuilderMode.FLOAT,
83
- append2output=self.compare_points)
86
+ append2output=self.compare_points,
87
+ fw_info=self.fw_info)
84
88
 
85
89
  self.fxp_model, self.gptq_user_info = self.build_gptq_model()
86
90
  if self.gptq_config.hessian_weights_config:
@@ -284,6 +288,7 @@ def gptq_training(graph_float: Graph,
284
288
  gptq_config: GradientPTQConfig,
285
289
  representative_data_gen: Callable,
286
290
  fw_impl: GPTQFrameworkImplemantation,
291
+ fw_info: FrameworkInfo,
287
292
  hessian_info_service: HessianInfoService = None) -> Graph:
288
293
  """
289
294
  GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
@@ -293,6 +298,7 @@ def gptq_training(graph_float: Graph,
293
298
  gptq_config: GradientPTQConfig with parameters about the tuning process.
294
299
  representative_data_gen: Dataset to use for inputs of the models.
295
300
  fw_impl: Framework implementation
301
+ fw_info: Framework information
296
302
  hessian_info_service: HessianInfoService to fetch information based on the Hessian approximation.
297
303
 
298
304
  Returns:
@@ -306,6 +312,7 @@ def gptq_training(graph_float: Graph,
306
312
  graph_quant,
307
313
  gptq_config,
308
314
  fw_impl,
315
+ fw_info,
309
316
  representative_data_gen,
310
317
  hessian_info_service=hessian_info_service)
311
318
 
@@ -65,6 +65,7 @@ class KerasGPTQTrainer(GPTQTrainer):
65
65
  graph_quant: Graph,
66
66
  gptq_config: GradientPTQConfig,
67
67
  fw_impl: FrameworkImplementation,
68
+ fw_info: FrameworkInfo,
68
69
  representative_data_gen: Callable,
69
70
  hessian_info_service: HessianInfoService = None):
70
71
  """
@@ -78,6 +79,7 @@ class KerasGPTQTrainer(GPTQTrainer):
78
79
  graph_quant: Graph to build a quantized networks from.
79
80
  gptq_config: GradientPTQConfig with parameters about the tuning process.
80
81
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
82
+ fw_info: Framework information.
81
83
  representative_data_gen: Dataset to use for inputs of the models.
82
84
  hessian_info_service: HessianScoresService for fetching and computing Hessian's approximation scores.
83
85
 
@@ -92,6 +94,7 @@ class KerasGPTQTrainer(GPTQTrainer):
92
94
  graph_quant,
93
95
  gptq_config,
94
96
  fw_impl,
97
+ fw_info,
95
98
  representative_data_gen_fn=representative_data_gen,
96
99
  hessian_info_service=hessian_info_service)
97
100
 
@@ -207,7 +210,8 @@ class KerasGPTQTrainer(GPTQTrainer):
207
210
  Returns:
208
211
  A boolean whether the layer is to be wrapped with a QuantizeWrapper
209
212
  """
210
- return node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)
213
+ kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
214
+ return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
211
215
 
212
216
  def gptq_wrapper(self,
213
217
  n: common.BaseNode,
@@ -226,7 +230,7 @@ class KerasGPTQTrainer(GPTQTrainer):
226
230
  # If we are here, then the node has a kernel attribute to quantize and training during GPTQ
227
231
  weights_quantizers, _ = quantization_builder(n,
228
232
  self.gptq_config, # TODO: split quantizers building into two functions: for weights and activations
229
- n.kernel_attr)
233
+ self.fw_info.get_kernel_op_attributes(n.type)[0])
230
234
  if len(weights_quantizers) > 0:
231
235
  return KerasTrainableQuantizationWrapper(layer,
232
236
  weights_quantizers=weights_quantizers)
@@ -267,6 +271,7 @@ class KerasGPTQTrainer(GPTQTrainer):
267
271
 
268
272
  gptq_model, gptq_user_info = KerasModelBuilder(graph=self.graph_quant,
269
273
  append2output=self.compare_points,
274
+ fw_info=self.fw_info,
270
275
  return_float_outputs=True,
271
276
  wrapper=self.gptq_wrapper,
272
277
  get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
@@ -426,7 +431,8 @@ class KerasGPTQTrainer(GPTQTrainer):
426
431
  Logger.critical(f"Unable to update the GPTQ graph because the layer named '{layer.layer.name}' could not be found. "
427
432
  f"Verify that the layer names in the GPTQ model match those in the graph.")
428
433
  node = node[0]
429
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type)
434
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
435
+ fw_info=self.fw_info)
430
436
  # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
431
437
  # To enable GPTQ for other attributes, this code needs to be modified.
432
438
  weights, weight_quant_config, activation_quant_config = \
@@ -16,8 +16,8 @@
16
16
  import tensorflow as tf
17
17
  from typing import Tuple, List
18
18
  from model_compression_toolkit.core.keras.constants import USE_BIAS
19
- from model_compression_toolkit.core.common.framework_info import get_fw_info
20
19
  from tensorflow.keras.models import Model
20
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
21
21
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
22
  from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
@@ -44,7 +44,8 @@ def get_gptq_trainable_parameters(fxp_model: Model,
44
44
 
45
45
  for layer in fxp_model.layers:
46
46
  if isinstance(layer, KerasTrainableQuantizationWrapper):
47
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
47
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
48
+ fw_info=DEFAULT_KERAS_INFO)
48
49
 
49
50
  # collect trainable weights per quantizer
50
51
  if kernel_attribute not in layer.weights_quantizers:
@@ -56,8 +57,9 @@ def get_gptq_trainable_parameters(fxp_model: Model,
56
57
  trainable_threshold.extend(quantizer_trainable_threshold)
57
58
 
58
59
  if add_bias:
59
- kernel_ops_attr = get_fw_info().get_kernel_op_attribute(type(layer.layer))
60
- use_bias = kernel_ops_attr is not None and layer.layer.get_config().get(USE_BIAS)
60
+ kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
61
+ use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
62
+ and layer.layer.get_config().get(USE_BIAS)
61
63
  if use_bias is not None and use_bias and layer.layer.bias is not None:
62
64
  bias_weights.append([layer.layer.bias])
63
65
 
@@ -41,8 +41,9 @@ from model_compression_toolkit.metadata import create_model_metadata
41
41
 
42
42
  if FOUND_TF:
43
43
  import tensorflow as tf
44
- from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
44
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
45
45
  from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
46
+ from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
46
47
  from tensorflow.keras.models import Model
47
48
  from model_compression_toolkit.gptq.keras.gptq_loss import GPTQMultipleTensorsLoss, sample_layer_attention_loss
48
49
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
@@ -151,7 +152,6 @@ if FOUND_TF:
151
152
  gradual_activation_quantization_config=gradual_quant_config)
152
153
 
153
154
 
154
- @set_keras_info
155
155
  def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
156
156
  gptq_config: GradientPTQConfig,
157
157
  gptq_representative_data_gen: Callable = None,
@@ -234,13 +234,16 @@ if FOUND_TF:
234
234
  if core_config.debug_config.bypass:
235
235
  return in_model, None
236
236
 
237
+ KerasModelValidation(model=in_model,
238
+ fw_info=DEFAULT_KERAS_INFO).validate()
239
+
237
240
  if core_config.is_mixed_precision_enabled:
238
241
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
239
242
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
240
243
  "Ensure usage of the correct API for keras_post_training_quantization "
241
244
  "or provide a valid mixed-precision configuration.") # pragma: no cover
242
245
 
243
- tb_w = init_tensorboard_writer()
246
+ tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
244
247
 
245
248
  fw_impl = GPTQKerasImplemantation()
246
249
 
@@ -254,6 +257,7 @@ if FOUND_TF:
254
257
  tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
255
258
  representative_data_gen=representative_data_gen,
256
259
  core_config=core_config,
260
+ fw_info=DEFAULT_KERAS_INFO,
257
261
  fw_impl=fw_impl,
258
262
  fqc=framework_platform_capabilities,
259
263
  target_resource_utilization=target_resource_utilization,
@@ -267,6 +271,7 @@ if FOUND_TF:
267
271
  gptq_config,
268
272
  representative_data_gen,
269
273
  gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
274
+ DEFAULT_KERAS_INFO,
270
275
  fw_impl,
271
276
  tb_w,
272
277
  hessian_info_service=hessian_info_service)
@@ -278,7 +283,8 @@ if FOUND_TF:
278
283
  tb_w,
279
284
  float_graph,
280
285
  tg_gptq,
281
- fw_impl)
286
+ fw_impl,
287
+ DEFAULT_KERAS_INFO)
282
288
 
283
289
  exportable_model, user_info = get_exportable_keras_model(tg_gptq)
284
290
  if framework_platform_capabilities.tpc.add_metadata:
@@ -17,6 +17,7 @@ from typing import List, Callable
17
17
  import tensorflow as tf
18
18
  from keras import Model
19
19
 
20
+ from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
20
21
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
21
22
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
22
23
 
@@ -65,7 +66,8 @@ class SoftQuantizerRegularization:
65
66
 
66
67
  # Compute the regularization term without concatenating
67
68
  for i, layer in enumerate(layers):
68
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
69
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
70
+ fw_info=DEFAULT_KERAS_INFO)
69
71
 
70
72
  st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
71
73
 
@@ -54,6 +54,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
54
54
  graph_quant: Graph,
55
55
  gptq_config: GradientPTQConfig,
56
56
  fw_impl: FrameworkImplementation,
57
+ fw_info: FrameworkInfo,
57
58
  representative_data_gen: Callable,
58
59
  hessian_info_service: HessianInfoService = None):
59
60
  """
@@ -67,6 +68,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
67
68
  graph_quant: Graph to build a quantized networks from.
68
69
  gptq_config: GradientPTQConfigV2 with parameters about the tuning process.
69
70
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
71
+ fw_info: Framework information
70
72
  representative_data_gen: Dataset to use for inputs of the models.
71
73
  hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
72
74
  """
@@ -79,6 +81,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
79
81
  graph_quant,
80
82
  gptq_config,
81
83
  fw_impl,
84
+ fw_info,
82
85
  representative_data_gen_fn=representative_data_gen,
83
86
  hessian_info_service=hessian_info_service)
84
87
 
@@ -164,7 +167,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
164
167
  A boolean whether the layer is to be wrapped with a Quantization Wrapper.
165
168
  """
166
169
 
167
- return node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)
170
+ kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
171
+ return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
168
172
 
169
173
  def gptq_wrapper(self,
170
174
  n: BaseNode,
@@ -183,7 +187,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
183
187
  # If we are here, then the node has a kernel attribute to quantize and training during GPTQ
184
188
  weights_quantizers, _ = quantization_builder(n,
185
189
  self.gptq_config,
186
- n.kernel_attr)
190
+ self.fw_info.get_kernel_op_attributes(n.type)[0])
187
191
 
188
192
  if len(weights_quantizers) > 0:
189
193
  return PytorchQuantizationWrapper(layer,
@@ -220,6 +224,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
220
224
  """
221
225
  gptq_model, gptq_user_info = PyTorchModelBuilder(graph=self.graph_quant,
222
226
  append2output=self.compare_points,
227
+ fw_info=self.fw_info,
223
228
  wrapper=self.gptq_wrapper,
224
229
  return_float_outputs=True,
225
230
  get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
@@ -335,7 +340,8 @@ class PytorchGPTQTrainer(GPTQTrainer):
335
340
  Logger.critical(f"Cannot update GPTQ graph: Layer with name '{name}' is missing or not unique. "
336
341
  f"Ensure each layer has a unique name and exists within the graph for updates.")
337
342
  node = node[0]
338
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type)
343
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
344
+ fw_info=self.fw_info)
339
345
  # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
340
346
  # To enable GPTQ for other attributes, this code needs to be modified.
341
347
  weights, weight_quant_config, activation_quant_config = \
@@ -16,6 +16,7 @@ import torch
16
16
  import torch.nn as nn
17
17
  from typing import List
18
18
  from model_compression_toolkit.core.pytorch.constants import BIAS
19
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
19
20
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
20
21
  from model_compression_toolkit.logger import Logger
21
22
  from mct_quantizers import PytorchQuantizationWrapper
@@ -42,7 +43,8 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
42
43
 
43
44
  for layer in fxp_model.modules():
44
45
  if isinstance(layer, PytorchQuantizationWrapper):
45
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
46
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
47
+ fw_info=DEFAULT_PYTORCH_INFO)
46
48
 
47
49
  # collect trainable weights per quantizer
48
50
  if kernel_attribute not in layer.weights_quantizers:
@@ -39,7 +39,7 @@ from model_compression_toolkit.verify_packages import FOUND_TORCH
39
39
 
40
40
 
41
41
  if FOUND_TORCH:
42
- from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
42
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
43
43
  from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
44
44
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
45
45
  from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss, sample_layer_attention_loss
@@ -142,8 +142,6 @@ if FOUND_TORCH:
142
142
  gradual_activation_quantization_config=gradual_quant_config,
143
143
  log_function=log_function)
144
144
 
145
-
146
- @set_pytorch_info
147
145
  def pytorch_gradient_post_training_quantization(model: Module,
148
146
  representative_data_gen: Callable,
149
147
  target_resource_utilization: ResourceUtilization = None,
@@ -218,7 +216,8 @@ if FOUND_TORCH:
218
216
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
219
217
  "Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
220
218
  "or provide a valid mixed-precision configuration.")
221
- tb_w = init_tensorboard_writer()
219
+
220
+ tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
222
221
 
223
222
  fw_impl = GPTQPytorchImplemantation()
224
223
 
@@ -234,6 +233,7 @@ if FOUND_TORCH:
234
233
  graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
235
234
  representative_data_gen=representative_data_gen,
236
235
  core_config=core_config,
236
+ fw_info=DEFAULT_PYTORCH_INFO,
237
237
  fw_impl=fw_impl,
238
238
  fqc=framework_quantization_capabilities,
239
239
  target_resource_utilization=target_resource_utilization,
@@ -250,6 +250,7 @@ if FOUND_TORCH:
250
250
  gptq_config,
251
251
  representative_data_gen,
252
252
  gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
253
+ DEFAULT_PYTORCH_INFO,
253
254
  fw_impl,
254
255
  tb_w,
255
256
  hessian_info_service=hessian_info_service)
@@ -259,7 +260,8 @@ if FOUND_TORCH:
259
260
  tb_w,
260
261
  float_graph,
261
262
  graph_gptq,
262
- fw_impl)
263
+ fw_impl,
264
+ DEFAULT_PYTORCH_INFO)
263
265
 
264
266
  exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
265
267
  if framework_quantization_capabilities.tpc.add_metadata:
@@ -18,6 +18,7 @@ import torch
18
18
  from torch import nn
19
19
 
20
20
  from mct_quantizers import PytorchQuantizationWrapper
21
+ from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
21
22
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
23
 
23
24
 
@@ -60,7 +61,8 @@ class SoftQuantizerRegularization:
60
61
  b = self.beta_scheduler(self.count_iter)
61
62
  reg = 0
62
63
  for layer, w in zip(layers, layer_weights):
63
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
64
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
65
+ fw_info=DEFAULT_PYTORCH_INFO)
64
66
 
65
67
  st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
66
68
  soft_loss = (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()