mct-nightly 2.4.0.20250617.613__py3-none-any.whl → 2.4.0.20250619.621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/METADATA +1 -1
  2. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/RECORD +123 -123
  3. model_compression_toolkit/__init__.py +1 -1
  4. model_compression_toolkit/core/analyzer.py +2 -5
  5. model_compression_toolkit/core/common/back2framework/base_model_builder.py +0 -3
  6. model_compression_toolkit/core/common/framework_implementation.py +10 -22
  7. model_compression_toolkit/core/common/framework_info.py +105 -68
  8. model_compression_toolkit/core/common/graph/base_graph.py +15 -42
  9. model_compression_toolkit/core/common/graph/base_node.py +103 -42
  10. model_compression_toolkit/core/common/graph/functional_node.py +18 -1
  11. model_compression_toolkit/core/common/graph/virtual_activation_weights_node.py +7 -13
  12. model_compression_toolkit/core/common/mixed_precision/bit_width_setter.py +8 -18
  13. model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py +4 -7
  14. model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py +2 -3
  15. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py +2 -5
  16. model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py +3 -6
  17. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py +5 -10
  18. model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +2 -5
  19. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/metric_calculators.py +4 -8
  20. model_compression_toolkit/core/common/mixed_precision/sensitivity_eval/sensitivity_evaluation.py +2 -7
  21. model_compression_toolkit/core/common/model_collector.py +10 -20
  22. model_compression_toolkit/core/common/model_validation.py +1 -4
  23. model_compression_toolkit/core/common/network_editors/actions.py +14 -38
  24. model_compression_toolkit/core/common/network_editors/edit_network.py +1 -4
  25. model_compression_toolkit/core/common/pruning/channels_grouping.py +1 -5
  26. model_compression_toolkit/core/common/pruning/greedy_mask_calculator.py +0 -6
  27. model_compression_toolkit/core/common/pruning/importance_metrics/lfh_importance_metric.py +5 -15
  28. model_compression_toolkit/core/common/pruning/mask/per_channel_mask.py +3 -7
  29. model_compression_toolkit/core/common/pruning/mask/per_simd_group_mask.py +2 -4
  30. model_compression_toolkit/core/common/pruning/memory_calculator.py +5 -13
  31. model_compression_toolkit/core/common/pruning/prune_graph.py +1 -4
  32. model_compression_toolkit/core/common/pruning/pruner.py +1 -6
  33. model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py +5 -13
  34. model_compression_toolkit/core/common/pruning/pruning_section.py +9 -18
  35. model_compression_toolkit/core/common/quantization/candidate_node_quantization_config.py +2 -1
  36. model_compression_toolkit/core/common/quantization/filter_nodes_candidates.py +10 -12
  37. model_compression_toolkit/core/common/quantization/node_quantization_config.py +4 -3
  38. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_computation.py +5 -11
  39. model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +8 -22
  40. model_compression_toolkit/core/common/statistics_correction/apply_activation_bias_correction_to_graph.py +1 -2
  41. model_compression_toolkit/core/common/statistics_correction/apply_bias_correction_to_graph.py +2 -3
  42. model_compression_toolkit/core/common/statistics_correction/apply_second_moment_correction_to_graph.py +5 -13
  43. model_compression_toolkit/core/common/statistics_correction/compute_activation_bias_correction_of_graph.py +3 -9
  44. model_compression_toolkit/core/common/statistics_correction/compute_bias_correction_of_graph.py +3 -10
  45. model_compression_toolkit/core/common/statistics_correction/statistics_correction.py +1 -6
  46. model_compression_toolkit/core/common/substitutions/batchnorm_reconstruction.py +2 -3
  47. model_compression_toolkit/core/common/substitutions/batchnorm_refusing.py +3 -6
  48. model_compression_toolkit/core/common/substitutions/scale_equalization.py +5 -21
  49. model_compression_toolkit/core/common/substitutions/shift_negative_activation.py +5 -19
  50. model_compression_toolkit/core/common/substitutions/virtual_activation_weights_composition.py +1 -3
  51. model_compression_toolkit/core/common/substitutions/weights_activation_split.py +1 -1
  52. model_compression_toolkit/core/common/visualization/nn_visualizer.py +3 -8
  53. model_compression_toolkit/core/common/visualization/tensorboard_writer.py +6 -8
  54. model_compression_toolkit/core/graph_prep_runner.py +2 -16
  55. model_compression_toolkit/core/keras/back2framework/float_model_builder.py +0 -4
  56. model_compression_toolkit/core/keras/back2framework/keras_model_builder.py +0 -5
  57. model_compression_toolkit/core/keras/back2framework/mixed_precision_model_builder.py +8 -15
  58. model_compression_toolkit/core/keras/back2framework/quantized_model_builder.py +0 -4
  59. model_compression_toolkit/core/keras/default_framework_info.py +138 -87
  60. model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_folding.py +2 -7
  61. model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py +0 -1
  62. model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py +3 -5
  63. model_compression_toolkit/core/keras/graph_substitutions/substitutions/scale_equalization.py +8 -16
  64. model_compression_toolkit/core/keras/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  65. model_compression_toolkit/core/keras/hessian/weights_hessian_scores_calculator_keras.py +3 -13
  66. model_compression_toolkit/core/keras/keras_implementation.py +15 -35
  67. model_compression_toolkit/core/keras/keras_model_validation.py +6 -7
  68. model_compression_toolkit/core/keras/keras_node_prior_info.py +4 -13
  69. model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py +11 -34
  70. model_compression_toolkit/core/keras/resource_utilization_data_facade.py +2 -2
  71. model_compression_toolkit/core/keras/statistics_correction/keras_compute_activation_bias_correction_of_graph.py +0 -3
  72. model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py +3 -12
  73. model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py +9 -16
  74. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +1 -5
  75. model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py +2 -3
  76. model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py +0 -4
  77. model_compression_toolkit/core/pytorch/default_framework_info.py +100 -74
  78. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/const_holder_conv.py +3 -4
  79. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py +4 -8
  80. model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py +1 -4
  81. model_compression_toolkit/core/pytorch/hessian/weights_hessian_scores_calculator_pytorch.py +3 -12
  82. model_compression_toolkit/core/pytorch/pruning/pruning_pytorch_implementation.py +16 -41
  83. model_compression_toolkit/core/pytorch/pytorch_implementation.py +12 -32
  84. model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py +1 -5
  85. model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py +2 -2
  86. model_compression_toolkit/core/pytorch/statistics_correction/pytorch_compute_activation_bias_correction_of_graph.py +0 -3
  87. model_compression_toolkit/core/quantization_prep_runner.py +4 -9
  88. model_compression_toolkit/core/runner.py +5 -15
  89. model_compression_toolkit/data_generation/keras/optimization_functions/lr_scheduler.py +8 -8
  90. model_compression_toolkit/data_generation/pytorch/optimization_functions/lr_scheduler.py +11 -11
  91. model_compression_toolkit/exporter/model_exporter/keras/keras_export_facade.py +2 -0
  92. model_compression_toolkit/exporter/model_exporter/pytorch/fakely_quant_onnx_pytorch_exporter.py +19 -17
  93. model_compression_toolkit/exporter/model_exporter/pytorch/pytorch_export_facade.py +2 -0
  94. model_compression_toolkit/gptq/common/gptq_graph.py +5 -11
  95. model_compression_toolkit/gptq/common/gptq_training.py +1 -8
  96. model_compression_toolkit/gptq/keras/gptq_training.py +3 -9
  97. model_compression_toolkit/gptq/keras/graph_info.py +4 -6
  98. model_compression_toolkit/gptq/keras/quantization_facade.py +5 -8
  99. model_compression_toolkit/gptq/keras/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  100. model_compression_toolkit/gptq/pytorch/gptq_training.py +3 -9
  101. model_compression_toolkit/gptq/pytorch/graph_info.py +1 -3
  102. model_compression_toolkit/gptq/pytorch/quantization_facade.py +5 -7
  103. model_compression_toolkit/gptq/pytorch/quantizer/soft_rounding/soft_quantizer_reg.py +1 -3
  104. model_compression_toolkit/gptq/runner.py +1 -7
  105. model_compression_toolkit/pruning/keras/pruning_facade.py +2 -3
  106. model_compression_toolkit/pruning/pytorch/pruning_facade.py +2 -3
  107. model_compression_toolkit/ptq/keras/quantization_facade.py +5 -10
  108. model_compression_toolkit/ptq/pytorch/quantization_facade.py +4 -8
  109. model_compression_toolkit/ptq/runner.py +1 -4
  110. model_compression_toolkit/qat/common/qat_config.py +2 -6
  111. model_compression_toolkit/qat/keras/quantization_facade.py +7 -10
  112. model_compression_toolkit/qat/pytorch/quantization_facade.py +6 -10
  113. model_compression_toolkit/xquant/common/core_report_generator.py +1 -1
  114. model_compression_toolkit/xquant/common/framework_report_utils.py +0 -3
  115. model_compression_toolkit/xquant/common/model_folding_utils.py +1 -6
  116. model_compression_toolkit/xquant/common/tensorboard_utils.py +1 -4
  117. model_compression_toolkit/xquant/keras/keras_report_utils.py +3 -8
  118. model_compression_toolkit/xquant/keras/tensorboard_utils.py +0 -3
  119. model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py +5 -8
  120. model_compression_toolkit/xquant/pytorch/tensorboard_utils.py +0 -3
  121. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/WHEEL +0 -0
  122. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/licenses/LICENSE.md +0 -0
  123. {mct_nightly-2.4.0.20250617.613.dist-info → mct_nightly-2.4.0.20250619.621.dist-info}/top_level.txt +0 -0
@@ -60,10 +60,10 @@ class ReduceLROnPlateauWithReset:
60
60
  # Attach optimizer
61
61
  if not isinstance(optimizer, Optimizer):
62
62
  Logger.critical('{} is not an Optimizer'.format(
63
- type(optimizer).__name__)) # pragma: no cover
63
+ type(optimizer).__name__)) # pragma: no cover
64
64
  self.optimizer = optimizer
65
65
 
66
- if isinstance(min_lr, (list, tuple)):
66
+ if isinstance(min_lr, (list, tuple)): # pragma: no cover
67
67
  if len(min_lr) != len(optimizer.param_groups):
68
68
  Logger.critical("expected {} min_lrs, got {}".format(
69
69
  len(optimizer.param_groups), len(min_lr))) # pragma: no cover
@@ -117,7 +117,7 @@ class ReduceLROnPlateauWithReset:
117
117
  self.num_bad_epochs += 1
118
118
 
119
119
  # Handle cooldown period
120
- if self.in_cooldown:
120
+ if self.in_cooldown: # pragma: no cover
121
121
  self.cooldown_counter -= 1
122
122
  self.num_bad_epochs = 0 # Ignore any bad epochs in cooldown
123
123
 
@@ -142,7 +142,7 @@ class ReduceLROnPlateauWithReset:
142
142
  new_lr = max(old_lr * self.factor, self.min_lrs[i])
143
143
  if old_lr - new_lr > self.eps:
144
144
  param_group['lr'] = new_lr
145
- if self.verbose:
145
+ if self.verbose: # pragma: no cover
146
146
  epoch_str = ("%.2f" if isinstance(epoch, float) else "%.5d") % epoch
147
147
  print('Epoch {}: reducing learning rate'
148
148
  ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr))
@@ -168,19 +168,19 @@ class ReduceLROnPlateauWithReset:
168
168
  Returns:
169
169
  bool: True if the new value is better, False otherwise.
170
170
  """
171
- if best is None:
171
+ if best is None: # pragma: no cover
172
172
  return True
173
173
 
174
174
  if self.mode == 'min' and self.threshold_mode == 'rel':
175
175
  rel_epsilon = 1. - self.threshold
176
176
  return a < best * rel_epsilon
177
- elif self.mode == 'min' and self.threshold_mode == 'abs':
177
+ elif self.mode == 'min' and self.threshold_mode == 'abs': # pragma: no cover
178
178
  return a < best - self.threshold
179
- elif self.mode == 'max' and self.threshold_mode == 'rel':
179
+ elif self.mode == 'max' and self.threshold_mode == 'rel': # pragma: no cover
180
180
  rel_epsilon = self.threshold + 1.
181
181
  return a > best * rel_epsilon
182
182
  else: # mode == 'max' and threshold_mode == 'abs':
183
- return a > best + self.threshold
183
+ return a > best + self.threshold # pragma: no cover
184
184
 
185
185
  def _init_is_better(self) -> None:
186
186
  """
@@ -197,9 +197,9 @@ class ReduceLROnPlateauWithReset:
197
197
  if self.mode == 'min':
198
198
  self.mode_worse = float('inf')
199
199
  else: # mode == 'max':
200
- self.mode_worse = float('-inf')
200
+ self.mode_worse = float('-inf') # pragma: no cover
201
201
 
202
- def state_dict(self) -> Dict[str, Any]:
202
+ def state_dict(self) -> Dict[str, Any]: # pragma: no cover
203
203
  """
204
204
  Return the state of the scheduler as a dictionary.
205
205
 
@@ -208,7 +208,7 @@ class ReduceLROnPlateauWithReset:
208
208
  """
209
209
  return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
210
210
 
211
- def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
211
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: # pragma: no cover
212
212
  """
213
213
  Load the scheduler state.
214
214
 
@@ -21,6 +21,7 @@ from model_compression_toolkit.logger import Logger
21
21
 
22
22
  if FOUND_TF:
23
23
  import keras
24
+ from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
24
25
  from model_compression_toolkit.exporter.model_wrapper.keras.validate_layer import is_keras_layer_exportable
25
26
  from model_compression_toolkit.exporter.model_exporter.keras.fakely_quant_keras_exporter import \
26
27
  FakelyQuantKerasExporter
@@ -36,6 +37,7 @@ if FOUND_TF:
36
37
  KerasExportSerializationFormat.TFLITE: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.INT8]
37
38
  }
38
39
 
40
+ @set_keras_info
39
41
  def keras_export_model(model: keras.models.Model,
40
42
  save_model_path: str,
41
43
  is_layer_exportable_fn: Callable = is_keras_layer_exportable,
@@ -73,23 +73,25 @@ if FOUND_ONNX:
73
73
  Returns:
74
74
  Fake-quant PyTorch model.
75
75
  """
76
- # List all activation quantization holders with num_bits>8 and replace them with Identity, because
77
- # ONNX doesn't support quantization of more than 8 bits for torch.fake_quantize_per_tensor_affine.
78
- act_holder_list = [n for n, m in self.model.named_modules()
79
- if isinstance(m, PytorchActivationQuantizationHolder) and
80
- m.activation_holder_quantizer.num_bits > 8]
81
- for act_holder in act_holder_list: # pragma: no cover
82
- obj = self.model
83
- attrs = act_holder.split(".")
84
- for a in attrs[:-1]:
85
- obj = getattr(obj, a)
86
- if hasattr(obj, attrs[-1]):
87
- delattr(obj, attrs[-1])
88
- setattr(obj, attrs[-1], torch.nn.Identity())
89
- else:
90
- Logger.info(f"During removal of activation quantization of a quantizer (with bits > 8) in ONNX FQ "
91
- f"export, deletion of activation holder '{act_holder}' failed — could not locate one or"
92
- f"more intermediate attributes in the path.")
76
+ # When exporting using Fakely Quant Quantization Format list all activation quantization holders with
77
+ # num_bits>8 and replace them with Identity, because ONNX doesn't support quantization of more than 8 bits
78
+ # for torch.fake_quantize_per_tensor_affine.
79
+ if not self._use_onnx_custom_quantizer_ops:
80
+ act_holder_list = [n for n, m in self.model.named_modules()
81
+ if isinstance(m, PytorchActivationQuantizationHolder) and
82
+ m.activation_holder_quantizer.num_bits > 8]
83
+ for act_holder in act_holder_list: # pragma: no cover
84
+ obj = self.model
85
+ attrs = act_holder.split(".")
86
+ for a in attrs[:-1]:
87
+ obj = getattr(obj, a)
88
+ if hasattr(obj, attrs[-1]):
89
+ delattr(obj, attrs[-1])
90
+ setattr(obj, attrs[-1], torch.nn.Identity())
91
+ else:
92
+ Logger.info(f"During removal of activation quantization of a quantizer (with bits > 8) in ONNX"
93
+ f"FQ export, deletion of activation holder '{act_holder}' failed — could not locate"
94
+ f"one or more intermediate attributes in the path.")
93
95
 
94
96
  for layer in self.model.children():
95
97
  self.is_layer_exportable_fn(layer)
@@ -27,6 +27,7 @@ DEFAULT_ONNX_OPSET_VERSION = 15
27
27
 
28
28
  if FOUND_TORCH:
29
29
  import torch.nn
30
+ from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
30
31
  from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_onnx_pytorch_exporter import FakelyQuantONNXPyTorchExporter
31
32
  from model_compression_toolkit.exporter.model_exporter.pytorch.fakely_quant_torchscript_pytorch_exporter import FakelyQuantTorchScriptPyTorchExporter
32
33
  from model_compression_toolkit.exporter.model_wrapper.pytorch.validate_layer import is_pytorch_layer_exportable
@@ -41,6 +42,7 @@ if FOUND_TORCH:
41
42
  PytorchExportSerializationFormat.ONNX: [QuantizationFormat.FAKELY_QUANT, QuantizationFormat.MCTQ]
42
43
  }
43
44
 
45
+ @set_pytorch_info
44
46
  def pytorch_export_model(model: torch.nn.Module,
45
47
  save_model_path: str,
46
48
  repr_dataset: Callable,
@@ -14,8 +14,8 @@
14
14
  # ==============================================================================
15
15
  from typing import Tuple, List
16
16
 
17
- from model_compression_toolkit.core import FrameworkInfo
18
17
  from model_compression_toolkit.logger import Logger
18
+ from model_compression_toolkit.core.common.framework_info import get_fw_info
19
19
  from model_compression_toolkit.core.common.graph.base_graph import Graph
20
20
  from model_compression_toolkit.core.common.graph.base_node import BaseNode
21
21
 
@@ -40,8 +40,7 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
40
40
  compare_points_name = []
41
41
  for n in input_graph.get_topo_sorted_nodes():
42
42
  # only nodes with kernel attribute are currently trained with GPTQ and are used as compare points
43
- kernel_attr = input_graph.fw_info.get_kernel_op_attributes(n.type)[0]
44
- if kernel_attr is not None and n.is_weights_quantization_enabled(kernel_attr) and not n.reuse:
43
+ if n.kernel_attr is not None and n.is_weights_quantization_enabled(n.kernel_attr) and not n.reuse:
45
44
  compare_points.append(n)
46
45
  compare_points_name.append(n.name)
47
46
  compare_points_std.append(n.prior_info.std_output)
@@ -49,20 +48,15 @@ def get_compare_points(input_graph: Graph) -> Tuple[List[BaseNode], List[str], L
49
48
  return compare_points, compare_points_name, compare_points_mean, compare_points_std
50
49
 
51
50
 
52
- def get_kernel_attribute_name_for_gptq(layer_type: type, fw_info: FrameworkInfo) -> str:
51
+ def get_kernel_attribute_name_for_gptq(layer_type: type) -> str:
53
52
  """
54
53
  Returns a layer's kernel attribute name for GPTQ training purposes.
55
54
 
56
55
  Args:
57
56
  layer_type: A type of model's layer.
58
- fw_info: A FrameworkInfo object.
59
57
 
60
58
  Returns: The name of the kernel attribute.
61
59
 
62
60
  """
63
- kernel_attribute = fw_info.get_kernel_op_attributes(layer_type)
64
- if len(kernel_attribute) != 1:
65
- Logger.critical( # pragma: no cover
66
- f"In GPTQ training, only the kernel weights attribute should be trained. "
67
- f"However, the number of kernel attributes is {len(kernel_attribute)}.")
68
- return kernel_attribute[0]
61
+
62
+ return get_fw_info().get_kernel_op_attribute(layer_type)
@@ -44,7 +44,6 @@ class GPTQTrainer(ABC):
44
44
  graph_quant: Graph,
45
45
  gptq_config: GradientPTQConfig,
46
46
  fw_impl: GPTQFrameworkImplemantation,
47
- fw_info: FrameworkInfo,
48
47
  representative_data_gen_fn: Callable[[], Generator],
49
48
  hessian_info_service: HessianInfoService = None):
50
49
  """
@@ -58,7 +57,6 @@ class GPTQTrainer(ABC):
58
57
  graph_quant: Graph to build a quantized networks from.
59
58
  gptq_config: GradientPTQConfig with parameters about the tuning process.
60
59
  fw_impl: Framework implementation
61
- fw_info: Framework information
62
60
  representative_data_gen_fn: factory for representative data generator.
63
61
  hessian_info_service: HessianInfoService for fetching and computing Hessian-approximation information.
64
62
  """
@@ -66,7 +64,6 @@ class GPTQTrainer(ABC):
66
64
  self.graph_quant = copy.deepcopy(graph_quant)
67
65
  self.gptq_config = gptq_config
68
66
  self.fw_impl = fw_impl
69
- self.fw_info = fw_info
70
67
  self.representative_data_gen_fn = representative_data_gen_fn
71
68
 
72
69
  def _get_total_grad_steps():
@@ -83,8 +80,7 @@ class GPTQTrainer(ABC):
83
80
 
84
81
  self.float_model, self.float_user_info = fw_impl.model_builder(self.graph_float,
85
82
  mode=ModelBuilderMode.FLOAT,
86
- append2output=self.compare_points,
87
- fw_info=self.fw_info)
83
+ append2output=self.compare_points)
88
84
 
89
85
  self.fxp_model, self.gptq_user_info = self.build_gptq_model()
90
86
  if self.gptq_config.hessian_weights_config:
@@ -288,7 +284,6 @@ def gptq_training(graph_float: Graph,
288
284
  gptq_config: GradientPTQConfig,
289
285
  representative_data_gen: Callable,
290
286
  fw_impl: GPTQFrameworkImplemantation,
291
- fw_info: FrameworkInfo,
292
287
  hessian_info_service: HessianInfoService = None) -> Graph:
293
288
  """
294
289
  GPTQ training process using knowledge distillation with a teacher network (float model) and a student network (quantized model).
@@ -298,7 +293,6 @@ def gptq_training(graph_float: Graph,
298
293
  gptq_config: GradientPTQConfig with parameters about the tuning process.
299
294
  representative_data_gen: Dataset to use for inputs of the models.
300
295
  fw_impl: Framework implementation
301
- fw_info: Framework information
302
296
  hessian_info_service: HessianInfoService to fetch information based on the Hessian approximation.
303
297
 
304
298
  Returns:
@@ -312,7 +306,6 @@ def gptq_training(graph_float: Graph,
312
306
  graph_quant,
313
307
  gptq_config,
314
308
  fw_impl,
315
- fw_info,
316
309
  representative_data_gen,
317
310
  hessian_info_service=hessian_info_service)
318
311
 
@@ -65,7 +65,6 @@ class KerasGPTQTrainer(GPTQTrainer):
65
65
  graph_quant: Graph,
66
66
  gptq_config: GradientPTQConfig,
67
67
  fw_impl: FrameworkImplementation,
68
- fw_info: FrameworkInfo,
69
68
  representative_data_gen: Callable,
70
69
  hessian_info_service: HessianInfoService = None):
71
70
  """
@@ -79,7 +78,6 @@ class KerasGPTQTrainer(GPTQTrainer):
79
78
  graph_quant: Graph to build a quantized networks from.
80
79
  gptq_config: GradientPTQConfig with parameters about the tuning process.
81
80
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
82
- fw_info: Framework information.
83
81
  representative_data_gen: Dataset to use for inputs of the models.
84
82
  hessian_info_service: HessianScoresService for fetching and computing Hessian's approximation scores.
85
83
 
@@ -94,7 +92,6 @@ class KerasGPTQTrainer(GPTQTrainer):
94
92
  graph_quant,
95
93
  gptq_config,
96
94
  fw_impl,
97
- fw_info,
98
95
  representative_data_gen_fn=representative_data_gen,
99
96
  hessian_info_service=hessian_info_service)
100
97
 
@@ -210,8 +207,7 @@ class KerasGPTQTrainer(GPTQTrainer):
210
207
  Returns:
211
208
  A boolean whether the layer is to be wrapped with a QuantizeWrapper
212
209
  """
213
- kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
214
- return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
210
+ return node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)
215
211
 
216
212
  def gptq_wrapper(self,
217
213
  n: common.BaseNode,
@@ -230,7 +226,7 @@ class KerasGPTQTrainer(GPTQTrainer):
230
226
  # If we are here, then the node has a kernel attribute to quantize and training during GPTQ
231
227
  weights_quantizers, _ = quantization_builder(n,
232
228
  self.gptq_config, # TODO: split quantizers building into two functions: for weights and activations
233
- self.fw_info.get_kernel_op_attributes(n.type)[0])
229
+ n.kernel_attr)
234
230
  if len(weights_quantizers) > 0:
235
231
  return KerasTrainableQuantizationWrapper(layer,
236
232
  weights_quantizers=weights_quantizers)
@@ -271,7 +267,6 @@ class KerasGPTQTrainer(GPTQTrainer):
271
267
 
272
268
  gptq_model, gptq_user_info = KerasModelBuilder(graph=self.graph_quant,
273
269
  append2output=self.compare_points,
274
- fw_info=self.fw_info,
275
270
  return_float_outputs=True,
276
271
  wrapper=self.gptq_wrapper,
277
272
  get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
@@ -431,8 +426,7 @@ class KerasGPTQTrainer(GPTQTrainer):
431
426
  Logger.critical(f"Unable to update the GPTQ graph because the layer named '{layer.layer.name}' could not be found. "
432
427
  f"Verify that the layer names in the GPTQ model match those in the graph.")
433
428
  node = node[0]
434
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
435
- fw_info=self.fw_info)
429
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type)
436
430
  # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
437
431
  # To enable GPTQ for other attributes, this code needs to be modified.
438
432
  weights, weight_quant_config, activation_quant_config = \
@@ -16,8 +16,8 @@
16
16
  import tensorflow as tf
17
17
  from typing import Tuple, List
18
18
  from model_compression_toolkit.core.keras.constants import USE_BIAS
19
+ from model_compression_toolkit.core.common.framework_info import get_fw_info
19
20
  from tensorflow.keras.models import Model
20
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
21
21
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
22
  from model_compression_toolkit.logger import Logger
23
23
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
@@ -44,8 +44,7 @@ def get_gptq_trainable_parameters(fxp_model: Model,
44
44
 
45
45
  for layer in fxp_model.layers:
46
46
  if isinstance(layer, KerasTrainableQuantizationWrapper):
47
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
48
- fw_info=DEFAULT_KERAS_INFO)
47
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
49
48
 
50
49
  # collect trainable weights per quantizer
51
50
  if kernel_attribute not in layer.weights_quantizers:
@@ -57,9 +56,8 @@ def get_gptq_trainable_parameters(fxp_model: Model,
57
56
  trainable_threshold.extend(quantizer_trainable_threshold)
58
57
 
59
58
  if add_bias:
60
- kernel_ops_attrs = DEFAULT_KERAS_INFO.kernel_ops_attributes_mapping.get(type(layer.layer))
61
- use_bias = kernel_ops_attrs is not None and kernel_ops_attrs[0] is not None \
62
- and layer.layer.get_config().get(USE_BIAS)
59
+ kernel_ops_attr = get_fw_info().get_kernel_op_attribute(type(layer.layer))
60
+ use_bias = kernel_ops_attr is not None and layer.layer.get_config().get(USE_BIAS)
63
61
  if use_bias is not None and use_bias and layer.layer.bias is not None:
64
62
  bias_weights.append([layer.layer.bias])
65
63
 
@@ -41,7 +41,7 @@ from model_compression_toolkit.metadata import create_model_metadata
41
41
 
42
42
  if FOUND_TF:
43
43
  import tensorflow as tf
44
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
44
+ from model_compression_toolkit.core.keras.default_framework_info import set_keras_info
45
45
  from model_compression_toolkit.gptq.keras.gptq_keras_implementation import GPTQKerasImplemantation
46
46
  from model_compression_toolkit.core.keras.keras_model_validation import KerasModelValidation
47
47
  from tensorflow.keras.models import Model
@@ -152,6 +152,7 @@ if FOUND_TF:
152
152
  gradual_activation_quantization_config=gradual_quant_config)
153
153
 
154
154
 
155
+ @set_keras_info
155
156
  def keras_gradient_post_training_quantization(in_model: Model, representative_data_gen: Callable,
156
157
  gptq_config: GradientPTQConfig,
157
158
  gptq_representative_data_gen: Callable = None,
@@ -234,8 +235,7 @@ if FOUND_TF:
234
235
  if core_config.debug_config.bypass:
235
236
  return in_model, None
236
237
 
237
- KerasModelValidation(model=in_model,
238
- fw_info=DEFAULT_KERAS_INFO).validate()
238
+ KerasModelValidation(model=in_model).validate()
239
239
 
240
240
  if core_config.is_mixed_precision_enabled:
241
241
  if not isinstance(core_config.mixed_precision_config, MixedPrecisionQuantizationConfig):
@@ -243,7 +243,7 @@ if FOUND_TF:
243
243
  "Ensure usage of the correct API for keras_post_training_quantization "
244
244
  "or provide a valid mixed-precision configuration.") # pragma: no cover
245
245
 
246
- tb_w = init_tensorboard_writer(DEFAULT_KERAS_INFO)
246
+ tb_w = init_tensorboard_writer()
247
247
 
248
248
  fw_impl = GPTQKerasImplemantation()
249
249
 
@@ -257,7 +257,6 @@ if FOUND_TF:
257
257
  tg, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=in_model,
258
258
  representative_data_gen=representative_data_gen,
259
259
  core_config=core_config,
260
- fw_info=DEFAULT_KERAS_INFO,
261
260
  fw_impl=fw_impl,
262
261
  fqc=framework_platform_capabilities,
263
262
  target_resource_utilization=target_resource_utilization,
@@ -271,7 +270,6 @@ if FOUND_TF:
271
270
  gptq_config,
272
271
  representative_data_gen,
273
272
  gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
274
- DEFAULT_KERAS_INFO,
275
273
  fw_impl,
276
274
  tb_w,
277
275
  hessian_info_service=hessian_info_service)
@@ -283,8 +281,7 @@ if FOUND_TF:
283
281
  tb_w,
284
282
  float_graph,
285
283
  tg_gptq,
286
- fw_impl,
287
- DEFAULT_KERAS_INFO)
284
+ fw_impl)
288
285
 
289
286
  exportable_model, user_info = get_exportable_keras_model(tg_gptq)
290
287
  if framework_platform_capabilities.tpc.add_metadata:
@@ -17,7 +17,6 @@ from typing import List, Callable
17
17
  import tensorflow as tf
18
18
  from keras import Model
19
19
 
20
- from model_compression_toolkit.core.keras.default_framework_info import DEFAULT_KERAS_INFO
21
20
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
22
21
  from model_compression_toolkit.trainable_infrastructure import KerasTrainableQuantizationWrapper
23
22
 
@@ -66,8 +65,7 @@ class SoftQuantizerRegularization:
66
65
 
67
66
  # Compute the regularization term without concatenating
68
67
  for i, layer in enumerate(layers):
69
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
70
- fw_info=DEFAULT_KERAS_INFO)
68
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
71
69
 
72
70
  st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
73
71
 
@@ -54,7 +54,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
54
54
  graph_quant: Graph,
55
55
  gptq_config: GradientPTQConfig,
56
56
  fw_impl: FrameworkImplementation,
57
- fw_info: FrameworkInfo,
58
57
  representative_data_gen: Callable,
59
58
  hessian_info_service: HessianInfoService = None):
60
59
  """
@@ -68,7 +67,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
68
67
  graph_quant: Graph to build a quantized networks from.
69
68
  gptq_config: GradientPTQConfigV2 with parameters about the tuning process.
70
69
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
71
- fw_info: Framework information
72
70
  representative_data_gen: Dataset to use for inputs of the models.
73
71
  hessian_info_service: HessianInfoService to fetch info based on the hessian approximation of the float model.
74
72
  """
@@ -81,7 +79,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
81
79
  graph_quant,
82
80
  gptq_config,
83
81
  fw_impl,
84
- fw_info,
85
82
  representative_data_gen_fn=representative_data_gen,
86
83
  hessian_info_service=hessian_info_service)
87
84
 
@@ -167,8 +164,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
167
164
  A boolean whether the layer is to be wrapped with a Quantization Wrapper.
168
165
  """
169
166
 
170
- kernel_attr = self.fw_info.get_kernel_op_attributes(node.type)[0]
171
- return kernel_attr is not None and node.is_weights_quantization_enabled(kernel_attr)
167
+ return node.kernel_attr is not None and node.is_weights_quantization_enabled(node.kernel_attr)
172
168
 
173
169
  def gptq_wrapper(self,
174
170
  n: BaseNode,
@@ -187,7 +183,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
187
183
  # If we are here, then the node has a kernel attribute to quantize and training during GPTQ
188
184
  weights_quantizers, _ = quantization_builder(n,
189
185
  self.gptq_config,
190
- self.fw_info.get_kernel_op_attributes(n.type)[0])
186
+ n.kernel_attr)
191
187
 
192
188
  if len(weights_quantizers) > 0:
193
189
  return PytorchQuantizationWrapper(layer,
@@ -224,7 +220,6 @@ class PytorchGPTQTrainer(GPTQTrainer):
224
220
  """
225
221
  gptq_model, gptq_user_info = PyTorchModelBuilder(graph=self.graph_quant,
226
222
  append2output=self.compare_points,
227
- fw_info=self.fw_info,
228
223
  wrapper=self.gptq_wrapper,
229
224
  return_float_outputs=True,
230
225
  get_activation_quantizer_holder_fn=self.get_activation_quantizer_holder).build_model()
@@ -340,8 +335,7 @@ class PytorchGPTQTrainer(GPTQTrainer):
340
335
  Logger.critical(f"Cannot update GPTQ graph: Layer with name '{name}' is missing or not unique. "
341
336
  f"Ensure each layer has a unique name and exists within the graph for updates.")
342
337
  node = node[0]
343
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type,
344
- fw_info=self.fw_info)
338
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=node.type)
345
339
  # TODO: only kernel attributes are currently trained in GPTQ, so only the kernel weights need to be updated.
346
340
  # To enable GPTQ for other attributes, this code needs to be modified.
347
341
  weights, weight_quant_config, activation_quant_config = \
@@ -16,7 +16,6 @@ import torch
16
16
  import torch.nn as nn
17
17
  from typing import List
18
18
  from model_compression_toolkit.core.pytorch.constants import BIAS
19
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
20
19
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
21
20
  from model_compression_toolkit.logger import Logger
22
21
  from mct_quantizers import PytorchQuantizationWrapper
@@ -43,8 +42,7 @@ def get_gptq_trainable_parameters(fxp_model: nn.Module,
43
42
 
44
43
  for layer in fxp_model.modules():
45
44
  if isinstance(layer, PytorchQuantizationWrapper):
46
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
47
- fw_info=DEFAULT_PYTORCH_INFO)
45
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
48
46
 
49
47
  # collect trainable weights per quantizer
50
48
  if kernel_attribute not in layer.weights_quantizers:
@@ -39,7 +39,7 @@ from model_compression_toolkit.verify_packages import FOUND_TORCH
39
39
 
40
40
 
41
41
  if FOUND_TORCH:
42
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
42
+ from model_compression_toolkit.core.pytorch.default_framework_info import set_pytorch_info
43
43
  from model_compression_toolkit.gptq.pytorch.gptq_pytorch_implementation import GPTQPytorchImplemantation
44
44
  from model_compression_toolkit.target_platform_capabilities.constants import DEFAULT_TP_MODEL
45
45
  from model_compression_toolkit.gptq.pytorch.gptq_loss import multiple_tensors_mse_loss, sample_layer_attention_loss
@@ -142,6 +142,8 @@ if FOUND_TORCH:
142
142
  gradual_activation_quantization_config=gradual_quant_config,
143
143
  log_function=log_function)
144
144
 
145
+
146
+ @set_pytorch_info
145
147
  def pytorch_gradient_post_training_quantization(model: Module,
146
148
  representative_data_gen: Callable,
147
149
  target_resource_utilization: ResourceUtilization = None,
@@ -216,8 +218,7 @@ if FOUND_TORCH:
216
218
  Logger.critical("Given quantization config for mixed-precision is not of type 'MixedPrecisionQuantizationConfig'. "
217
219
  "Ensure usage of the correct API for 'pytorch_gradient_post_training_quantization' "
218
220
  "or provide a valid mixed-precision configuration.")
219
-
220
- tb_w = init_tensorboard_writer(DEFAULT_PYTORCH_INFO)
221
+ tb_w = init_tensorboard_writer()
221
222
 
222
223
  fw_impl = GPTQPytorchImplemantation()
223
224
 
@@ -233,7 +234,6 @@ if FOUND_TORCH:
233
234
  graph, bit_widths_config, hessian_info_service, scheduling_info = core_runner(in_model=model,
234
235
  representative_data_gen=representative_data_gen,
235
236
  core_config=core_config,
236
- fw_info=DEFAULT_PYTORCH_INFO,
237
237
  fw_impl=fw_impl,
238
238
  fqc=framework_quantization_capabilities,
239
239
  target_resource_utilization=target_resource_utilization,
@@ -250,7 +250,6 @@ if FOUND_TORCH:
250
250
  gptq_config,
251
251
  representative_data_gen,
252
252
  gptq_representative_data_gen if gptq_representative_data_gen else representative_data_gen,
253
- DEFAULT_PYTORCH_INFO,
254
253
  fw_impl,
255
254
  tb_w,
256
255
  hessian_info_service=hessian_info_service)
@@ -260,8 +259,7 @@ if FOUND_TORCH:
260
259
  tb_w,
261
260
  float_graph,
262
261
  graph_gptq,
263
- fw_impl,
264
- DEFAULT_PYTORCH_INFO)
262
+ fw_impl)
265
263
 
266
264
  exportable_model, user_info = get_exportable_pytorch_model(graph_gptq)
267
265
  if framework_quantization_capabilities.tpc.add_metadata:
@@ -18,7 +18,6 @@ import torch
18
18
  from torch import nn
19
19
 
20
20
  from mct_quantizers import PytorchQuantizationWrapper
21
- from model_compression_toolkit.core.pytorch.default_framework_info import DEFAULT_PYTORCH_INFO
22
21
  from model_compression_toolkit.gptq.common.gptq_graph import get_kernel_attribute_name_for_gptq
23
22
 
24
23
 
@@ -61,8 +60,7 @@ class SoftQuantizerRegularization:
61
60
  b = self.beta_scheduler(self.count_iter)
62
61
  reg = 0
63
62
  for layer, w in zip(layers, layer_weights):
64
- kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer),
65
- fw_info=DEFAULT_PYTORCH_INFO)
63
+ kernel_attribute = get_kernel_attribute_name_for_gptq(layer_type=type(layer.layer))
66
64
 
67
65
  st = layer.weights_quantizers[kernel_attribute].get_soft_targets()
68
66
  soft_loss = (1 - torch.pow(torch.abs(st - .5) * 2, b)).sum()
@@ -37,7 +37,6 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
37
37
  tb_w: TensorboardWriter,
38
38
  tg: Graph,
39
39
  tg_bias: Graph,
40
- fw_info: FrameworkInfo,
41
40
  fw_impl: FrameworkImplementation,
42
41
  hessian_info_service: HessianInfoService = None) -> Graph:
43
42
  """
@@ -52,7 +51,6 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
52
51
  tb_w: TensorBoardWriter object to log events.
53
52
  tg: Float Reference Graph.
54
53
  tg_bias: Graph of quantized model.
55
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.).
56
54
  fw_impl: Framework implementation per framework
57
55
  hessian_info_service: HessianInfoService to fetch information based on the hessian approximation for the float model.
58
56
  Returns:
@@ -64,7 +62,6 @@ def _apply_gptq(gptq_config: GradientPTQConfig,
64
62
  gptq_config,
65
63
  representative_data_gen,
66
64
  fw_impl,
67
- fw_info,
68
65
  hessian_info_service=hessian_info_service)
69
66
 
70
67
  if tb_w is not None:
@@ -77,7 +74,6 @@ def gptq_runner(tg: Graph,
77
74
  gptq_config: GradientPTQConfig,
78
75
  representative_data_gen: Callable,
79
76
  gptq_representative_data_gen: Callable,
80
- fw_info: FrameworkInfo,
81
77
  fw_impl: FrameworkImplementation,
82
78
  tb_w: TensorboardWriter,
83
79
  hessian_info_service: HessianInfoService = None) -> Graph:
@@ -91,7 +87,6 @@ def gptq_runner(tg: Graph,
91
87
  gptq_config: GradientPTQConfig with parameters about the tuning process.
92
88
  representative_data_gen: Dataset used for calibration.
93
89
  gptq_representative_data_gen: Dataset used for GPTQ training
94
- fw_info: Information needed for quantization about the specific framework (e.g., kernel channels indices, groups of layers by how they should be quantized, etc.)
95
90
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
96
91
  tb_w: A TensorBoardWriter object initialized with the logger dir path if it was set, or None otherwise.
97
92
  hessian_info_service: HessianScoresService to fetch approximations of the hessian scores for the float model.
@@ -104,7 +99,7 @@ def gptq_runner(tg: Graph,
104
99
  #############################################
105
100
  # Apply Statistics Correction
106
101
  #############################################
107
- tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config, fw_info, fw_impl, tb_w)
102
+ tg_bias = apply_statistics_correction(tg, representative_data_gen, core_config, fw_impl, tb_w)
108
103
 
109
104
  if tb_w is not None:
110
105
  tb_w.add_graph(tg_bias, 'after_bias_correction')
@@ -117,7 +112,6 @@ def gptq_runner(tg: Graph,
117
112
  tb_w,
118
113
  tg,
119
114
  tg_bias,
120
- fw_info,
121
115
  fw_impl,
122
116
  hessian_info_service=hessian_info_service)
123
117