mct-nightly 2.3.0.20250224.520__py3-none-any.whl → 2.3.0.20250226.518__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250226.518.dist-info}/METADATA +4 -4
  2. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250226.518.dist-info}/RECORD +23 -22
  3. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250226.518.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/core/common/collectors/histogram_collector.py +19 -20
  6. model_compression_toolkit/core/common/collectors/statistics_collector.py +7 -3
  7. model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py +114 -0
  8. model_compression_toolkit/core/common/framework_implementation.py +9 -4
  9. model_compression_toolkit/core/common/graph/base_node.py +16 -6
  10. model_compression_toolkit/core/common/hessian/hessian_info_service.py +31 -15
  11. model_compression_toolkit/core/common/hessian/hessian_scores_calculator.py +1 -1
  12. model_compression_toolkit/core/common/hessian/hessian_scores_request.py +7 -2
  13. model_compression_toolkit/core/common/model_collector.py +115 -17
  14. model_compression_toolkit/core/common/quantization/quantization_params_generation/error_functions.py +2 -0
  15. model_compression_toolkit/core/common/quantization/quantization_params_generation/qparams_activations_computation.py +110 -33
  16. model_compression_toolkit/core/keras/keras_implementation.py +35 -27
  17. model_compression_toolkit/core/pytorch/hessian/activation_hessian_scores_calculator_pytorch.py +23 -61
  18. model_compression_toolkit/core/pytorch/pytorch_implementation.py +34 -18
  19. model_compression_toolkit/core/quantization_prep_runner.py +1 -0
  20. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py +2 -2
  21. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +2 -1
  22. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250226.518.dist-info}/LICENSE.md +0 -0
  23. {mct_nightly-2.3.0.20250224.520.dist-info → mct_nightly-2.3.0.20250226.518.dist-info}/top_level.txt +0 -0
@@ -186,18 +186,32 @@ class PytorchImplementation(FrameworkImplementation):
186
186
 
187
187
  def run_model_inference(self,
188
188
  model: Any,
189
- input_list: List[Any]) -> Tuple[torch.Tensor]:
189
+ input_list: List[Any],
190
+ requires_grad: bool = False) -> Tuple[torch.Tensor]:
190
191
  """
191
- Run the model logic on the given the inputs.
192
+ Runs the given PyTorch model on the provided input data.
192
193
 
194
+ This method converts the input data into PyTorch tensors, sets the `requires_grad`
195
+ flag if necessary, and runs inference using the provided model.
193
196
  Args:
194
- model: Pytorch model.
195
- input_list: List of inputs for the model.
197
+ model: The PyTorch model to be executed.
198
+ input_list: A list of input data for the model.
199
+ requires_grad: If True, enables gradient computation for the input tensors.
196
200
 
197
201
  Returns:
198
- The Pytorch model's output.
202
+ A tuple containing the model's output tensors.
199
203
  """
200
- return model(*to_torch_tensor(input_list))
204
+ # Convert input list elements into PyTorch tensors
205
+ torch_tensor_list = to_torch_tensor(input_list)
206
+
207
+ # If gradients are required, enable tracking and gradient retention for each tensor
208
+ if requires_grad:
209
+ for input_tensor in torch_tensor_list:
210
+ input_tensor.requires_grad_()
211
+ input_tensor.retain_grad()
212
+
213
+ # Run the model with the prepared input tensors
214
+ return model(*torch_tensor_list)
201
215
 
202
216
  def shift_negative_correction(self,
203
217
  graph: Graph,
@@ -492,21 +506,23 @@ class PytorchImplementation(FrameworkImplementation):
492
506
 
493
507
  Returns: The MAC count of the operation
494
508
  """
509
+ kernels = fw_info.get_kernel_op_attributes(node.type)
510
+ if not kernels or kernels[0] is None:
511
+ return 0
495
512
 
496
- output_shape = node.output_shape[0]
497
- kernel_shape = node.get_weights_by_keys(fw_info.get_kernel_op_attributes(node.type)[0]).shape
498
- output_channel_axis, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
513
+ assert len(kernels) == 1
514
+ kernel_shape = node.get_weights_by_keys(kernels[0]).shape
499
515
 
500
516
  if node.is_match_type(Conv2d) or node.is_match_type(ConvTranspose2d):
501
- # (C_out * W_out * H_out) * C_in * (W_kernel * H_kernel)
502
- return np.prod([x for x in output_shape if x is not None]) * \
503
- kernel_shape[input_channel_axis] * \
504
- (kernel_shape[0] * kernel_shape[1])
505
- elif node.is_match_type(Linear):
506
- # IN * OUT
507
- return kernel_shape[0] * kernel_shape[1]
508
- else:
509
- return 0
517
+ h, w = node.get_output_shapes_list()[0][-2:]
518
+ return np.prod(kernel_shape) * h * w
519
+
520
+ if node.is_match_type(Linear):
521
+ # IN * OUT * (all previous dims[:-1])
522
+ _, input_channel_axis = fw_info.kernel_channels_mapping.get(node.type)
523
+ return node.get_total_output_params() * kernel_shape[input_channel_axis]
524
+
525
+ return 0
510
526
 
511
527
  def apply_second_moment_correction(self,
512
528
  quantized_model: Any,
@@ -69,6 +69,7 @@ def quantization_preparation_runner(graph: Graph,
69
69
  mi = ModelCollector(graph,
70
70
  fw_impl,
71
71
  fw_info,
72
+ hessian_info_service,
72
73
  core_config.quantization_config) # Mark points for statistics collection
73
74
 
74
75
  for _data in tqdm(representative_data_gen(), "Statistics Collection"):
@@ -39,9 +39,9 @@ class AttachTpcToFramework:
39
39
 
40
40
  tpc = FrameworkQuantizationCapabilities(tpc_model)
41
41
  custom_opset2layer = custom_opset2layer if custom_opset2layer is not None else {}
42
-
42
+ operator_set = tpc_model.operator_set or ()
43
43
  with tpc:
44
- for opset in tpc_model.operator_set:
44
+ for opset in operator_set:
45
45
  if isinstance(opset, OperatorsSet): # filter out OperatorsSetConcat
46
46
  if opset.name in custom_opset2layer:
47
47
  custom_opset_layers = custom_opset2layer[opset.name]
@@ -52,7 +52,8 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
52
52
  self.op_sets_to_layers = OperationsToLayers() # Init an empty OperationsToLayers
53
53
  self.layer2qco, self.filterlayer2qco = {}, {} # Init empty mappings from layers/LayerFilterParams to QC options
54
54
  # Track the unused opsets for warning purposes.
55
- self.__tpc_opsets_not_used = [s.name for s in tpc.operator_set]
55
+ operator_set = tpc.operator_set or ()
56
+ self.__tpc_opsets_not_used = [s.name for s in operator_set]
56
57
  self.remove_fusing_names_from_not_used_list()
57
58
 
58
59
  def get_layers_by_opset_name(self, opset_name: str) -> List[Any]: