mct-nightly 2.3.0.20250331.610__py3-none-any.whl → 2.3.0.20250402.536__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mct-nightly
3
- Version: 2.3.0.20250331.610
3
+ Version: 2.3.0.20250402.536
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: Apache Software License
@@ -1,5 +1,5 @@
1
- mct_nightly-2.3.0.20250331.610.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
- model_compression_toolkit/__init__.py,sha256=tulTPjSjSee0ySdxccC46EZ_AbQkEIr3RvtUps-6IME,1557
1
+ mct_nightly-2.3.0.20250402.536.dist-info/licenses/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
2
+ model_compression_toolkit/__init__.py,sha256=dhPx1u7eKO_zAY9CefOYP31YislX9FLOXxMFWv9PVJo,1557
3
3
  model_compression_toolkit/constants.py,sha256=2ltuH-gdaLZoZV4CPUgKjC3S9ojz2z4OTVdenyVEypU,3912
4
4
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
5
5
  model_compression_toolkit/logger.py,sha256=L3q7tn3Uht0i_7phnlOWMR2Te2zvzrt2HOz9vYEInts,4529
@@ -34,8 +34,8 @@ model_compression_toolkit/core/common/fusion/__init__.py,sha256=Rf1RcYmelmdZmBV5
34
34
  model_compression_toolkit/core/common/fusion/graph_fuser.py,sha256=b41_4rL_Adiza4vpWlmmqgvkpUmWVdfdx0nEIB0p2n8,6195
35
35
  model_compression_toolkit/core/common/fusion/layer_fusing.py,sha256=-2fnjyC9q2RPw9st6RxROW-gdtT2mSRz0QZ_Gz1KDz4,5579
36
36
  model_compression_toolkit/core/common/graph/__init__.py,sha256=Xr-Lt_qXMdrCnnOaUS_OJP_3iTTGfPCLf8_vSrQgCs0,773
37
- model_compression_toolkit/core/common/graph/base_graph.py,sha256=VhniLTiMqL7i1Vqg2UBQuFFTvw2cYeJayssUJwabp3E,38112
38
- model_compression_toolkit/core/common/graph/base_node.py,sha256=kZbmAMh5cPAwYzlY8KYa8w0ipL58yApB09-WXQ8plrE,33763
37
+ model_compression_toolkit/core/common/graph/base_graph.py,sha256=cSwHUqwZEiR1t2DaBfc7_qSJbtX8crpqerN4ol9v3H8,38859
38
+ model_compression_toolkit/core/common/graph/base_node.py,sha256=CJu8_r80MGVnYmlAUGOGKGRsD9xShMyaRNb3VMeRC0s,34523
39
39
  model_compression_toolkit/core/common/graph/edge.py,sha256=buoSEUZwilWBK3WeBKpJ-GeDaUA1SDdOHxDpxU_bGpk,3784
40
40
  model_compression_toolkit/core/common/graph/functional_node.py,sha256=GH5wStmw8SoAj5IdT_-ItN1Meo_P5NUTt_5bgJC4fak,3935
41
41
  model_compression_toolkit/core/common/graph/graph_matchers.py,sha256=CrDoHYq4iPaflgJWmoJ1K4ziLrRogJvFTVWg8P0UcDU,4744
@@ -75,10 +75,10 @@ model_compression_toolkit/core/common/mixed_precision/set_layer_to_bitwidth.py,s
75
75
  model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py,sha256=fk7PWiZ6Na5O_Z_dymk_UfDCTqW_X_4EROU7DZknQnc,9444
76
76
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKWF575ofc06JFQSq83Jz99A,696
77
77
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization.py,sha256=PKkhc5q8pEPnNLXwo3U56EOCfYnPXIvPs0LlCGZOoKU,4426
78
- model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=xCYL36K0nK41VSsLcy52uDA7zVfoLxhubmOrtXbqw7s,39140
78
+ model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_calculator.py,sha256=MP4Q5lThvEIhfa1iBajQQM3nCUNgK-2yseqQQ8Rgiog,40624
79
79
  model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py,sha256=QQwtl08DiDxUOQGpYPnek_RlZjWm1Ky7tL2ESHXMK78,4050
80
80
  model_compression_toolkit/core/common/mixed_precision/search_methods/__init__.py,sha256=sw7LOPN1bM82o3SkMaklyH0jw-TLGK0-fl2Wq73rffI,697
81
- model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=rSSN5MhH5BO5b58d8pe2pY9wc5HbfescoUStfg-nWfk,7263
81
+ model_compression_toolkit/core/common/mixed_precision/search_methods/linear_programming.py,sha256=TaK5NqVdmygsHw9_x5JsJ-BPvlbKA9cRyTno1R8gbnU,7269
82
82
  model_compression_toolkit/core/common/network_editors/__init__.py,sha256=vZmu55bYqiaOQs3AjfwWDXHmuKZcLHt-wm7uR5fPEqg,1307
83
83
  model_compression_toolkit/core/common/network_editors/actions.py,sha256=nid0_j-Cn10xvmztT8yCKW_6uA7JEnom9SW9syx7wc0,19594
84
84
  model_compression_toolkit/core/common/network_editors/edit_network.py,sha256=dfgawi-nB0ocAJ0xcGn9E-Zv203oUnQLuMiXpX8vTgA,1748
@@ -435,13 +435,14 @@ model_compression_toolkit/target_platform_capabilities/constants.py,sha256=BFSgD
435
435
  model_compression_toolkit/target_platform_capabilities/immutable.py,sha256=YhROBiXEIB3TU-bAFrnL3qbAsb1yuWPBAQ_CLOJbYUU,1827
436
436
  model_compression_toolkit/target_platform_capabilities/tpc_io_handler.py,sha256=4ydTWWKv_PEOAFok2JtxFNj8rav-0IlqcXKF6lnhHNE,4157
437
437
  model_compression_toolkit/target_platform_capabilities/schema/__init__.py,sha256=pKAdbTCFM_2BrZXUtTIw0ouKotrWwUDF_hP3rPwCM2k,696
438
- model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=PvO8eHxnb3A55gyExT5fZGnOUl3ce7BbbT5SPxCEXNo,541
438
+ model_compression_toolkit/target_platform_capabilities/schema/mct_current_schema.py,sha256=hf539WJ3nBGn0RnALXrKmAPnbhJ-VmWmLIa207x8b4M,541
439
439
  model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py,sha256=vBkXxVJagm9JKB9cdm4Pvi7u_luriXUjvNn0-m8Zr0k,4653
440
440
  model_compression_toolkit/target_platform_capabilities/schema/v1.py,sha256=4CGpWENuOyjwaIMaGrFI0Act7jsSeT7m94pjrv91dxE,27516
441
+ model_compression_toolkit/target_platform_capabilities/schema/v2.py,sha256=vUhCocA0EcjdR741Yv48W4Kr5Pq22Miebhm7F9GKb3Y,6086
441
442
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/__init__.py,sha256=XjNws3zoiJkeH4ixKqrLA5xBvpv5rq31qX7wYQjNpZM,1447
442
443
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2fw.py,sha256=HJ8uc3PFfyxg-WpVXPBg4mGaox8Z9bRqtQNbRfIyAk4,3745
443
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=mxc3DBbUi-HDFgSx8Nmnyxr8SIdbx8lmtcRMsQl1BLE,7578
444
- model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=WPCqs_aFGE28XJf7KKB-SlrYoUNOcD9epgoaqQMCJMw,6320
444
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2keras.py,sha256=AE09QLE_QKwNqUTZbkZP9XLJStG1ECiTWmEGuXZTEsQ,7652
445
+ model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attach2pytorch.py,sha256=-zbPmzQJal-1vZiQ6vIBBBnlEOB2DTb09koA0Aj4I_I,6396
445
446
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/attribute_filter.py,sha256=jfhszvuD2Fyy6W2KjlLzXBQKFzTqGAaDZeFVr4-ONQw,8776
446
447
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/current_tpc.py,sha256=_kFG0USYa6yzvLsi82_Vusv_KR8Hi7J1u680pPXECuo,2192
447
448
  model_compression_toolkit/target_platform_capabilities/targetplatform2framework/framework_quantization_capabilities.py,sha256=UKzckLYLdBcFAptyKnVMwpPpfRkmF0SK1Kl0g0eGjQA,9710
@@ -526,7 +527,7 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
526
527
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=UVN_S9ULHBEldBpShCOt8-soT8YTQ5oE362y96qF_FA,3950
527
528
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
528
529
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
529
- mct_nightly-2.3.0.20250331.610.dist-info/METADATA,sha256=BSWkSPN58Xfzc8eImgBxCG3v_wXYt8c7l8oFNOOoyGw,27098
530
- mct_nightly-2.3.0.20250331.610.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
531
- mct_nightly-2.3.0.20250331.610.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
532
- mct_nightly-2.3.0.20250331.610.dist-info/RECORD,,
530
+ mct_nightly-2.3.0.20250402.536.dist-info/METADATA,sha256=v7bHr6SXyb9qkbAmhb3xAHDXzinP-56xYnw-3SMtHVc,27098
531
+ mct_nightly-2.3.0.20250402.536.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
532
+ mct_nightly-2.3.0.20250402.536.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
533
+ mct_nightly-2.3.0.20250402.536.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.3.0.20250331.000610"
30
+ __version__ = "2.3.0.20250402.000536"
@@ -696,6 +696,23 @@ class Graph(nx.MultiDiGraph, GraphSearches):
696
696
  sorted_conf_activation = self.get_sorted_activation_configurable_nodes()
697
697
  return [(n, n.final_activation_quantization_cfg.activation_n_bits) for n in sorted_conf_activation]
698
698
 
699
+ def retrieve_preserved_quantization_node(self, node: BaseNode) -> BaseNode:
700
+ """
701
+ For a node with quantization_preserving == True, get the previous non-quantization_preserving node
702
+ to get activation quantization config from. If quantization_preserving is False return node.
703
+ Args:
704
+ node: quantization preserving node.
705
+
706
+ Returns:
707
+ The node that the quantization preserving node should get the activation quantization from.
708
+
709
+ """
710
+ while node.is_quantization_preserving():
711
+ prev_nodes = self.get_prev_nodes(node)
712
+ assert len(prev_nodes) == 1, "Activation preserving node should have only 1 input."
713
+ node = prev_nodes[0]
714
+ return node
715
+
699
716
  def update_fused_nodes(self, fusion: List[Any]):
700
717
  """
701
718
  Updates the graphs fusions list with a new list of nodes that have been fused.
@@ -131,6 +131,19 @@ class BaseNode:
131
131
  qc.activation_quantization_cfg.enable_activation_quantization
132
132
  return self.candidates_quantization_cfg[0].activation_quantization_cfg.enable_activation_quantization
133
133
 
134
+ def is_quantization_preserving(self) -> bool:
135
+ """
136
+ Returns: Whether node activation quantization information is preserved from its inputs.
137
+ """
138
+ if self.final_activation_quantization_cfg:
139
+ # if we have a final configuration, then we only care to check if it enables activation quantization.
140
+ return self.final_activation_quantization_cfg.quantization_preserving
141
+
142
+ for qc in self.candidates_quantization_cfg:
143
+ assert self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving == \
144
+ qc.activation_quantization_cfg.quantization_preserving
145
+ return self.candidates_quantization_cfg[0].activation_quantization_cfg.quantization_preserving
146
+
134
147
  def is_weights_quantization_enabled(self, attr_name: str) -> bool:
135
148
  """
136
149
  Checks whether a node's weights attribute quantization is enabled.
@@ -335,13 +335,35 @@ class ResourceUtilizationCalculator:
335
335
  """
336
336
  return self.compute_activation_utilization_by_cut(target_criterion, bitwidth_mode, act_qcs)
337
337
 
338
+ def _extract_qc(self, n: BaseNode, act_qcs: Optional[ActivationQCfgPerNode] = None
339
+ ) -> Union[NodeActivationQuantizationConfig, None]:
340
+ """
341
+ Extract quantization config the activation configs dictionary is provided. If node is quantization
342
+ preserving, extract the quantization config from the preceding activation quantized node (i.e.
343
+ the Quantization the original node preserves).
344
+
345
+ Args:
346
+ n: Node to extract qc for.
347
+ act_qcs: custom activations quantization configuration. If not provided, the default
348
+ configuration will be extracted from the node.
349
+
350
+ Returns:
351
+ The relevant quantization config.
352
+ """
353
+ if act_qcs:
354
+ assert not (n.is_quantization_preserving() and act_qcs.get(n.name) is not None), \
355
+ f"Quantization preserving node {n.name} should not have a qc for this computation."
356
+ return act_qcs.get(self.graph.retrieve_preserved_quantization_node(n).name)
357
+ return None
358
+
338
359
  def compute_activation_utilization_by_cut(self,
339
360
  target_criterion: TargetInclusionCriterion,
340
361
  bitwidth_mode: BitwidthMode,
341
362
  act_qcs: Optional[ActivationQCfgPerNode] = None) \
342
363
  -> Tuple[float, Dict[Cut, Utilization], Dict[Cut, Dict[BaseNode, Utilization]]]:
343
364
  """
344
- Compute graph activation cuts utilization.
365
+ Compute graph activation cuts utilization. If activation quantization configs are provided, then for
366
+ quantization preserving nodes, get the previous quantized activation node bit-width.
345
367
 
346
368
  Args:
347
369
  target_criterion: criterion to include weights for computation.
@@ -369,7 +391,7 @@ class ResourceUtilizationCalculator:
369
391
  if not cut_target_nodes:
370
392
  continue
371
393
  for n in cut_target_nodes:
372
- qc = act_qcs.get(n.name) if act_qcs else None
394
+ qc = self._extract_qc(n, act_qcs)
373
395
  util_per_cut_per_node[cut][n.name] = self.compute_node_activation_tensor_utilization(n, target_criterion,
374
396
  bitwidth_mode, qc)
375
397
  util_per_cut[cut] = sum(util_per_cut_per_node[cut].values()) # type: ignore
@@ -384,7 +406,8 @@ class ResourceUtilizationCalculator:
384
406
  include_reused=False) \
385
407
  -> Tuple[float, Dict[NodeName, Utilization]]:
386
408
  """
387
- Compute resource utilization for graph's activations tensors.
409
+ Compute resource utilization for graph's activations tensors. If activation quantization configs are provided, then for
410
+ quantization preserving nodes, get the previous quantized activation node bit-width.
388
411
 
389
412
  Args:
390
413
  target_criterion: criterion to include weights for computation.
@@ -405,7 +428,7 @@ class ResourceUtilizationCalculator:
405
428
 
406
429
  util_per_node: Dict[NodeName, Utilization] = {}
407
430
  for n in self._topo_sort(nodes):
408
- qc = act_qcs.get(n.name) if act_qcs else None
431
+ qc = self._extract_qc(n, act_qcs)
409
432
  util = self.compute_node_activation_tensor_utilization(n, None, bitwidth_mode, qc)
410
433
  util_per_node[n.name] = util
411
434
 
@@ -659,7 +682,7 @@ class ResourceUtilizationCalculator:
659
682
  if target_criterion == TargetInclusionCriterion.QConfigurable:
660
683
  nodes = [n for n in nodes if n.has_configurable_activation()]
661
684
  elif target_criterion == TargetInclusionCriterion.AnyQuantized:
662
- nodes = [n for n in nodes if n.is_activation_quantization_enabled()]
685
+ nodes = [n for n in nodes if n.is_activation_quantization_enabled() or n.is_quantization_preserving()]
663
686
  elif target_criterion == TargetInclusionCriterion.QNonConfigurable:
664
687
  nodes = [n for n in nodes if n.is_activation_quantization_enabled() and not n.has_configurable_activation()]
665
688
  elif target_criterion != TargetInclusionCriterion.Any: # pragma: no cover
@@ -668,8 +691,7 @@ class ResourceUtilizationCalculator:
668
691
  nodes = [n for n in nodes if not n.reuse]
669
692
  return nodes
670
693
 
671
- @classmethod
672
- def _get_activation_nbits(cls,
694
+ def _get_activation_nbits(self,
673
695
  n: BaseNode,
674
696
  bitwidth_mode: BitwidthMode,
675
697
  act_qc: Optional[NodeActivationQuantizationConfig]) -> int:
@@ -690,21 +712,22 @@ class ResourceUtilizationCalculator:
690
712
  assert bitwidth_mode == BitwidthMode.QCustom
691
713
  return act_qc.activation_n_bits if act_qc.enable_activation_quantization else FLOAT_BITWIDTH
692
714
 
693
- if bitwidth_mode == BitwidthMode.Float or not n.is_activation_quantization_enabled():
715
+ if bitwidth_mode == BitwidthMode.Float or not (n.is_activation_quantization_enabled() or
716
+ n.is_quantization_preserving()):
694
717
  return FLOAT_BITWIDTH
695
718
 
696
719
  if bitwidth_mode == BitwidthMode.Q8Bit:
697
720
  return 8
698
721
 
699
- if bitwidth_mode in cls._bitwidth_mode_fn:
722
+ if bitwidth_mode in self._bitwidth_mode_fn:
700
723
  candidates_nbits = [c.activation_quantization_cfg.activation_n_bits for c in n.candidates_quantization_cfg]
701
- return cls._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)
724
+ return self._bitwidth_mode_fn[bitwidth_mode](candidates_nbits)
702
725
 
703
726
  if bitwidth_mode in [BitwidthMode.QCustom, BitwidthMode.QDefaultSP]:
704
- qcs = n.get_unique_activation_candidates()
727
+ qcs = self.graph.retrieve_preserved_quantization_node(n).get_unique_activation_candidates()
705
728
  if len(qcs) != 1:
706
729
  raise ValueError(f'Could not retrieve the activation quantization candidate for node {n} '
707
- f'as it has {len(qcs)}!=1 unique candidates .')
730
+ f'as it has {len(qcs)}!=1 unique candidates.')
708
731
  return qcs[0].activation_quantization_cfg.activation_n_bits
709
732
 
710
733
  raise ValueError(f'Unknown mode {bitwidth_mode}') # pragma: no cover
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  import numpy as np
16
16
  from pulp import *
17
- from typing import Dict, Tuple
17
+ from typing import Dict, Tuple, List
18
18
 
19
19
  from model_compression_toolkit.core.common.mixed_precision.resource_utilization_tools.resource_utilization import RUTarget
20
20
 
@@ -1,4 +1,4 @@
1
- import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
1
+ import model_compression_toolkit.target_platform_capabilities.schema.v2 as schema
2
2
 
3
3
  OperatorSetNames = schema.OperatorSetNames
4
4
  Signedness = schema.Signedness
@@ -0,0 +1,177 @@
1
+ # Copyright 2025 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import pprint
16
+ from enum import Enum
17
+ from typing import Dict, Any, Tuple, Optional
18
+
19
+ from pydantic import BaseModel, root_validator
20
+
21
+ from mct_quantizers import QuantizationMethod
22
+ from model_compression_toolkit.constants import FLOAT_BITWIDTH
23
+ from model_compression_toolkit.logger import Logger
24
+ from model_compression_toolkit.target_platform_capabilities.schema.v1 import (
25
+ Signedness,
26
+ AttributeQuantizationConfig,
27
+ OpQuantizationConfig,
28
+ QuantizationConfigOptions,
29
+ TargetPlatformModelComponent,
30
+ OperatorsSetBase,
31
+ OperatorsSet,
32
+ OperatorSetGroup,
33
+ Fusing)
34
+
35
+
36
+ class OperatorSetNames(str, Enum):
37
+ CONV = "Conv"
38
+ DEPTHWISE_CONV = "DepthwiseConv2D"
39
+ CONV_TRANSPOSE = "ConvTranspose"
40
+ FULLY_CONNECTED = "FullyConnected"
41
+ CONCATENATE = "Concatenate"
42
+ STACK = "Stack"
43
+ UNSTACK = "Unstack"
44
+ GATHER = "Gather"
45
+ EXPAND = "Expend"
46
+ BATCH_NORM = "BatchNorm"
47
+ L2NORM = "L2Norm"
48
+ RELU = "ReLU"
49
+ RELU6 = "ReLU6"
50
+ LEAKY_RELU = "LeakyReLU"
51
+ ELU = "Elu"
52
+ HARD_TANH = "HardTanh"
53
+ ADD = "Add"
54
+ SUB = "Sub"
55
+ MUL = "Mul"
56
+ DIV = "Div"
57
+ MIN = "Min"
58
+ MAX = "Max"
59
+ PRELU = "PReLU"
60
+ ADD_BIAS = "AddBias"
61
+ SWISH = "Swish"
62
+ SIGMOID = "Sigmoid"
63
+ SOFTMAX = "Softmax"
64
+ LOG_SOFTMAX = "LogSoftmax"
65
+ TANH = "Tanh"
66
+ GELU = "Gelu"
67
+ HARDSIGMOID = "HardSigmoid"
68
+ HARDSWISH = "HardSwish"
69
+ FLATTEN = "Flatten"
70
+ GET_ITEM = "GetItem"
71
+ RESHAPE = "Reshape"
72
+ UNSQUEEZE = "Unsqueeze"
73
+ SQUEEZE = "Squeeze"
74
+ PERMUTE = "Permute"
75
+ TRANSPOSE = "Transpose"
76
+ DROPOUT = "Dropout"
77
+ SPLIT_CHUNK = "SplitChunk"
78
+ MAXPOOL = "MaxPool"
79
+ AVGPOOL = "AvgPool"
80
+ SIZE = "Size"
81
+ SHAPE = "Shape"
82
+ EQUAL = "Equal"
83
+ ARGMAX = "ArgMax"
84
+ TOPK = "TopK"
85
+ FAKE_QUANT = "FakeQuant"
86
+ COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression"
87
+ BOX_DECODE = "BoxDecode"
88
+ ZERO_PADDING2D = "ZeroPadding2D"
89
+ CAST = "Cast"
90
+ RESIZE = "Resize"
91
+ PAD = "Pad"
92
+ FOLD = "Fold"
93
+ STRIDED_SLICE = "StridedSlice"
94
+ SSD_POST_PROCESS = "SSDPostProcess"
95
+
96
+ @classmethod
97
+ def get_values(cls):
98
+ return [v.value for v in cls]
99
+
100
+
101
+ class TargetPlatformCapabilities(BaseModel):
102
+ """
103
+ Represents the hardware configuration used for quantized model inference.
104
+
105
+ Attributes:
106
+ default_qco (QuantizationConfigOptions): Default quantization configuration options for the model.
107
+ operator_set (Optional[Tuple[OperatorsSet, ...]]): Tuple of operator sets within the model.
108
+ fusing_patterns (Optional[Tuple[Fusing, ...]]): Tuple of fusing patterns for the model.
109
+ tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration.
110
+ tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration.
111
+ tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration.
112
+ add_metadata (bool): Flag to determine if metadata should be added.
113
+ name (str): Name of the Target Platform Model.
114
+ is_simd_padding (bool): Indicates if SIMD padding is applied.
115
+ SCHEMA_VERSION (int): Version of the schema for the Target Platform Model.
116
+ """
117
+ default_qco: QuantizationConfigOptions
118
+ operator_set: Optional[Tuple[OperatorsSet, ...]]
119
+ fusing_patterns: Optional[Tuple[Fusing, ...]]
120
+ tpc_minor_version: Optional[int]
121
+ tpc_patch_version: Optional[int]
122
+ tpc_platform_type: Optional[str]
123
+ add_metadata: bool = True
124
+ name: Optional[str] = "default_tpc"
125
+ is_simd_padding: bool = False
126
+
127
+ SCHEMA_VERSION: int = 2
128
+
129
+ class Config:
130
+ frozen = True
131
+
132
+ @root_validator(allow_reuse=True)
133
+ def validate_after_initialization(cls, values: Dict[str, Any]) -> Dict[str, Any]:
134
+ """
135
+ Perform validation after the model has been instantiated.
136
+
137
+ Args:
138
+ values (Dict[str, Any]): The instantiated target platform model.
139
+
140
+ Returns:
141
+ Dict[str, Any]: The validated values.
142
+ """
143
+ # Validate `default_qco`
144
+ default_qco = values.get('default_qco')
145
+ if len(default_qco.quantization_configurations) != 1:
146
+ Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover
147
+
148
+ # Validate `operator_set` uniqueness
149
+ operator_set = values.get('operator_set')
150
+ if operator_set is not None:
151
+ opsets_names = [
152
+ op.name.value if isinstance(op.name, OperatorSetNames) else op.name
153
+ for op in operator_set
154
+ ]
155
+ if len(set(opsets_names)) != len(opsets_names):
156
+ Logger.critical("Operator Sets must have unique names.") # pragma: no cover
157
+
158
+ return values
159
+
160
+ def get_info(self) -> Dict[str, Any]:
161
+ """
162
+ Get a dictionary summarizing the TargetPlatformCapabilities properties.
163
+
164
+ Returns:
165
+ Dict[str, Any]: Summary of the TargetPlatformCapabilities properties.
166
+ """
167
+ return {
168
+ "Model name": self.name,
169
+ "Operators sets": [o.get_info() for o in self.operator_set] if self.operator_set else [],
170
+ "Fusing patterns": [f.get_info() for f in self.fusing_patterns] if self.fusing_patterns else [],
171
+ }
172
+
173
+ def show(self):
174
+ """
175
+ Display the TargetPlatformCapabilities.
176
+ """
177
+ pprint.pprint(self.get_info(), sort_dicts=False)
@@ -93,6 +93,7 @@ class AttachTpcToKeras(AttachTpcToFramework):
93
93
  OperatorSetNames.TOPK: [tf.nn.top_k],
94
94
  OperatorSetNames.FAKE_QUANT: [tf.quantization.fake_quant_with_min_max_vars],
95
95
  OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [tf.image.combined_non_max_suppression],
96
+ OperatorSetNames.BOX_DECODE: [], # no such operator in keras
96
97
  OperatorSetNames.ZERO_PADDING2D: [ZeroPadding2D],
97
98
  OperatorSetNames.CAST: [tf.cast],
98
99
  OperatorSetNames.STRIDED_SLICE: [tf.strided_slice],
@@ -97,7 +97,8 @@ class AttachTpcToPytorch(AttachTpcToFramework):
97
97
  OperatorSetNames.L2NORM: [LayerFilterParams(torch.nn.functional.normalize,
98
98
  Eq('p', 2) | Eq('p', None))],
99
99
  OperatorSetNames.SSD_POST_PROCESS: [], # no such operator in pytorch
100
- OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [] # no such operator in pytorch
100
+ OperatorSetNames.COMBINED_NON_MAX_SUPPRESSION: [], # no such operator in pytorch
101
+ OperatorSetNames.BOX_DECODE: [] # no such operator in pytorch
101
102
  }
102
103
 
103
104
  pytorch_linear_attr_mapping = {KERNEL_ATTR: DefaultDict(default_value=PYTORCH_KERNEL),