mct-nightly 2.2.0.20240930.532__py3-none-any.whl → 2.2.0.20241002.500__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mct-nightly
3
- Version: 2.2.0.20240930.532
3
+ Version: 2.2.0.20241002.500
4
4
  Summary: A Model Compression Toolkit for neural networks
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- model_compression_toolkit/__init__.py,sha256=t1wNg0lS5JpWEPKyQf-PoKxWqEgc_58HfooVwjVCFsQ,1573
1
+ model_compression_toolkit/__init__.py,sha256=0OYwjkiM5Okt4kzkKaRTqc3Iq-TCsVN1uOUzYDgffog,1573
2
2
  model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
3
3
  model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
4
4
  model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
@@ -15,7 +15,7 @@ model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7w
15
15
  model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
16
16
  model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
17
17
  model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
18
- model_compression_toolkit/core/common/model_collector.py,sha256=ofcepKtxc3j2Ouz6BpAKXTzPgjABnpRP47ndmJCXAkk,8352
18
+ model_compression_toolkit/core/common/model_collector.py,sha256=T0J3hLmqJI8eQEXlBfqbnPNJ4XpPUp0zfRSjL0CQYu8,8381
19
19
  model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
20
20
  model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
21
21
  model_compression_toolkit/core/common/similarity_analyzer.py,sha256=FikcIqgQQpfiXr9VJvgl-wk8OyH7-LvC8ku7TkhJfJM,9200
@@ -219,7 +219,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
219
219
  model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
220
220
  model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
221
221
  model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
222
- model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=xmcJyU-rkIDX1a_X9LILzf2Ko2z_4I4xnlHkezKH-2w,27669
222
+ model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=2RGf4ii9zxJwGLA3mp-qzDp4khFaYNUNN95bNuNNZ0c,27868
223
223
  model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
224
224
  model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=xpKj99OZKT9NT0vKIl_cOe8d89d2gef1gKoNT6PFElE,4989
225
225
  model_compression_toolkit/core/pytorch/utils.py,sha256=GE7T8q93I5C4As0iOias_dk9HpOvXM1N6---dJlyD60,3863
@@ -249,6 +249,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_
249
249
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
250
250
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
251
251
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
252
+ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=ziL7jwTnjzTf7BHPRPYgWBSCUrSXSyjZnvQqsJhD1nM,12466
252
253
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=3WCLvPyx7tVkM0rwYhYq-gntCzW9R_DcImR1ucKlPac,10772
253
254
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
254
255
  model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
@@ -550,8 +551,8 @@ tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2
550
551
  tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
551
552
  tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
552
553
  tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
553
- mct_nightly-2.2.0.20240930.532.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
554
- mct_nightly-2.2.0.20240930.532.dist-info/METADATA,sha256=p4oG8xi2574mBzHWDgHCKzuBAz49q9DLDya346NWcYc,20830
555
- mct_nightly-2.2.0.20240930.532.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
556
- mct_nightly-2.2.0.20240930.532.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
557
- mct_nightly-2.2.0.20240930.532.dist-info/RECORD,,
554
+ mct_nightly-2.2.0.20241002.500.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
555
+ mct_nightly-2.2.0.20241002.500.dist-info/METADATA,sha256=XUo1iMNL1fh6tGsBz-kglXfHvxhfWdOBebgokDVQJ4A,20830
556
+ mct_nightly-2.2.0.20241002.500.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
557
+ mct_nightly-2.2.0.20241002.500.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
558
+ mct_nightly-2.2.0.20241002.500.dist-info/RECORD,,
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
27
27
  from model_compression_toolkit import pruning
28
28
  from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
29
29
 
30
- __version__ = "2.2.0.20240930.000532"
30
+ __version__ = "2.2.0.20241002.000500"
@@ -158,7 +158,7 @@ class ModelCollector:
158
158
  for td, sc in zip(tensor_data, self.stats_containers_list):
159
159
  if isinstance(sc, (list, tuple)):
160
160
  if not isinstance(td, (list, tuple)):
161
- Logger.critical('\'tensor_data\' must be a list or a tuple if \'stats_containers_list\' contains lists or tuples.') # pragma: no cover
161
+ Logger.critical(f"\'tensor_data\' is of type {type(td)} but must be of the same type as \'stats_containers_list\', which is of type {type(sc)}") # pragma: no cover
162
162
  if len(sc) != len(td):
163
163
  Logger.critical('\'tensor_data\' and \'stats_containers_list\' must have matching lengths') # pragma: no cover
164
164
  for tdi, sci in zip(td, sc):
@@ -0,0 +1,231 @@
1
+ # Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import torch.nn as nn
17
+ import torch
18
+ import math
19
+ from copy import copy
20
+ import numpy as np
21
+ from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
22
+ from model_compression_toolkit.core.common import BaseSubstitution
23
+ from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
24
+ from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
25
+ from model_compression_toolkit.core.pytorch.constants import DIM
26
+ from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
27
+
28
+
29
+ class ScaledDotProductDecomposition(BaseSubstitution):
30
+ """
31
+ Decompose torch.nn.scale_dot_product into its base operators:
32
+ Transpose (over k)
33
+ MatMul(over q and transposed k)
34
+ Mul (for scaling)
35
+ Add (for masking. optional operation, used in cases that attn_mask ig given)
36
+ Dropout
37
+ Softmax
38
+ Matmul.
39
+ """
40
+
41
+ def __init__(self):
42
+ """
43
+ Matches scaled_dot_product_attention node.
44
+ """
45
+ super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))
46
+
47
+ def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str,
48
+ input_index: int, default_value: any) -> any:
49
+ """
50
+ Search for attention_node input value in op_call_kwargs (using input_name) and op_call_args (using input_index).
51
+ In case the input is not given, returns its default_value.
52
+
53
+ """
54
+ if input_name in attention_node.op_call_kwargs:
55
+ return attention_node.op_call_kwargs[input_name]
56
+ elif len(attention_node.op_call_args) > input_index: # input order: [attn_mask, dropout_p, is_causal]
57
+ return attention_node.op_call_args[input_index]
58
+ return default_value
59
+
60
+ def _get_attention_input_nodes(self, graph: Graph, attention_node: FunctionalNode) -> dict:
61
+ q, k, v = 0, 1, 2
62
+ prev_nodes = graph.get_prev_nodes(attention_node, sink_index_sorted=True)
63
+ q_node, k_node, v_node = prev_nodes[q], prev_nodes[k], prev_nodes[v]
64
+ return {"q": q_node, "k": k_node, "v": v_node}
65
+
66
+ def _get_transpose_k_node(self, attention_node_name: str, key_node: BaseNode) -> BaseNode:
67
+ input_shape, output_shape = copy(key_node.output_shape[0]), copy(key_node.output_shape[0])
68
+ output_shape[-2], output_shape[-1] = input_shape[-1], input_shape[-2]
69
+ transpose_node = FunctionalNode(name=f"{attention_node_name}_{key_node.name}_transpose",
70
+ framework_attr={},
71
+ input_shape=input_shape,
72
+ output_shape=output_shape,
73
+ weights={},
74
+ layer_class=torch.transpose,
75
+ op_call_args=[-1, -2], # axes to transpose
76
+ op_call_kwargs={},
77
+ functional_op=torch.transpose)
78
+ return transpose_node
79
+
80
+ def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matmul_node: BaseNode) -> FunctionalNode:
81
+ """
82
+ :return: multiplication node that represents multiplication by the scale factor
83
+ """
84
+ scale_name = f'{attention_node.name}_scale'
85
+ q_embd_axis = -1
86
+ input_scale = self._get_input_by_name(attention_node, "scale", 3, None)
87
+ scale_factor = input_scale if input_scale else (1 / math.sqrt(q_node.output_shape[0][q_embd_axis]))
88
+ scale_node = FunctionalNode(name=scale_name,
89
+ framework_attr={},
90
+ input_shape=(matmul_node.output_shape),
91
+ output_shape=matmul_node.output_shape,
92
+ weights={},
93
+ layer_class=torch.mul,
94
+ op_call_args=[scale_factor],
95
+ op_call_kwargs={},
96
+ functional_op=torch.mul)
97
+ return scale_node
98
+
99
+ def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transposed_k_node: BaseNode) -> BaseNode:
100
+ matmul1_output_shape = copy(q_node.output_shape[0])
101
+ matmul1_output_shape[-2] = q_node.output_shape[0][-2]
102
+ matmul1_output_shape[-1] = transposed_k_node.output_shape[-1]
103
+ matmul_name = f'{attention_node_name}_matmul1'
104
+ return FunctionalNode(name=matmul_name,
105
+ framework_attr={},
106
+ input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)),
107
+ output_shape=tuple(matmul1_output_shape),
108
+ weights={},
109
+ layer_class=torch.matmul,
110
+ op_call_args=[],
111
+ op_call_kwargs={},
112
+ functional_op=torch.matmul)
113
+
114
+ def _get_mask_node(self, attention_node: FunctionalNode, scale_node: FunctionalNode) -> FunctionalNode:
115
+ """
116
+ :return: Add operator node with the mask tensor as input. In case there is no mask tensor, returns None.
117
+ """
118
+ attention_mask_tensor = self._get_attention_mask_tensor(attention_node)
119
+ if attention_mask_tensor is None:
120
+ return None
121
+ mask_node_name = f'{attention_node.name}_mask'
122
+ return FunctionalNode(name=mask_node_name,
123
+ framework_attr={},
124
+ input_shape=(scale_node.output_shape),
125
+ output_shape=scale_node.output_shape,
126
+ weights={},
127
+ layer_class=torch.add,
128
+ op_call_args=[],
129
+ op_call_kwargs={'other': attention_mask_tensor},
130
+ functional_op=torch.add)
131
+
132
+ def _get_softmax_node(self, attention_node_name: str, in_out_shape: tuple) -> BaseNode:
133
+ softmax_name = f'{attention_node_name}_softmax'
134
+ return BaseNode(name=softmax_name,
135
+ framework_attr={DIM: -1},
136
+ input_shape=in_out_shape,
137
+ output_shape=in_out_shape,
138
+ weights={},
139
+ layer_class=nn.Softmax)
140
+
141
+ def _get_matmul2_node(self, attention_node_name: str, softmax_node: BaseNode, v_node: BaseNode) -> FunctionalNode:
142
+ matmul2_output_shape = list(copy(softmax_node.output_shape))
143
+ matmul2_output_shape[-2] = softmax_node.output_shape[-2]
144
+ matmul2_output_shape[-1] = v_node.output_shape[0][-1]
145
+ matmul2_name = f'{attention_node_name}_matmul2'
146
+ return FunctionalNode(name=matmul2_name,
147
+ framework_attr={},
148
+ input_shape=(tuple(softmax_node.output_shape), tuple(v_node.output_shape[0])),
149
+ output_shape=tuple(matmul2_output_shape),
150
+ weights={},
151
+ layer_class=torch.matmul,
152
+ op_call_args=[],
153
+ op_call_kwargs={},
154
+ functional_op=torch.matmul)
155
+
156
+ def _get_attention_mask_tensor(self, attention_node: FunctionalNode) -> torch.Tensor:
157
+ """
158
+ :return: mask tensor given as part of attention node input.
159
+ Since MCT doesn't support infinite values, we don't support is_causal (torch.nn.scale_dot_product_attention
160
+ argument) and boolean mask tensor, as they both require -inf values.
161
+ """
162
+ device = get_working_device()
163
+ is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False)
164
+ if is_causal:
165
+ raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.")
166
+ input_weights = list(attention_node.weights.values())
167
+ attn_mask = input_weights[0] if len(input_weights) > 0 else None
168
+ if attn_mask is not None and (attn_mask.dtype == "bool"):
169
+ raise NotImplementedError(
170
+ "scaled_dot_product_attention attn_mask is of type boolean, which is not supported.")
171
+ if attn_mask is not None and (not np.isfinite(attn_mask).all()):
172
+ raise NotImplementedError(
173
+ "scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
174
+ return torch.from_numpy(attn_mask).to(device) if attn_mask is not None else None
175
+
176
+ def _get_dropout_node(self, attention_node: FunctionalNode, in_out_shape: tuple) -> BaseNode:
177
+ dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)
178
+ dropout_name = f'{attention_node.name}_dropout'
179
+ return BaseNode(name=dropout_name,
180
+ framework_attr={"p": dropout_p},
181
+ input_shape=in_out_shape,
182
+ output_shape=in_out_shape,
183
+ weights={},
184
+ layer_class=nn.Dropout)
185
+
186
+ def substitute(self, graph: Graph, attention_node: FunctionalNode) -> Graph:
187
+ """
188
+ Removes a scaled_dot_product_attention node from the graph, and replaces it with a compatible graph that
189
+ consists of:
190
+ Transpose (over k)
191
+ MatMul(over q and transposed k)
192
+ Mul (for scaling)
193
+ Add (for masking. optional operation, used in cases that attn_mask ig given)
194
+ Dropout
195
+ Softmax
196
+ Matmul.
197
+ :param graph: A Graph to apply substitution on
198
+ :param attention_node: the node to replace
199
+ :return: A graph after the substitution
200
+ """
201
+ print("In scale_dot_product_attention substitution@@@@@@@@")
202
+ input_nodes = self._get_attention_input_nodes(graph, attention_node)
203
+ q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
204
+ transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node)
205
+ matmul_node = self._get_matmul_node(attention_node.name, q_node, transpose_k_node)
206
+ scale_node = self._get_scale_node(attention_node, q_node, matmul_node)
207
+ mask_node = self._get_mask_node(attention_node, scale_node)
208
+ softmax_node = self._get_softmax_node(attention_node.name, matmul_node.output_shape)
209
+ dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape)
210
+ matmul2_node = self._get_matmul2_node(attention_node.name, softmax_node, v_node)
211
+
212
+ graph.add_node_with_in_edges(transpose_k_node, [k_node])
213
+ graph.add_node_with_in_edges(matmul_node, [q_node, transpose_k_node])
214
+ graph.add_node_with_in_edges(scale_node, [matmul_node])
215
+ if mask_node:
216
+ graph.add_node_with_in_edges(mask_node, [scale_node])
217
+ graph.add_node_with_in_edges(softmax_node, [mask_node if mask_node else scale_node])
218
+ graph.add_node_with_in_edges(dropout_node, [softmax_node])
219
+ graph.add_node_with_in_edges(matmul2_node, [dropout_node if dropout_node else softmax_node, v_node])
220
+
221
+ graph_outputs = graph.get_outputs()
222
+ for i, g_out in enumerate(graph_outputs):
223
+ if g_out.node == attention_node:
224
+ graph_outputs[i] = OutTensor(node=matmul2_node, node_out_index=g_out.node_out_index)
225
+
226
+ graph.reconnect_out_edges(current_node=attention_node, new_node=matmul2_node)
227
+ graph.remove_edge(q_node, attention_node)
228
+ graph.remove_edge(k_node, attention_node)
229
+ graph.remove_edge(v_node, attention_node)
230
+ graph.remove_node(attention_node, new_graph_outputs=graph_outputs)
231
+ return graph
@@ -53,6 +53,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.li
53
53
  pytorch_linear_collapsing
54
54
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
55
55
  import MultiHeadAttentionDecomposition
56
+ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.scaled_dot_product_attention import \
57
+ ScaledDotProductDecomposition
56
58
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.transform_function_call_method import \
57
59
  TransformFunctionCallMethod
58
60
  from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \
@@ -237,6 +239,7 @@ class PytorchImplementation(FrameworkImplementation):
237
239
  """
238
240
  return [ReshapeWithStaticShapes(),
239
241
  MultiHeadAttentionDecomposition(),
242
+ ScaledDotProductDecomposition(),
240
243
  TransformFunctionCallMethod(),
241
244
  FunctionalConvSubstitution(fw_info),
242
245
  FunctionalBatchNorm(),