mct-nightly 2.3.0.20250515.544__py3-none-any.whl → 2.3.0.20250517.552__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250515.544
3
+ Version: 2.3.0.20250517.552
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Author-email: ssi-dnn-dev@sony.com
6
6
  Classifier: Programming Language :: Python :: 3
@@ -21,7 +21,7 @@ Requires-Dist: PuLP
21
21
  Requires-Dist: matplotlib<3.10.0
22
22
  Requires-Dist: scipy
23
23
  Requires-Dist: protobuf
24
- Requires-Dist: mct-quantizers-nightly
24
+ Requires-Dist: mct-quantizers==1.6.0
25
25
  Requires-Dist: pydantic>=2.0
26
26
  Requires-Dist: edge-mdt-cl-dev
27
27
  Dynamic: author-email
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250515.544.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=ZuiC7LBUZRbxQhR-vJI5NKeCIc9cX-tIpkHCw_Ynb0o,1557
1
+ mct_nightly-2.3.0.20250517.552.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=RueNJeNFQG6WxXCYDucXbXAnF5xB1DP5nCV-ouC3da0,1557
3
3
  model_compression_toolkit/constants.py,sha256=KNgiNLpsMgSYyXMNEbHXd4bFNerQc1D6HH3vpbUq_Gs,4086
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -233,7 +233,7 @@ model_compression_toolkit/core/pytorch/back2framework/factory_model_builder.py,s
233
233
  model_compression_toolkit/core/pytorch/back2framework/float_model_builder.py,sha256=tLrlUyYhxVKVjkad1ZAtbRra0HedB3iVfIkZ_dYnQ-4,3419
234
234
  model_compression_toolkit/core/pytorch/back2framework/instance_builder.py,sha256=BBHBfTqeWm7L3iDyPBpk0jxvj-rBg1QWI23imkjfIl0,1467
235
235
  model_compression_toolkit/core/pytorch/back2framework/mixed_precision_model_builder.py,sha256=HAzzWOnPcIeDxQO1712254RNTBZD-gVSMSVnxqpfuQ0,11907
236
- model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=Z-ZQV-GWdOBGPbksiWBQ8MtFkQ41qgUKU5d5c8aNSjQ,21646
236
+ model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py,sha256=FVewVclo3kx-Oufr_PJE4-MAqkKJseBvd96vz8JtuBg,22163
237
237
  model_compression_toolkit/core/pytorch/back2framework/quantized_model_builder.py,sha256=qZNNOlNTTV4ZKPG3q5GDXkIVTPUEr8dvxAS_YiMORmg,3456
238
238
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
239
239
  model_compression_toolkit/core/pytorch/back2framework/quantization_wrapper/quantized_layer_wrapper.py,sha256=q2JDw10NKng50ee2i9faGzWZ-IydnR2aOMGSn9RoZmc,5773
@@ -349,7 +349,7 @@ model_compression_toolkit/exporter/model_wrapper/keras/builder/node_to_quantizer
349
349
  model_compression_toolkit/exporter/model_wrapper/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
350
350
  model_compression_toolkit/exporter/model_wrapper/pytorch/validate_layer.py,sha256=vQUGbCi8_pGoN8DwQ0IblSeN6L9t6Cr0reZNuCbBpkM,3469
351
351
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/__init__.py,sha256=cco4TmeIDIh32nj9ZZXVkws4dd9F2UDrmjKzTN8G0V0,697
352
- model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=gKLKQaVlIx8Rt04aA5EXnG53D1x5N8gaSfUnmip3UK4,6851
352
+ model_compression_toolkit/exporter/model_wrapper/pytorch/builder/fully_quantized_model_builder.py,sha256=AKSpWbTtXHPjW7hY655OXANaK5SgEiF-FZCu5zoioxM,6860
353
353
  model_compression_toolkit/exporter/model_wrapper/pytorch/builder/node_to_quantizer.py,sha256=Pl8a8MSZMzNbm5vngujFjCt_iSMbSmKjlcL1DvN9nTM,9292
354
354
  model_compression_toolkit/gptq/__init__.py,sha256=pEgkJvmf05KSw70iLDTz_6LI_2Oi5L8sTN0JsEUpnpk,1445
355
355
  model_compression_toolkit/gptq/runner.py,sha256=La12JTYjWyJW0YW4Al4TP1_Xi4JWBCEKw6FR_JQsxe0,5982
@@ -433,20 +433,20 @@ model_compression_toolkit/qat/pytorch/quantizer/ste_rounding/uniform_ste.py,sha2
433
433
  model_compression_toolkit/target_platform_capabilities/__init__.py,sha256=8RVOriZg-XNjSt53h_4Yum0oRgOe2gp5H45dfG_lZxE,1415
434
434
  model_compression_toolkit/target_platform_capabilities/constants.py,sha256=JRz9DoxLRpkqvu532TFkIvv0595Bfb9NtU4pRp4urDY,1540
435
435
  model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
436
- model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=nbmlygR-nc3bzwnUDrRamq3a6KFkC4-cCpbUeF7EEmo,4626
436
+ model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=hFBq-qKUM9qKZGaMmrxsEmurTV_D1kWIXI1rTERZsbk,5241
437
437
  model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
438
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
439
- model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=TtMPbiibV6Hk53nl5Y_ctfpI6mSbd8VVH9fxnv5j9eM,4430
438
+ model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=hf539WJ3nBGn0RnALXrKmAPnbhJ-VmWmLIa207x8b4M,541
439
+ model_compression_toolkit/target_platform_capabilities/schema/schema_compatability.py,sha256=OpZ9SH2aTAVTCBfj1m3wcAeouk_q_16yWxCwByXK_M8,6294
440
440
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
441
441
  model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=oWKNQnnz04kmijmdWtRyXgVXbJ6BG_V_bUBz_MfUM94,27116
442
- model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=hryYeGK0zJ2ffcRpHihudtYpl8kIl1WTAQOEsyerqlM,10813
442
+ model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=1hrvq4EeLDRe0-wvpHkMLXMYYbETQ_tX-3FAHHsxb18,10880
443
443
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
444
444
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
445
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=9r_lDvRYtbGLKjnH1yLuP4vxWn0_4xS4AkdDhvBg7Ko,7154
446
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=NCwuvnByeexLL987h67XhU8vQvCgq63bt0hFSiSSxvE,6400
445
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=5Uyb5CurpLm4fgOiARKYwy3T-bb0NMmJXIRBgRjMgjo,7301
446
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=R-kTbJka37u3toun9rRDGGGXYR3Sv4VdirLIn5G1BgQ,6541
447
447
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
448
448
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
449
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=1jkj0ZO3t9M0SRpe9ZcSucraSoB4raezIbpcO_lZcP4,10084
449
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=Y-HZKwoakzY6PAYYj9l-h19yLMqBs0qBHo2YIKIsrN8,10375
450
450
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities_component.py,sha256=9Hg6AMCzTdDsKKgivRd61UjxGT5SWvKsc3mIUPPsYDQ,1021
451
451
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/layer_filter_params.py,sha256=dIu6k1xvGKLtk_47wq1eKYvrS4lYAknAXTeJfFstW0Y,3878
452
452
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/operations_to_layers.py,sha256=vZ7I2XDr_YDgU8oQt8gKkcuUOJf28DCzCPunPK2h_Xw,6563
@@ -528,7 +528,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
528
528
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
529
529
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
530
530
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
531
- mct_nightly-2.3.0.20250515.544.dist-info/METADATA,sha256=dV9aRBw1JVkuZDXyGl4aFtA91lLC_NtYTDquO5yA8rY,25136
532
- mct_nightly-2.3.0.20250515.544.dist-info/WHEEL,sha256=QZxptf4Y1BKFRCEDxD4h2V0mBFQOVFLFEpvxHmIs52A,91
533
- mct_nightly-2.3.0.20250515.544.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
- mct_nightly-2.3.0.20250515.544.dist-info/RECORD,,
531
+ mct_nightly-2.3.0.20250517.552.dist-info/METADATA,sha256=kxKFMh-zWtlCfUBFowzu71E5L-8ybwVw0pgy_rCxVYw,25135
532
+ mct_nightly-2.3.0.20250517.552.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
533
+ mct_nightly-2.3.0.20250517.552.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
534
+ mct_nightly-2.3.0.20250517.552.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.6.0)
2
+ Generator: setuptools (80.7.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250515.000544"
30
+ __version__ = "2.3.0.20250517.000552"
@@ -233,6 +233,7 @@ class PytorchModel(torch.nn.Module):
233
233
  self.return_float_outputs = return_float_outputs
234
234
  self.wrapper = wrapper
235
235
  self.get_activation_quantizer_holder = get_activation_quantizer_holder_fn
236
+ self.insert_preserving_quantizers = graph.fqc.insert_preserving_quantizers
236
237
  self.reuse_groups = {}
237
238
  self._reused_nodes = []
238
239
 
@@ -335,12 +336,17 @@ class PytorchModel(torch.nn.Module):
335
336
  activation_quantizer_holder = None
336
337
  if self.use_activation_holder_during_model_building:
337
338
  if node.is_activation_quantization_enabled():
338
- activation_quantizer_holder = self.get_activation_quantizer_holder(node, holder_type=PytorchActivationQuantizationHolder)
339
+ activation_quantizer_holder = self.get_activation_quantizer_holder(node,
340
+ holder_type=PytorchActivationQuantizationHolder)
339
341
 
340
342
  elif node.is_quantization_preserving():
341
343
  prev_node = self.graph.retrieve_preserved_quantization_node(node)
342
344
  if prev_node.is_activation_quantization_enabled():
343
- activation_quantizer_holder = self.get_activation_quantizer_holder(prev_node, holder_type=PytorchPreservingActivationQuantizationHolder)
345
+ if self.insert_preserving_quantizers:
346
+ holder_kwargs = {'quantization_bypass': True}
347
+ activation_quantizer_holder = self.get_activation_quantizer_holder(prev_node,
348
+ holder_type=PytorchPreservingActivationQuantizationHolder,
349
+ **holder_kwargs)
344
350
 
345
351
  if activation_quantizer_holder is not None:
346
352
  activation_quantizer_holder_name = node.name + '_' + ACTIVATION_HOLDER_QUANTIZER
@@ -65,26 +65,29 @@ if FOUND_TORCH:
65
65
  return module
66
66
 
67
67
 
68
- def get_activation_quantizer_holder(node: BaseNode, holder_type: PytorchActivationQuantizationHolder, fw_impl) -> Callable:
68
+ def get_activation_quantizer_holder(node: BaseNode, holder_type: PytorchActivationQuantizationHolder,
69
+ fw_impl, **kwargs) -> Callable:
69
70
  """
70
71
  Retrieve a PytorchActivationQuantizationHolder layer to use for activation quantization of a node.
71
72
  If the layer is not supposed to be wrapped with an activation quantizer - return None.
73
+
72
74
  Args:
73
75
  node: Node to attach a PytorchActivationQuantizationHolder to its output.
74
76
  holder_type: The type of the activation quantization holder to use.
75
77
  fw_impl: FrameworkImplementation object with a specific framework methods implementation.
78
+ **kwargs: Key-arguments to be passed to the quantization holder initialization to set specific arguments
79
+ based on the holder's type.
80
+
76
81
  Returns:
77
82
  A PytorchActivationQuantizationHolder module for the node's activation quantization.
78
83
  """
79
84
  # Holder by definition uses a single quantizer for the activation quantization
80
- # thus we make sure this is the only possible case (unless it's a node we no activation
85
+ # thus we make sure this is the only possible case (unless it's a node with no activation
81
86
  # quantization, which in this case has an empty list).
82
87
  _, activation_quantizers = fw_impl.get_inferable_quantizers(node)
83
88
  if len(activation_quantizers) == 1:
84
- if holder_type == PytorchActivationQuantizationHolder:
85
- return holder_type(activation_quantizers[0])
86
- elif holder_type == PytorchPreservingActivationQuantizationHolder:
87
- return holder_type(activation_quantizers[0], quantization_bypass=True)
89
+ return holder_type(activation_quantizers[0], **kwargs)
90
+
88
91
  Logger.critical(
89
92
  f'PytorchActivationQuantizationHolder supports a single quantizer but {len(activation_quantizers)} quantizers '
90
93
  f'were found for node {node}')
@@ -105,9 +108,9 @@ if FOUND_TORCH:
105
108
  wrapper=lambda n, m:
106
109
  fully_quantized_wrapper(n, m,
107
110
  fw_impl=fw_impl),
108
- get_activation_quantizer_holder_fn=lambda n, holder_type:
111
+ get_activation_quantizer_holder_fn=lambda n, holder_type, **kwargs:
109
112
  get_activation_quantizer_holder(n, holder_type,
110
- fw_impl=fw_impl)).build_model()
113
+ fw_impl=fw_impl, **kwargs)).build_model()
111
114
 
112
115
  Logger.info("\nPlease run your accuracy evaluation on the exported quantized model to verify it's accuracy.\n"
113
116
  "Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:\n"
@@ -1,4 +1,4 @@
1
- import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
1
+ import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema
2
2
 
3
3
  OperatorSetNames = schema.OperatorSetNames
4
4
  Signedness = schema.Signedness
@@ -12,14 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import copy
15
16
  from typing import Any, Union
16
17
 
17
18
  import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema_v1
18
19
  import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema_v2
19
20
  import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as current_schema
20
21
 
21
- ALL_SCHEMA_VERSIONS = [schema_v1] # needs to be updated with all active schema versions
22
- FUTURE_SCHEMA_VERSIONS = [schema_v2] # once future schema becomes current schema, move to it ALL_SCHEMA_VERSIONS
22
+ ALL_SCHEMA_VERSIONS = [schema_v1, schema_v2] # needs to be updated with all active schema versions
23
+ FUTURE_SCHEMA_VERSIONS = [] # once future schema becomes current schema, move to it ALL_SCHEMA_VERSIONS
23
24
  all_tpc_types = tuple([s.TargetPlatformCapabilities for s in ALL_SCHEMA_VERSIONS])
24
25
  tpc_or_str_type = all_tpc_types + (str,)
25
26
 
@@ -33,19 +34,57 @@ def is_tpc_instance(tpc_obj_or_path: Any) -> bool:
33
34
  return type(tpc_obj_or_path) in all_tpc_types
34
35
 
35
36
 
36
- def _schema_v1_to_v2(tpc: schema_v1.TargetPlatformCapabilities) -> schema_v2.TargetPlatformCapabilities: # pragma: no cover
37
+ def get_schema_by_version(schema_version: str):
38
+ return {
39
+ "1": schema_v1,
40
+ "2": schema_v2
41
+ }[schema_version]
42
+
43
+
44
+ def _schema_v1_to_v2(
45
+ tpc: schema_v1.TargetPlatformCapabilities) -> schema_v2.TargetPlatformCapabilities:
37
46
  """
38
- Converts given tpc of schema version 1 to schema version 2
39
- :return: TargetPlatformCapabilities instance of of schema version 2
47
+ Converts given tpc of schema version 1 to schema version 2.
48
+ Schema v2 updates:
49
+ 1. New TPC field: insert_preserving_quantizers
50
+ Compatability behavior: Set field to False by default
51
+ 2. New field in Fusing: fuse_op_quantization_config
52
+ Compatability behavior: set enable_activation_quantization=False in every fuse_op_quantization_config
53
+ 3. New operator set names: EXP, SIN, COS
54
+ Compatability behavior: Not required
55
+ :return: TargetPlatformCapabilities instance of schema version 2
40
56
  """
41
- raise NotImplementedError("Once schema v2 is implemented, add necessary adaptations to _schema_v1_to_v2 function and remove 'pragma: no cover'")
42
- return schema_v2.TargetPlatformCapabilities(default_qco=tpc.default_qco,
43
- operator_set=tpc.operator_set,
44
- fusing_patterns=tpc.fusing_patterns,
45
- tpc_minor_version=tpc.tpc_minor_version,
46
- tpc_patch_version=tpc.tpc_patch_version,
47
- tpc_platform_type=tpc.tpc_platform_type,
48
- add_metadata=tpc.add_metadata)
57
+ v1_default_qco = tpc.default_qco.base_config
58
+ v2_default_qco = schema_v2.OpQuantizationConfig(
59
+ default_weight_attr_config=v1_default_qco.default_weight_attr_config,
60
+ attr_weights_configs_mapping=v1_default_qco.attr_weights_configs_mapping,
61
+ activation_quantization_method=v1_default_qco.activation_quantization_method,
62
+ activation_n_bits=v1_default_qco.activation_n_bits,
63
+ supported_input_activation_n_bits=v1_default_qco.supported_input_activation_n_bits,
64
+ enable_activation_quantization=False, # set to False by default because feature not exist in schema v1
65
+ quantization_preserving=v1_default_qco.quantization_preserving,
66
+ fixed_scale=v1_default_qco.fixed_scale,
67
+ fixed_zero_point=v1_default_qco.fixed_zero_point,
68
+ simd_size=v1_default_qco.simd_size,
69
+ signedness=v1_default_qco.signedness)
70
+
71
+ schema_v2_fusing_patters = []
72
+ for fussing_pattern in tpc.fusing_patterns:
73
+ schema_v2_fusing_patters.append(
74
+ schema_v2.Fusing(operator_groups=fussing_pattern.operator_groups,
75
+ fuse_op_quantization_config=copy.deepcopy(v2_default_qco),
76
+ name=fussing_pattern.name))
77
+
78
+ tpc_schema_v2 = schema_v2.TargetPlatformCapabilities(default_qco=tpc.default_qco,
79
+ operator_set=tpc.operator_set,
80
+ fusing_patterns=schema_v2_fusing_patters,
81
+ tpc_minor_version=tpc.tpc_minor_version,
82
+ tpc_patch_version=tpc.tpc_patch_version,
83
+ tpc_platform_type=tpc.tpc_platform_type,
84
+ add_metadata=tpc.add_metadata,
85
+ insert_preserving_quantizers=False) # set to False by default because feature not exist in schema v1
86
+ return tpc_schema_v2
87
+
49
88
 
50
89
  def get_conversion_map() -> dict:
51
90
  """
@@ -60,7 +99,8 @@ def get_conversion_map() -> dict:
60
99
  return conversion_map
61
100
 
62
101
 
63
- def tpc_to_current_schema_version(tpc: Union[all_tpc_types]) -> current_schema.TargetPlatformCapabilities: # pragma: no cover
102
+ def tpc_to_current_schema_version(
103
+ tpc: Union[all_tpc_types]) -> current_schema.TargetPlatformCapabilities:
64
104
  """
65
105
  Given tpc instance of some schema version, convert it to the current MCT schema version.
66
106
 
@@ -91,6 +91,8 @@ class OperatorSetNames(str, Enum):
91
91
  STRIDED_SLICE = "StridedSlice"
92
92
  SSD_POST_PROCESS = "SSDPostProcess"
93
93
  EXP = "Exp"
94
+ SIN = "Sin"
95
+ COS = "Cos"
94
96
 
95
97
  @classmethod
96
98
  def get_values(cls):
@@ -218,11 +220,11 @@ class TargetPlatformCapabilities(BaseModel):
218
220
  SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
219
221
  """
220
222
  default_qco: QuantizationConfigOptions
221
- operator_set: Optional[Tuple[OperatorsSet, ...]]
222
- fusing_patterns: Optional[Tuple[Fusing, ...]]
223
- tpc_minor_version: Optional[int]
224
- tpc_patch_version: Optional[int]
225
- tpc_platform_type: Optional[str]
223
+ operator_set: Optional[Tuple[OperatorsSet, ...]] = None
224
+ fusing_patterns: Optional[Tuple[Fusing, ...]] = None
225
+ tpc_minor_version: Optional[int] = None
226
+ tpc_patch_version: Optional[int] = None
227
+ tpc_platform_type: Optional[str] = None
226
228
  add_metadata: bool = True
227
229
  name: Optional[str] = "default_tpc"
228
230
 
@@ -100,7 +100,10 @@ class AttachTpcToKeras(AttachTpcToFramework):
100
100
  OperatorSetNames.LOG_SOFTMAX: [tf.nn.log_softmax],
101
101
  OperatorSetNames.ADD_BIAS: [tf.nn.bias_add],
102
102
  OperatorSetNames.L2NORM: [tf.math.l2_normalize],
103
- OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess]
103
+ OperatorSetNames.SSD_POST_PROCESS: [SSDPostProcess],
104
+ OperatorSetNames.EXP: [tf.math.exp],
105
+ OperatorSetNames.SIN: [tf.math.sin],
106
+ OperatorSetNames.COS: [tf.math.cos]
104
107
  }
105
108
 
106
109
  self._opset2attr_mapping = {
@@ -99,6 +99,9 @@ class AttachTpcToPytorch(AttachTpcToFramework):
99
99
  Eq('p', 2) | Eq('p', None))],
100
100
  OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
101
101
  OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [MulticlassNMS, MulticlassNMSWithIndices],
102
+ OperatorSetNames.EXP: [torch.exp],
103
+ OperatorSetNames.SIN: [torch.sin],
104
+ OperatorSetNames.COS: [torch.cos],
102
105
  }
103
106
 
104
107
  pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),
@@ -239,7 +239,17 @@ class FrameworkQuantizationCapabilities(ImmutableClass):
239
239
  def is_simd_padding(self) -> bool:
240
240
  """
241
241
 
242
- Returns: Check if the TP model defines that padding due to SIMD constrains occurs.
242
+ Returns: Check if the TPC defines that padding due to SIMD constrains occurs.
243
243
 
244
244
  """
245
245
  return self.tpc.is_simd_padding
246
+
247
+ @property
248
+ def insert_preserving_quantizers(self) -> bool:
249
+ """
250
+
251
+ Returns: Check if the TPC defines that a quantizer for quantization preserving operators should be added to the
252
+ constructed model.
253
+
254
+ """
255
+ return self.tpc.insert_preserving_quantizers
@@ -12,12 +12,23 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ import json
15
16
  from pathlib import Path
16
17
  from typing import Union
17
18
 
18
19
  import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
19
20
  from model_compression_toolkit.target_platform_capabilities.schema.schema_compatability import is_tpc_instance, \
20
- tpc_to_current_schema_version, tpc_or_str_type
21
+ tpc_to_current_schema_version, tpc_or_str_type, get_schema_by_version
22
+
23
+
24
+ def _get_json_schema(tpc_json_path: str):
25
+ """
26
+ Given a TPC json file path, extract the schema version from it, and return schema object matched to that
27
+ schema version.
28
+ """
29
+ with open(tpc_json_path, 'r', encoding='utf-8') as f:
30
+ schema_version = str(json.load(f)["SCHEMA_VERSION"])
31
+ return get_schema_by_version(schema_version)
21
32
 
22
33
 
23
34
  def _get_tpc_from_json(tpc_path: str) -> schema.TargetPlatformCapabilities:
@@ -40,7 +51,11 @@ def _get_tpc_from_json(tpc_path: str) -> schema.TargetPlatformCapabilities:
40
51
  raise ValueError(f"Error reading the file '{tpc_path}': {e.strerror}.") from e
41
52
 
42
53
  try:
43
- return schema.TargetPlatformCapabilities.parse_raw(data)
54
+ # json_schema = _get_json_schema(tpc_path)
55
+ # tpc = json_schema.TargetPlatformCapabilities.parse_raw(data)
56
+ # return tpc_to_current_schema_version(tpc)
57
+ tpc = schema.TargetPlatformCapabilities.parse_raw(data)
58
+ return tpc_to_current_schema_version(tpc)
44
59
  except ValueError as e:
45
60
  raise ValueError(f"Invalid JSON for loading TargetPlatformCapabilities in '{tpc_path}': {e}.") from e
46
61
  except Exception as e: