mct-nightly 2.3.0.20250428.605__py3-none-any.whl → 2.3.0.20250429.622__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250428.605
3
+ Version: 2.3.0.20250429.622
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -51,7 +51,7 @@ ______________________________________________________________________
51
51
  </p>
52
52
  <p align="center">
53
53
  <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/pytorch-2.2%20%7C%202.3%20%7C%202.4%20%7C%202.5-blue" /></a>
54
- <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/tensorflow-2.12%20%7C%202.13%20%7C%202.14%20%7C%202.15-blue" /></a>
54
+ <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/tensorflow-02.14%20%7C%202.15-blue" /></a>
55
55
  <a href="https://sony.github.io/model_optimization#prerequisites"><img src="https://img.shields.io/badge/python-3.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue" /></a>
56
56
  <a href="https://github.com/sony/model_optimization/releases"><img src="https://img.shields.io/github/v/release/sony/model_optimization" /></a>
57
57
  <a href="https://github.com/sony/model_optimization/blob/main/LICENSE.md"><img src="https://img.shields.io/badge/license-Apache%202.0-blue" /></a>
@@ -171,7 +171,7 @@ Currently, MCT is being tested on various Python, Pytorch and TensorFlow version
171
171
  | Python 3.12 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch22.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch23.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch24.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python312_pytorch25.yml) |
172
172
 
173
173
  | | TensorFlow 2.14 | TensorFlow 2.15 |
174
- |-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
174
+ |-------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
175
175
  | Python 3.9 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python39_keras215.yml) |
176
176
  | Python 3.10 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python310_keras215.yml) |
177
177
  | Python 3.11 | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras214.yml) | [![Run Tests](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml/badge.svg)](https://github.com/sony/model_optimization/actions/workflows/run_tests_python311_keras215.yml) |
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250428.605.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=KTddrZxT3r5G_WJ3NWOixbYFMVcLk032Ii0ssUccyic,1557
1
+ mct_nightly-2.3.0.20250429.622.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=PrThtvqTwsbfiKBPdckZFWhbBXaRfAATT21j_WbV8pA,1557
3
3
  model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -66,11 +66,11 @@ model_compression_toolkit/core/common/mixed_precision/configurable_quant_id.py,s
66
66
  model_compression_toolkit/core/common/mixed_precision/configurable_quantizer_utils.py,sha256=7dKMi5S0zQZ16m8NWn1XIuoXsKuZUg64G4-uK8-j1PQ,5177
67
67
  model_compression_toolkit/core/common/mixed_precision/distance_weighting.py,sha256=-x8edUyudu1EAEM66AuXPtgayLpzbxoLNubfEbFM5kU,2867
68
68
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_candidates_filter.py,sha256=6pLUEEIqRTVIlCYQC4JIvY55KAvuBHEX8uTOQ-1Ac4Q,3859
69
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=r1t025_QHshyoop-PZvL7x6UuXaeplCCU3h4VNBhJHo,4309
69
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_quantization_config.py,sha256=onHgDwfw8CUbZFNU-RYit9eqA6FrzAtFA3akVZ2d7IM,4533
70
70
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_ru_helper.py,sha256=-hOMBucYn12ePyLd0b1KxniPOIRu4b53SwEzv0bWToI,4943
71
71
  model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_facade.py,sha256=d5-3j2e_rdcQOT7c4s0p7640i3nSetjJ6MgMhhMM7dc,6152
72
- model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=658DBP0sY6DRqEbFcK1gX4EGQMeaBSFE5-7_Py6sioE,37718
73
- model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=4bkM8pYKvk18cxHbx973Dz6qWrNT0MRm44cuk__qVaI,27297
72
+ model_compression_toolkit/core/common/mixed_precision/mixed_precision_search_manager.py,sha256=J8io_axti6gRoch9QR0FmKOP8JSHGeKqX95rf-nG6fI,37719
73
+ model_compression_toolkit/core/common/mixed_precision/sensitivity_evaluation.py,sha256=R3UIO9lKf-lpEGfJOqgpQAXdP1IWMatWxXKYDkhWj_E,28096
74
74
  model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,sha256=P8QtKgFXtt5b2RoubzI5OGlCfbEfZsAirjyrkFzK26A,2846
75
75
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=S1ChgxtUjzXJufNWyRbKoNdyNC6fGUjPeComDMx8ZCo,9479
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250428.605.dist-info/METADATA,sha256=qHKhtkD9E5Npa0vcNQc376dwsvBE6iUM0aiTV1S76qg,25560
532
- mct_nightly-2.3.0.20250428.605.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
533
- mct_nightly-2.3.0.20250428.605.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250428.605.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250429.622.dist-info/METADATA,sha256=R9WTA_IVw4cvLIbQbGpmzTl_ujwHc4RNElUajWHnSNE,25101
532
+ mct_nightly-2.3.0.20250429.622.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
533
+ mct_nightly-2.3.0.20250429.622.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250429.622.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250428.000605"
30
+ __version__ = "2.3.0.20250429.000622"
@@ -27,6 +27,7 @@ class MixedPrecisionQuantizationConfig:
27
27
  Args:
28
28
  compute_distance_fn (Callable): Function to compute a distance between two tensors. If None, using pre-defined distance methods based on the layer type for each layer.
29
29
  distance_weighting_method (MpDistanceWeighting): MpDistanceWeighting enum value that provides a function to use when weighting the distances among different layers when computing the sensitivity metric.
30
+ custom_metric_fn (Callable): Function to compute a custom metric. As input gets the model_mp and returns a float value for metric. If None, uses interest point metric.
30
31
  num_of_images (int): Number of images to use to evaluate the sensitivity of a mixed-precision model comparing to the float model.
31
32
  configuration_overwrite (List[int]): A list of integers that enables overwrite of mixed precision with a predefined one.
32
33
  num_interest_points_factor (float): A multiplication factor between zero and one (represents percentage) to reduce the number of interest points used to calculate the distance metric.
@@ -39,6 +40,7 @@ class MixedPrecisionQuantizationConfig:
39
40
 
40
41
  compute_distance_fn: Optional[Callable] = None
41
42
  distance_weighting_method: MpDistanceWeighting = MpDistanceWeighting.AVG
43
+ custom_metric_fn: Optional[Callable] = None
42
44
  num_of_images: int = MP_DEFAULT_NUM_SAMPLES
43
45
  configuration_overwrite: Optional[List[int]] = None
44
46
  num_interest_points_factor: float = field(default=1.0, metadata={"description": "Should be between 0.0 and 1.0"})
@@ -169,6 +169,7 @@ class MixedPrecisionSearchManager:
169
169
  return self.sensitivity_evaluator.compute_metric(topo_cfg(cfg),
170
170
  node_idx,
171
171
  topo_cfg(baseline_cfg) if baseline_cfg else None)
172
+
172
173
  if self.using_virtual_graph:
173
174
  origin_max_config = self.config_reconstruction_helper.reconstruct_config_from_virtual_graph(
174
175
  self.max_ru_config)
@@ -89,6 +89,9 @@ class SensitivityEvaluation:
89
89
  self.interest_points = get_mp_interest_points(graph,
90
90
  fw_impl.count_node_for_mixed_precision_interest_points,
91
91
  quant_config.num_interest_points_factor)
92
+ # If using a custom metric - return only model outputs
93
+ if self.quant_config.custom_metric_fn is not None:
94
+ self.interest_points = []
92
95
 
93
96
  # We use normalized MSE when not running hessian-based. For Hessian-based normalized MSE is not needed
94
97
  # because hessian weights already do normalization.
@@ -96,6 +99,9 @@ class SensitivityEvaluation:
96
99
  self.ips_distance_fns, self.ips_axis = self._init_metric_points_lists(self.interest_points, use_normalized_mse)
97
100
 
98
101
  self.output_points = get_output_nodes_for_metric(graph)
102
+ # If using a custom metric - return all model outputs
103
+ if self.quant_config.custom_metric_fn is not None:
104
+ self.output_points = [n.node for n in graph.get_outputs()]
99
105
  self.out_ps_distance_fns, self.out_ps_axis = self._init_metric_points_lists(self.output_points,
100
106
  use_normalized_mse)
101
107
 
@@ -160,7 +166,7 @@ class SensitivityEvaluation:
160
166
  """
161
167
  Compute the sensitivity metric of the MP model for a given configuration (the sensitivity
162
168
  is computed based on the similarity of the interest points' outputs between the MP model
163
- and the float model).
169
+ and the float model or a custom metric if given).
164
170
 
165
171
  Args:
166
172
  mp_model_configuration: Bitwidth configuration to use to configure the MP model.
@@ -177,15 +183,21 @@ class SensitivityEvaluation:
177
183
  node_idx)
178
184
 
179
185
  # Compute the distance metric
180
- ipts_distances, out_pts_distances = self._compute_distance()
186
+ if self.quant_config.custom_metric_fn is None:
187
+ ipts_distances, out_pts_distances = self._compute_distance()
188
+ sensitivity_metric = self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
189
+ self.quant_config.distance_weighting_method)
190
+ else:
191
+ sensitivity_metric = self.quant_config.custom_metric_fn(self.model_mp)
192
+ if not isinstance(sensitivity_metric, (float, np.floating)):
193
+ raise TypeError(f'The custom_metric_fn is expected to return float or numpy float, got {type(sensitivity_metric).__name__}')
181
194
 
182
195
  # Configure MP model back to the same configuration as the baseline model if baseline provided
183
196
  if baseline_mp_configuration is not None:
184
197
  self._configure_bitwidths_model(baseline_mp_configuration,
185
198
  node_idx)
186
199
 
187
- return self._compute_mp_distance_measure(ipts_distances, out_pts_distances,
188
- self.quant_config.distance_weighting_method)
200
+ return sensitivity_metric
189
201
 
190
202
  def _init_baseline_tensors_list(self):
191
203
  """