mct-nightly 2.3.0.20250514.602__py3-none-any.whl → 2.3.0.20250516.613__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (17) hide show
  1. {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/METADATA +2 -2
  2. {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/RECORD +17 -17
  3. {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/WHEEL +1 -1
  4. model_compression_toolkit/__init__.py +1 -1
  5. model_compression_toolkit/constants.py +5 -0
  6. model_compression_toolkit/core/common/fusion/fusing_info.py +75 -7
  7. model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +8 -2
  8. model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py +11 -8
  9. model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py +1 -1
  10. model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py +54 -14
  11. model_compression_toolkit/target_platform_capabilities/schema/v2.py +7 -5
  12. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py +4 -1
  13. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py +3 -0
  14. model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py +25 -6
  15. model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py +17 -2
  16. {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/licenses/LICENSE.md +0 -0
  17. {mct_nightly-2.3.0.20250514.602.dist-info → mct_nightly-2.3.0.20250516.613.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250514.602
3
+ Version: 2.3.0.20250516.613
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Author-email: ssi-dnn-dev@sony.com
6
6
  Classifier: Programming Language :: Python :: 3
@@ -21,7 +21,7 @@ Requires-Dist: PuLP
21
21
  Requires-Dist: matplotlib<3.10.0
22
22
  Requires-Dist: scipy
23
23
  Requires-Dist: protobuf
24
- Requires-Dist: mct-quantizers-nightly
24
+ Requires-Dist: mct-quantizers==1.6.0
25
25
  Requires-Dist: pydantic>=2.0
26
26
  Requires-Dist: edge-mdt-cl-dev
27
27
  Dynamic: author-email
@@ -1,6 +1,6 @@
1
- mct_nightly-2.3.0.20250514.602.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=r8-owZy9MZyc6lGIyuRz2eQeNhwA3DhfnJCgu0VSbhI,1557
3
- model_compression_toolkit/constants.py,sha256=iJ6vfTjC2oFIZWt8wvHoxEw5YJi3yl0Hd4q30_8q0Zc,3958
1
+ mct_nightly-2.3.0.20250516.613.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=vpUrhwvqdXrPwyv56B5qlxS71UYcqZVGpzP-u2yJA9E,1557
3
+ model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
6
6
  model_compression_toolkit/metadata.py,sha256=x_Bk4VpzILdsFax6--CZ3X18qUTP28sbF_AhoQW8dNc,4003
@@ -31,7 +31,7 @@ model_compression_toolkit/core/common/collectors/min_max_per_channel_collector.p
31
31
  model_compression_toolkit/core/common/collectors/statistics_collector.py,sha256=psijsQZefwjMDH8SU5E18n65HiGtQilPhKr1hhzZX-I,8268
32
32
  model_compression_toolkit/core/common/collectors/weighted_histogram_collector.py,sha256=zp3dE7YTqWmkD5QWdRhsl9zD8W6Lr96G1Wjw1g2D3T0,4894
33
33
  model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
34
- model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=W8qZejLwbm-lkvNF3GepNL3ypO10vFRxOxbq-o_rt_I,15479
34
+ model_compression_toolkit/core/common/fusion/fusing_info.py,sha256=S7hBbUJxL52Z8uJ9_upLdFyoSEJvgmVX0OmneqDIj-c,18656
35
35
  model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=F0AaAUBpJ9JjHMB5H2LD9pdwTSWJK-Kqm9dQmGHX1Jo,7368
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
37
  model_compression_toolkit/core/common/graph/base_graph.py,sha256=BSQpKy0BXoGX0G0bySTo72n2isTqvtpkbRYYa8-hPO4,41435
@@ -233,7 +233,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
233
233
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
234
234
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
235
235
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=HAzzWOnPcIeDxQO1712254RNTBZD-gVSMSVnxqpfuQ0,11907
236
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Z-ZQV-GWdOBGPbksiWBQ8MtFkQ41qgUKU5d5c8aNSjQ,21646
236
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=FVewVclo3kx-Oufr_PJE4-MAqkKJseBvd96vz8JtuBg,22163
237
237
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
238
238
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
239
239
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -349,7 +349,7 @@ model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer
349
349
  model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
350
350
  model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=vQUGbCi8_pGoN8DwQ0IblSeN6L9t6Cr0reZNuCbBpkM,3469
351
351
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
352
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=gKLKQaVlIx8Rt04aA5EXnG53D1x5N8gaSfUnmip3UK4,6851
352
+ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=AKSpWbTtXHPjW7hY655OXANaK5SgEiF-FZCu5zoioxM,6860
353
353
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=Pl8a8MSZMzNbm5vngujFjCt_iSMbSmKjlcL1DvN9nTM,9292
354
354
  model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
355
355
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
@@ -433,20 +433,20 @@ model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha2
433
433
  model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=8RVOriZg-XNjSt53h_4Yum0oRgOe2gp5H45dfG_lZxE,1415
434
434
  model_compression_toolkit/target_platform_capabilities/constants.py,sha256=JRz9DoxLRpkqvu532TFkIvv0595Bfb9NtU4pRp4urDY,1540
435
435
  model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
436
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=nbmlygR-nc3bzwnUDrRamq3a6KFkC4-cCpbUeF7EEmo,4626
436
+ model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=hFBq-qKUM9qKZGaMmrxsEmurTV_D1kWIXI1rTERZsbk,5241
437
437
  model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
438
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
439
- model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=TtMPbiibV6Hk53nl5Y_ctfpI6mSbd8VVH9fxnv5j9eM,4430
438
+ model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=hf539WJ3nBGn0RnALXrKmAPnbhJ-VmWmLIa207x8b4M,541
439
+ model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=OpZ9SH2aTAVTCBfj1m3wcAeouk_q_16yWxCwByXK_M8,6294
440
440
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
441
441
  model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=oWKNQnnz04kmijmdWtRyXgVXbJ6BG_V_bUBz_MfUM94,27116
442
- model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=hryYeGK0zJ2ffcRpHihudtYpl8kIl1WTAQOEsyerqlM,10813
442
+ model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=1hrvq4EeLDRe0-wvpHkMLXMYYbETQ_tX-3FAHHsxb18,10880
443
443
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
444
444
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
445
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=9r_lDvRYtbGLKjnH1yLuP4vxWn0_4xS4AkdDhvBg7Ko,7154
446
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=NCwuvnByeexLL987h67XhU8vQvCgq63bt0hFSiSSxvE,6400
445
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=5Uyb5CurpLm4fgOiARKYwy3T-bb0NMmJXIRBgRjMgjo,7301
446
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=R-kTbJka37u3toun9rRDGGGXYR3Sv4VdirLIn5G1BgQ,6541
447
447
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
448
448
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
449
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
449
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=Y-HZKwoakzY6PAYYj9l-h19yLMqBs0qBHo2YIKIsrN8,10375
450
450
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities_component.py,sha256=9Hg6AMCzTdDsKKgivRd61UjxGT5SWvKsc3mIUPPsYDQ,1021
451
451
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/layer_filter_params.py,sha256=dIu6k1xvGKLtk_47wq1eKYvrS4lYAknAXTeJfFstW0Y,3878
452
452
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/operations_to_layers.py,sha256=vZ7I2XDr_YDgU8oQt8gKkcuUOJf28DCzCPunPK2h_Xw,6563
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250514.602.dist-info/METADATA,sha256=G4fXoMcNvxB_V_F7hCVDXAJh5Q2bOmcBRS9FzqtQngM,25136
532
- mct_nightly-2.3.0.20250514.602.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
533
- mct_nightly-2.3.0.20250514.602.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250514.602.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250516.613.dist-info/METADATA,sha256=WTFK8E9-__LO8PW9LL44DHCYKwUvNgKMkGl0ko8dcf0,25135
532
+ mct_nightly-2.3.0.20250516.613.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
533
+ mct_nightly-2.3.0.20250516.613.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250516.613.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.4.0)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250514.000602"
30
+ __version__ = "2.3.0.20250516.000613"
@@ -138,3 +138,8 @@ SHAPE = 'shape'
138
138
  NODE_NAME = 'node_name'
139
139
  TOTAL_SIZE = 'total_size'
140
140
  NODE_OUTPUT_INDEX = 'node_output_index'
141
+
142
+
143
+ # Fusing Patterns constants
144
+ FUSED_LAYER_PATTERN = 'fused_layer_pattern'
145
+ FUSED_OP_QUANT_CONFIG = 'fused_op_quantization_config'
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from model_compression_toolkit.target_platform_capabilities import LayerFilterParams
17
+ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig
18
+ from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
17
19
  from dataclasses import dataclass, field
18
20
 
19
21
  from typing import Optional, List, Dict, Any, Tuple
@@ -41,6 +43,7 @@ class FusingInfo:
41
43
  fusing_patterns: any = None
42
44
  fusing_data: Dict[str, Tuple['BaseNode']] = field(default_factory=dict)
43
45
  node_to_fused_node_map: Dict[str, str] = field(init=False, default_factory=dict)
46
+ fused_op_id_to_quant_config: Dict[str, OpQuantizationConfig] = field(default_factory=dict)
44
47
 
45
48
  def __post_init__(self):
46
49
  """Validates and initializes mappings after dataclass instantiation."""
@@ -49,6 +52,7 @@ class FusingInfo:
49
52
  assert isinstance(op_nodes, tuple) and len(op_nodes) > 1, f"Found invalid fused op nodes: {op_nodes}"
50
53
 
51
54
  self._init_node_mapping()
55
+ self._init_quantization_config_map()
52
56
 
53
57
  def _init_node_mapping(self) -> None:
54
58
  """
@@ -59,6 +63,15 @@ class FusingInfo:
59
63
  for node in nodes:
60
64
  self.node_to_fused_node_map[node.name] = op_id
61
65
 
66
+ def _init_quantization_config_map(self) -> None:
67
+ """
68
+ Init the mapping between fused operation IDs and their quantization configurations.
69
+ """
70
+ self.fused_op_id_to_quant_config.clear()
71
+ if self.fusing_patterns is not None:
72
+ for op_id, nodes in self.fusing_data.items():
73
+ self.set_fused_op_quantization_config(op_id, nodes)
74
+
62
75
  def add_fused_operation(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
63
76
  """
64
77
  Add a new fused operation with the given ID and set of nodes.
@@ -78,6 +91,22 @@ class FusingInfo:
78
91
  for node in nodes:
79
92
  self.node_to_fused_node_map[node.name] = op_id
80
93
 
94
+ # Update the quantization config mapping for this operation
95
+ if self.fusing_patterns is not None:
96
+ self.set_fused_op_quantization_config(op_id, nodes)
97
+
98
+ def set_fused_op_quantization_config(self, op_id: str, nodes: Tuple['BaseNode']) -> None:
99
+ """
100
+ Set the quantization configuration for a given fused operation ID.
101
+
102
+ Args:
103
+ op_id (str): The identifier for the fused operation.
104
+ nodes (Tuple[BaseNode]): The tuple of nodes that form the fused operation.
105
+ """
106
+ fusing_pattern = next((fp for fp in self.fusing_patterns if is_valid_fusion([fp.get(FUSED_LAYER_PATTERN)], nodes)), None)
107
+ if fusing_pattern is not None:
108
+ self.fused_op_id_to_quant_config[op_id] = fusing_pattern.get(FUSED_OP_QUANT_CONFIG)
109
+
81
110
  def remove_fused_operation(self, op_id: str) -> None:
82
111
  """
83
112
  Remove a fused operation by its ID.
@@ -95,6 +124,7 @@ class FusingInfo:
95
124
  for node in nodes:
96
125
  self.node_to_fused_node_map.pop(node.name, None)
97
126
  del self.fusing_data[op_id]
127
+ self.fused_op_id_to_quant_config.pop(op_id, None)
98
128
 
99
129
  def get_fused_node_name(self, node_name: str) -> Optional[str]:
100
130
  """
@@ -117,6 +147,15 @@ class FusingInfo:
117
147
  """
118
148
  return self.node_to_fused_node_map.copy()
119
149
 
150
+ def get_fusing_quantization_config_map(self) -> Dict[str, OpQuantizationConfig]:
151
+ """
152
+ Retrieve a copy of the mapping from fused operation IDs to their quantization configurations.
153
+
154
+ Returns:
155
+ A dictionary mapping each fused operation ID to its quantization configuration.
156
+ """
157
+ return self.fused_op_id_to_quant_config.copy()
158
+
120
159
  def get_fused_nodes(self, op_id: str) -> Optional[List['BaseNode']]:
121
160
  """
122
161
  Retrieve the list of nodes for a given fused operation ID.
@@ -129,6 +168,18 @@ class FusingInfo:
129
168
  """
130
169
  return self.fusing_data.get(op_id)
131
170
 
171
+ def get_fused_op_quantization_config(self, op_id: str) -> OpQuantizationConfig:
172
+ """
173
+ Retrieve the quantization configuration for a given fused operation ID.
174
+
175
+ Args:
176
+ op_id (str): The identifier for the fused operation.
177
+
178
+ Returns:
179
+ OpQuantizationConfig: The quantization configuration for the operation, or None if not found.
180
+ """
181
+ return self.fused_op_id_to_quant_config.get(op_id)
182
+
132
183
  def is_node_in_fused_op(self, node: 'BaseNode') -> bool:
133
184
  """
134
185
  Check if a node is part of any fused operation.
@@ -216,10 +267,11 @@ class FusingInfo:
216
267
  all_fused_nodes.update(node_set)
217
268
 
218
269
  # Check 4: Ensure the sequence matches a valid fusing pattern
219
- if not is_valid_fusion(self.fusing_patterns, nodes):
270
+ valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
271
+ if not is_valid_fusion(valid_fusing_patterns, nodes):
220
272
  raise ValueError(
221
273
  f"Fused operation {op_id} does not match any valid fusing pattern "
222
- f"from {self.fusing_patterns}."
274
+ f"from {valid_fusing_patterns}."
223
275
  )
224
276
 
225
277
  def is_nodes_eligible_to_be_fused(self, nodes: List['BaseNode']) -> bool:
@@ -240,7 +292,8 @@ class FusingInfo:
240
292
  return False
241
293
 
242
294
  # Check if the provided nodes match a valid fusion pattern
243
- return is_valid_fusion(fusing_patterns=self.fusing_patterns, nodes=nodes)
295
+ valid_fusing_patterns = _get_fusing_layer_patterns(self.fusing_patterns)
296
+ return is_valid_fusion(fusing_patterns=valid_fusing_patterns, nodes=nodes)
244
297
 
245
298
  def __repr__(self) -> str:
246
299
  """
@@ -287,8 +340,11 @@ class FusingInfoGenerator:
287
340
  if not self._fusing_patterns:
288
341
  return FusingInfo(fusing_patterns=self._fusing_patterns)
289
342
 
343
+ # Extract fusing layer patterns
344
+ fusing_layer_patterns = _get_fusing_layer_patterns(self._fusing_patterns)
345
+
290
346
  # Find max fusion
291
- max_layers_fusing = max([len(fusing_pattern) for fusing_pattern in self._fusing_patterns])
347
+ max_layer_patterns = max([len(fusing_layer_pattern) for fusing_layer_pattern in fusing_layer_patterns])
292
348
 
293
349
  # Travel along the graph to find layers for fusing
294
350
  nodes = graph.get_topo_sorted_nodes()
@@ -302,9 +358,9 @@ class FusingInfoGenerator:
302
358
  continue
303
359
  # Start fusing search
304
360
  fusing_nodes = [] # nodes that are candidates for participating in fusing
305
- patterns = copy.deepcopy(self._fusing_patterns)
361
+ patterns = copy.deepcopy(fusing_layer_patterns)
306
362
  next_nodes = [node]
307
- for i in range(max_layers_fusing):
363
+ for i in range(max_layer_patterns):
308
364
  patterns = get_valid_fusing_patterns_for_node(patterns, next_nodes[0], i)
309
365
  if len(patterns) == 0: # Give up if no more fusion pattern
310
366
  break
@@ -314,7 +370,7 @@ class FusingInfoGenerator:
314
370
  break
315
371
 
316
372
  # New fusion
317
- if is_valid_fusion(self._fusing_patterns, fusing_nodes):
373
+ if is_valid_fusion(fusing_layer_patterns, fusing_nodes):
318
374
  fused_op_id = FusingInfo.generate_fused_op_id(fusing_nodes)
319
375
  assert fused_op_id not in fusing_info, f"{fused_op_id} is already in fusing info: {fusing_info}"
320
376
  fusing_info[fused_op_id] = tuple(fusing_nodes)
@@ -371,3 +427,15 @@ def is_valid_fusion(fusing_patterns: List[List[Any]], nodes: List['BaseNode']) -
371
427
  if counter == fusion_depth:
372
428
  return True
373
429
  return False
430
+
431
+
432
+ def _get_fusing_layer_patterns(fusing_patterns: List[Dict[Any, OpQuantizationConfig]]) -> List[List[Any]]:
433
+ """
434
+ Extracts the fusing layer patterns from the provided fusing patterns.
435
+ Args:
436
+ fusing_patterns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
437
+
438
+ Returns:
439
+ supported fusing layer patterns
440
+ """
441
+ return [f.get(FUSED_LAYER_PATTERN) for f in fusing_patterns]
@@ -233,6 +233,7 @@ class PytorchModel(torch.nn.Module):
233
233
  self.return_float_outputs = return_float_outputs
234
234
  self.wrapper = wrapper
235
235
  self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
236
+ self.insert_preserving_quantizers = graph.fqc.insert_preserving_quantizers
236
237
  self.reuse_groups = {}
237
238
  self._reused_nodes = []
238
239
 
@@ -335,12 +336,17 @@ class PytorchModel(torch.nn.Module):
335
336
  activation_quantizer_holder = None
336
337
  if self.use_activation_holder_during_model_building:
337
338
  if node.is_activation_quantization_enabled():
338
- activation_quantizer_holder = self.get_activation_quantizer_holder(node, holder_type=PytorchActivationQuantizationHolder)
339
+ activation_quantizer_holder = self.get_activation_quantizer_holder(node,
340
+ holder_type=PytorchActivationQuantizationHolder)
339
341
 
340
342
  elif node.is_quantization_preserving():
341
343
  prev_node = self.graph.retrieve_preserved_quantization_node(node)
342
344
  if prev_node.is_activation_quantization_enabled():
343
- activation_quantizer_holder = self.get_activation_quantizer_holder(prev_node, holder_type=PytorchPreservingActivationQuantizationHolder)
345
+ if self.insert_preserving_quantizers:
346
+ holder_kwargs = {'quantization_bypass': True}
347
+ activation_quantizer_holder = self.get_activation_quantizer_holder(prev_node,
348
+ holder_type=PytorchPreservingActivationQuantizationHolder,
349
+ **holder_kwargs)
344
350
 
345
351
  if activation_quantizer_holder is not None:
346
352
  activation_quantizer_holder_name = node.name + '_' + ACTIVATION_HOLDER_QUANTIZER
@@ -65,26 +65,29 @@ if FOUND_TORCH:
65
65
  return module
66
66
 
67
67
 
68
- def get_activation_quantizer_holder(node: BaseNode, holder_type: PytorchActivationQuantizationHolder, fw_impl) -> Callable:
68
+ def get_activation_quantizer_holder(node: BaseNode, holder_type: PytorchActivationQuantizationHolder,
69
+ fw_impl, **kwargs) -> Callable:
69
70
  """
70
71
  Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
71
72
  If the layer is not supposed to be wrapped with an activation quantizer - return None.
73
+
72
74
  Args:
73
75
  node: Node to attach a PytorchActivationQuantizationHolder to its output.
74
76
  holder_type: The type of the activation quantization holder to use.
75
77
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
78
+ **kwargs: Key-arguments to be passed to the quantization holder initialization to set specific arguments
79
+ based on the holder's type.
80
+
76
81
  Returns:
77
82
  A PytorchActivationQuantizationHolder module for the node's activation quantization.
78
83
  """
79
84
  # Holder by definition uses a single quantizer for the activation quantization
80
- # thus we make sure this is the only possible case (unless it's a node we no activation
85
+ # thus we make sure this is the only possible case (unless it's a node with no activation
81
86
  # quantization, which in this case has an empty list).
82
87
  _, activation_quantizers = fw_impl.get_inferable_quantizers(node)
83
88
  if len(activation_quantizers) == 1:
84
- if holder_type == PytorchActivationQuantizationHolder:
85
- return holder_type(activation_quantizers[0])
86
- elif holder_type == PytorchPreservingActivationQuantizationHolder:
87
- return holder_type(activation_quantizers[0], quantization_bypass=True)
89
+ return holder_type(activation_quantizers[0], **kwargs)
90
+
88
91
  Logger.critical(
89
92
  f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
90
93
  f'were found for node {node}')
@@ -105,9 +108,9 @@ if FOUND_TORCH:
105
108
  wrapper=lambda n, m:
106
109
  fully_quantized_wrapper(n, m,
107
110
  fw_impl=fw_impl),
108
- get_activation_quantizer_holder_fn=lambda n, holder_type:
111
+ get_activation_quantizer_holder_fn=lambda n, holder_type, **kwargs:
109
112
  get_activation_quantizer_holder(n, holder_type,
110
- fw_impl=fw_impl)).build_model()
113
+ fw_impl=fw_impl, **kwargs)).build_model()
111
114
 
112
115
  Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
113
116
  "Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
@@ -1,4 +1,4 @@
1
- import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
1
+ import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema
2
2
 
3
3
  OperatorSetNames = schema.OperatorSetNames
4
4
  Signedness = schema.Signedness
@@ -12,14 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
15
16
  from typing import Any, Union
16
17
 
17
18
  import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema_v1
18
19
  import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema_v2
19
20
  import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as current_schema
20
21
 
21
- ALL_SCHEMA_VERSIONS = [schema_v1] # needs to be updated with all active schema versions
22
- FUTURE_SCHEMA_VERSIONS = [schema_v2] # once future schema becomes current schema, move to it ALL_SCHEMA_VERSIONS
22
+ ALL_SCHEMA_VERSIONS = [schema_v1, schema_v2] # needs to be updated with all active schema versions
23
+ FUTURE_SCHEMA_VERSIONS = [] # once future schema becomes current schema, move to it ALL_SCHEMA_VERSIONS
23
24
  all_tpc_types = tuple([s.TargetPlatformCapabilities for s in ALL_SCHEMA_VERSIONS])
24
25
  tpc_or_str_type = all_tpc_types + (str,)
25
26
 
@@ -33,19 +34,57 @@ def is_tpc_instance(tpc_obj_or_path: Any) -> bool:
33
34
  return type(tpc_obj_or_path) in all_tpc_types
34
35
 
35
36
 
36
- def _schema_v1_to_v2(tpc: schema_v1.TargetPlatformCapabilities) -> schema_v2.TargetPlatformCapabilities: # pragma: no cover
37
+ def get_schema_by_version(schema_version: str):
38
+ return {
39
+ "1": schema_v1,
40
+ "2": schema_v2
41
+ }[schema_version]
42
+
43
+
44
+ def _schema_v1_to_v2(
45
+ tpc: schema_v1.TargetPlatformCapabilities) -> schema_v2.TargetPlatformCapabilities:
37
46
  """
38
- Converts given tpc of schema version 1 to schema version 2
39
- :return: TargetPlatformCapabilities instance of of schema version 2
47
+ Converts given tpc of schema version 1 to schema version 2.
48
+ Schema v2 updates:
49
+ 1. New TPC field: insert_preserving_quantizers
50
+ Compatability behavior: Set field to False by default
51
+ 2. New field in Fusing: fuse_op_quantization_config
52
+ Compatability behavior: set enable_activation_quantization=False in every fuse_op_quantization_config
53
+ 3. New operator set names: EXP, SIN, COS
54
+ Compatability behavior: Not required
55
+ :return: TargetPlatformCapabilities instance of schema version 2
40
56
  """
41
- raise NotImplementedError("Once schema v2 is implemented, add necessary adaptations to _schema_v1_to_v2 function and remove 'pragma: no cover'")
42
- return schema_v2.TargetPlatformCapabilities(default_qco=tpc.default_qco,
43
- operator_set=tpc.operator_set,
44
- fusing_patterns=tpc.fusing_patterns,
45
- tpc_minor_version=tpc.tpc_minor_version,
46
- tpc_patch_version=tpc.tpc_patch_version,
47
- tpc_platform_type=tpc.tpc_platform_type,
48
- add_metadata=tpc.add_metadata)
57
+ v1_default_qco = tpc.default_qco.base_config
58
+ v2_default_qco = schema_v2.OpQuantizationConfig(
59
+ default_weight_attr_config=v1_default_qco.default_weight_attr_config,
60
+ attr_weights_configs_mapping=v1_default_qco.attr_weights_configs_mapping,
61
+ activation_quantization_method=v1_default_qco.activation_quantization_method,
62
+ activation_n_bits=v1_default_qco.activation_n_bits,
63
+ supported_input_activation_n_bits=v1_default_qco.supported_input_activation_n_bits,
64
+ enable_activation_quantization=False, # set to False by default because feature not exist in schema v1
65
+ quantization_preserving=v1_default_qco.quantization_preserving,
66
+ fixed_scale=v1_default_qco.fixed_scale,
67
+ fixed_zero_point=v1_default_qco.fixed_zero_point,
68
+ simd_size=v1_default_qco.simd_size,
69
+ signedness=v1_default_qco.signedness)
70
+
71
+ schema_v2_fusing_patters = []
72
+ for fussing_pattern in tpc.fusing_patterns:
73
+ schema_v2_fusing_patters.append(
74
+ schema_v2.Fusing(operator_groups=fussing_pattern.operator_groups,
75
+ fuse_op_quantization_config=copy.deepcopy(v2_default_qco),
76
+ name=fussing_pattern.name))
77
+
78
+ tpc_schema_v2 = schema_v2.TargetPlatformCapabilities(default_qco=tpc.default_qco,
79
+ operator_set=tpc.operator_set,
80
+ fusing_patterns=schema_v2_fusing_patters,
81
+ tpc_minor_version=tpc.tpc_minor_version,
82
+ tpc_patch_version=tpc.tpc_patch_version,
83
+ tpc_platform_type=tpc.tpc_platform_type,
84
+ add_metadata=tpc.add_metadata,
85
+ insert_preserving_quantizers=False) # set to False by default because feature not exist in schema v1
86
+ return tpc_schema_v2
87
+
49
88
 
50
89
  def get_conversion_map() -> dict:
51
90
  """
@@ -60,7 +99,8 @@ def get_conversion_map() -> dict:
60
99
  return conversion_map
61
100
 
62
101
 
63
- def tpc_to_current_schema_version(tpc: Union[all_tpc_types]) -> current_schema.TargetPlatformCapabilities: # pragma: no cover
102
+ def tpc_to_current_schema_version(
103
+ tpc: Union[all_tpc_types]) -> current_schema.TargetPlatformCapabilities:
64
104
  """
65
105
  Given tpc instance of some schema version, convert it to the current MCT schema version.
66
106
 
@@ -91,6 +91,8 @@ class OperatorSetNames(str, Enum):
91
91
  STRIDED_SLICE = "StridedSlice"
92
92
  SSD_POST_PROCESS = "SSDPostProcess"
93
93
  EXP = "Exp"
94
+ SIN = "Sin"
95
+ COS = "Cos"
94
96
 
95
97
  @classmethod
96
98
  def get_values(cls):
@@ -218,11 +220,11 @@ class TargetPlatformCapabilities(BaseModel):
218
220
  SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
219
221
  """
220
222
  default_qco: QuantizationConfigOptions
221
- operator_set: Optional[Tuple[OperatorsSet, ...]]
222
- fusing_patterns: Optional[Tuple[Fusing, ...]]
223
- tpc_minor_version: Optional[int]
224
- tpc_patch_version: Optional[int]
225
- tpc_platform_type: Optional[str]
223
+ operator_set: Optional[Tuple[OperatorsSet, ...]] = None
224
+ fusing_patterns: Optional[Tuple[Fusing, ...]] = None
225
+ tpc_minor_version: Optional[int] = None
226
+ tpc_patch_version: Optional[int] = None
227
+ tpc_platform_type: Optional[str] = None
226
228
  add_metadata: bool = True
227
229
  name: Optional[str] = "default_tpc"
228
230
 
@@ -100,7 +100,10 @@ class AttachTpcToKeras(AttachTpcToFramework):
100
100
  OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
101
101
  OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
102
102
  OperatorSetNames.L2NORM: [tf.math.l2_normalize],
103
- OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess]
103
+ OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess],
104
+ OperatorSetNames.EXP: [tf.math.exp],
105
+ OperatorSetNames.SIN: [tf.math.sin],
106
+ OperatorSetNames.COS: [tf.math.cos]
104
107
  }
105
108
 
106
109
  self._opset2attr_mapping = {
@@ -99,6 +99,9 @@ class AttachTpcToPytorch(AttachTpcToFramework):
99
99
  Eq('p', 2) | Eq('p', None))],
100
100
  OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
101
101
  OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
102
+ OperatorSetNames.EXP: [torch.exp],
103
+ OperatorSetNames.SIN: [torch.sin],
104
+ OperatorSetNames.COS: [torch.cos],
102
105
  }
103
106
 
104
107
  pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
@@ -31,6 +31,9 @@ from model_compression_toolkit.target_platform_capabilities.schema.mct_current_s
31
31
  OpQuantizationConfig, QuantizationConfigOptions
32
32
  from model_compression_toolkit.target_platform_capabilities.targetplatform2framework.current_tpc import _current_tpc
33
33
 
34
+ from model_compression_toolkit.constants import FUSED_LAYER_PATTERN, FUSED_OP_QUANT_CONFIG
35
+
36
+
34
37
  class FrameworkQuantizationCapabilities(ImmutableClass):
35
38
  """
36
39
  Attach framework information to a modeled hardware.
@@ -94,20 +97,26 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
94
97
  """
95
98
  return self.op_sets_to_layers.get_layers_by_op(op)
96
99
 
97
- def get_fusing_patterns(self) -> List[List[Any]]:
100
+ def get_fusing_patterns(self) -> List[Dict[List[Any], OpQuantizationConfig]]:
98
101
  """
99
102
 
100
- Returns: List of patterns of layers/LayerFilterParams to fuse.
103
+ Returns: List of patterns of layers/LayerFilterParams to fuse and their mapping quantization config.
101
104
 
102
105
  """
103
- res = []
106
+
107
+ patterns = []
104
108
  if self.tpc.fusing_patterns is None:
105
- return res
109
+ return patterns
110
+
106
111
  for p in self.tpc.fusing_patterns:
112
+ res = []
107
113
  ops = [self.get_layers_by_opset(x) for x in p.operator_groups]
108
114
  res.extend(itertools.product(*ops))
109
- return [list(x) for x in res]
110
115
 
116
+ fused_op_quant_config = getattr(p, FUSED_OP_QUANT_CONFIG, None)
117
+ patterns.extend({FUSED_LAYER_PATTERN: list(x), FUSED_OP_QUANT_CONFIG: fused_op_quant_config} for x in res)
118
+
119
+ return patterns
111
120
 
112
121
  def get_info(self) -> Dict[str, Any]:
113
122
  """
@@ -230,7 +239,17 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
230
239
  def is_simd_padding(self) -> bool:
231
240
  """
232
241
 
233
- Returns: Check if the TP model defines that padding due to SIMD constrains occurs.
242
+ Returns: Check if the TPC defines that padding due to SIMD constrains occurs.
234
243
 
235
244
  """
236
245
  return self.tpc.is_simd_padding
246
+
247
+ @property
248
+ def insert_preserving_quantizers(self) -> bool:
249
+ """
250
+
251
+ Returns: Check if the TPC defines that a quantizer for quantization preserving operators should be added to the
252
+ constructed model.
253
+
254
+ """
255
+ return self.tpc.insert_preserving_quantizers
@@ -12,12 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import json
15
16
  from pathlib import Path
16
17
  from typing import Union
17
18
 
18
19
  import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
19
20
  from model_compression_toolkit.target_platform_capabilities.schema.schema_compatability import is_tpc_instance, \
20
- tpc_to_current_schema_version, tpc_or_str_type
21
+ tpc_to_current_schema_version, tpc_or_str_type, get_schema_by_version
22
+
23
+
24
+ def _get_json_schema(tpc_json_path: str):
25
+ """
26
+ Given a TPC json file path, extract the schema version from it, and return schema object matched to that
27
+ schema version.
28
+ """
29
+ with open(tpc_json_path, 'r', encoding='utf-8') as f:
30
+ schema_version = str(json.load(f)["SCHEMA_VERSION"])
31
+ return get_schema_by_version(schema_version)
21
32
 
22
33
 
23
34
  def _get_tpc_from_json(tpc_path: str) -> schema.TargetPlatformCapabilities:
@@ -40,7 +51,11 @@ def _get_tpc_from_json(tpc_path: str) -> schema.TargetPlatformCapabilities:
40
51
  raise ValueError(f"Error reading the file '{tpc_path}': {e.strerror}.") from e
41
52
 
42
53
  try:
43
- return schema.TargetPlatformCapabilities.parse_raw(data)
54
+ # json_schema = _get_json_schema(tpc_path)
55
+ # tpc = json_schema.TargetPlatformCapabilities.parse_raw(data)
56
+ # return tpc_to_current_schema_version(tpc)
57
+ tpc = schema.TargetPlatformCapabilities.parse_raw(data)
58
+ return tpc_to_current_schema_version(tpc)
44
59
  except ValueError as e:
45
60
  raise ValueError(f"Invalid JSON for loading TargetPlatformCapabilities in '{tpc_path}': {e}.") from e
46
61
  except Exception as e: