mct-nightly 2.2.0.20241123.518__py3-none-any.whl → 2.2.0.20241125.521__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20241123.518
3
+ Version: 2.2.0.20241125.521
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=-xH9WmXx3mA1ga3UmoUb7EqBNN0auFU_wQv6dSA5WVw,1573
1
+ model_compression_toolkit/__init__.py,sha256=V7JxE-NIGj2DUXx77l3Ho4NbCRrjcf7p_ma6-PrKRlM,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -176,7 +176,7 @@ model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm
176
176
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_reconstruction.py,sha256=GR1a3mCZpNUu4WxixJXF_aSm57phAdxaRoHecNx3hxw,3168
177
177
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/batchnorm_refusing.py,sha256=5df_xGfXkqNub4xVRnCWQvSohWqdv12axjJ6edVU2H0,2478
178
178
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/concat_threshold_update.py,sha256=Hl4LEQ_bw_Vpmf3ZqHujYUqVdvTNsPlEMvr9dZhwg2U,2806
179
- model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py,sha256=sf4nrcQckDlvfA1CNXCXcp35vHO-TVPrc9tXK2vDV6I,11198
179
+ model_compression_toolkit/core/keras/graph_substitutions/substitutions/conv_funcs_to_layer.py,sha256=RwzqSksGNmN1KPH8RTJzpCSjGgxvtT9kqqPqsjbGPqs,11631
180
180
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/dwconv_to_conv.py,sha256=R3U7cjc2E0zheMem16GHygp5jZFGSaomkNOTxTjcAgw,5794
181
181
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/input_scaling.py,sha256=V6hp67CkS_A3WqdsjLjs0ETtdZAOo4P9mhy4aT7W5FE,5940
182
182
  model_compression_toolkit/core/keras/graph_substitutions/substitutions/linear_collapsing.py,sha256=AvquvVVVT8-ioeVn-gjqysK4L41L3I7TlNOEDfWjViY,8185
@@ -559,8 +559,8 @@ model_compression_toolkit/xquant/pytorch/model_analyzer.py,sha256=b93o800yVB3Z-i
559
559
  model_compression_toolkit/xquant/pytorch/pytorch_report_utils.py,sha256=bOc-hFL3gdoSM1Th_S2N_-9JJSlPGpZCTx_QLJHS6lg,3388
560
560
  model_compression_toolkit/xquant/pytorch/similarity_functions.py,sha256=CERxq5K8rqaiE-DlwhZBTUd9x69dtYJlkHOPLB54vm8,2354
561
561
  model_compression_toolkit/xquant/pytorch/tensorboard_utils.py,sha256=mkoEktLFFHtEKzzFRn_jCnxjhJolK12TZ5AQeDHzUO8,9767
562
- mct_nightly-2.2.0.20241123.518.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
563
- mct_nightly-2.2.0.20241123.518.dist-info/METADATA,sha256=Cu9zX56iaIeLO2feGNpwHp1m3r6sbwKCmUF-ZY5VPUg,26473
564
- mct_nightly-2.2.0.20241123.518.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
565
- mct_nightly-2.2.0.20241123.518.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
566
- mct_nightly-2.2.0.20241123.518.dist-info/RECORD,,
562
+ mct_nightly-2.2.0.20241125.521.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
563
+ mct_nightly-2.2.0.20241125.521.dist-info/METADATA,sha256=P2SQ21pjaU6q--__nRskZCmLp3jjCliPQPLOBCm7VnM,26473
564
+ mct_nightly-2.2.0.20241125.521.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
565
+ mct_nightly-2.2.0.20241125.521.dist-info/top_level.txt,sha256=gsYA8juk0Z-ZmQRKULkb3JLGdOdz8jW_cMRjisn9ga4,26
566
+ mct_nightly-2.2.0.20241125.521.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.45.0)
2
+ Generator: bdist_wheel (0.45.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20241123.000518"
30
+ __version__ = "2.2.0.20241125.000521"
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from typing import Optional, Tuple
15
16
 
16
17
  import numpy as np
17
18
  import tensorflow as tf
@@ -30,7 +31,7 @@ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOpera
30
31
  from model_compression_toolkit.constants import REUSE, REUSE_GROUP
31
32
  from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
32
33
  KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \
33
- ACTIVATION, LINEAR
34
+ ACTIVATION, LINEAR, DATA_FORMAT, GROUPS, CHANNELS_FORMAT_FIRST, CHANNELS_FORMAT_LAST
34
35
 
35
36
 
36
37
  def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
@@ -136,35 +137,35 @@ class Conv2dFuncToConv2dLayer(common.BaseSubstitution):
136
137
  conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR}
137
138
  if len(conv_func_node.op_call_args) > 0:
138
139
  Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
139
- if STRIDES in conv_func_node.op_call_kwargs:
140
- strides = conv_func_node.op_call_kwargs[STRIDES]
141
- if len(strides) == 4:
142
- if strides[0] > 1 or strides[3] > 1:
143
- # Non-standard strides -> skip substitution.
144
- return graph # pragma: no cover
145
- conv_fw_attr[STRIDES] = strides[1:3]
146
- else:
147
- conv_fw_attr[STRIDES] = strides
148
- if PADDING in conv_func_node.op_call_kwargs:
149
- padding = conv_func_node.op_call_kwargs[PADDING]
150
- if not isinstance(padding, str):
151
- # Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
152
- return graph # pragma: no cover
153
- conv_fw_attr[PADDING] = padding
154
- if DILATIONS in conv_func_node.op_call_kwargs and conv_func_node.op_call_kwargs[DILATIONS] is not None:
155
- dilations = conv_func_node.op_call_kwargs[DILATIONS]
156
- if isinstance(dilations, (list, tuple)) and len(dilations) == 4:
157
- if dilations[0] > 1 or dilations[3] > 1:
158
- # Non-standard dilations -> skip substitution.
159
- return graph # pragma: no cover
160
- conv_fw_attr[DILATION_RATE] = dilations[1:3]
161
- else:
162
- conv_fw_attr[DILATION_RATE] = dilations
140
+
141
+ strides = self._parse_tf_stride_dilation(conv_func_node, STRIDES)
142
+ if strides is None:
143
+ # Non-standard strides -> skip substitution.
144
+ return graph
145
+ conv_fw_attr[STRIDES] = strides
146
+
147
+ padding = conv_func_node.op_call_kwargs.get(PADDING) or 'VALID'
148
+ if not isinstance(padding, str):
149
+ # Non-standard padding, Layer only support either 'valid' or 'same' -> skip substitution.
150
+ return graph # pragma: no cover
151
+ conv_fw_attr[PADDING] = padding
152
+
153
+ dilations = self._parse_tf_stride_dilation(conv_func_node, DILATIONS)
154
+ if dilations is None:
155
+ # Non-standard dilations -> skip substitution.
156
+ return graph
157
+ conv_fw_attr[DILATION_RATE] = dilations
158
+
163
159
  if b is None:
164
160
  conv_fw_attr[USE_BIAS] = False
165
161
  else:
166
162
  weights[BIAS] = b
167
163
 
164
+ data_format = conv_func_node.op_call_kwargs.get(DATA_FORMAT, 'NHWC')
165
+ conv_fw_attr[DATA_FORMAT] = {'NHWC': CHANNELS_FORMAT_LAST, 'NCHW': CHANNELS_FORMAT_FIRST}[data_format]
166
+
167
+ conv_fw_attr[GROUPS] = 1
168
+
168
169
  _reuse_params = {REUSE: conv_func_node.reuse, REUSE_GROUP: conv_func_node.reuse_group}
169
170
  conv_node = BaseNode(conv_func_node.name, conv_fw_attr, conv_func_node.input_shape, conv_func_node.output_shape,
170
171
  weights, Conv2D, **_reuse_params)
@@ -172,6 +173,31 @@ class Conv2dFuncToConv2dLayer(common.BaseSubstitution):
172
173
  replace_conv_node(graph, conv_node, conv_func_node, remove_add_node=b is not None)
173
174
  return graph
174
175
 
176
+ def _parse_tf_stride_dilation(self, node, key) -> Optional[Tuple[int, int]]:
177
+ """
178
+ Extract stride/dilation param from tf node and convert it to keras format (suitable for Conv2D).
179
+
180
+ Args:
181
+ node: node
182
+ key: param key
183
+
184
+ Returns:
185
+ Parsed value or None if non-standard.
186
+ """
187
+ v = node.op_call_kwargs.get(key)
188
+ if v is None:
189
+ return 1, 1
190
+ if isinstance(v, int):
191
+ return v, v
192
+ if len(v) == 1:
193
+ return v[0], v[0]
194
+ if len(v) == 4:
195
+ if v[0] > 1 and v[-1] > 1:
196
+ return None
197
+ else:
198
+ return v[1:3]
199
+ return tuple(v)
200
+
175
201
 
176
202
  class DwConv2dFuncToDwConv2dLayer(common.BaseSubstitution):
177
203
  """