mct-nightly 2.2.0.20240930.532__py3-none-any.whl → 2.2.0.20241002.500__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/METADATA +1 -1
- {mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/RECORD +9 -8
- model_compression_toolkit/__init__.py +1 -1
- model_compression_toolkit/core/common/model_collector.py +1 -1
- model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py +231 -0
- model_compression_toolkit/core/pytorch/pytorch_implementation.py +3 -0
- {mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/LICENSE.md +0 -0
- {mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/WHEEL +0 -0
- {mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/top_level.txt +0 -0
{mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/RECORD
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
model_compression_toolkit/__init__.py,sha256=
|
1
|
+
model_compression_toolkit/__init__.py,sha256=0OYwjkiM5Okt4kzkKaRTqc3Iq-TCsVN1uOUzYDgffog,1573
|
2
2
|
model_compression_toolkit/constants.py,sha256=i4wYheBkIdQmsQA-axIpcT3YiSO1USNc-jaNiNE8w6E,3920
|
3
3
|
model_compression_toolkit/defaultdict.py,sha256=LSc-sbZYXENMCw3U9F4GiXuv67IKpdn0Qm7Fr11jy-4,2277
|
4
4
|
model_compression_toolkit/logger.py,sha256=3DByV41XHRR3kLTJNbpaMmikL8icd9e1N-nkQAY9oDk,4567
|
@@ -15,7 +15,7 @@ model_compression_toolkit/core/common/framework_implementation.py,sha256=kSg2f7w
|
|
15
15
|
model_compression_toolkit/core/common/framework_info.py,sha256=1ZMMGS9ip-kSflqkartyNRt9aQ5ub1WepuTRcTy-YSQ,6337
|
16
16
|
model_compression_toolkit/core/common/memory_computation.py,sha256=ixoSpV5ZYZGyzhre3kQcvR2sNA8KBsPZ3lgbkDnw9Cs,1205
|
17
17
|
model_compression_toolkit/core/common/model_builder_mode.py,sha256=jll9-59OPaE3ug7Y9-lLyV99_FoNHxkGZMgcm0Vkpss,1324
|
18
|
-
model_compression_toolkit/core/common/model_collector.py,sha256=
|
18
|
+
model_compression_toolkit/core/common/model_collector.py,sha256=T0J3hLmqJI8eQEXlBfqbnPNJ4XpPUp0zfRSjL0CQYu8,8381
|
19
19
|
model_compression_toolkit/core/common/model_validation.py,sha256=LaG8wd6aZl0OJgieE3SeiVDEPxtk8IHq9-3wSnmWhY4,1214
|
20
20
|
model_compression_toolkit/core/common/node_prior_info.py,sha256=WXX_PrGVG9M9I_REG5ZzFBohwmV4yf356sZnrja_FLo,2832
|
21
21
|
model_compression_toolkit/core/common/similarity_analyzer.py,sha256=FikcIqgQQpfiXr9VJvgl-wk8OyH7-LvC8ku7TkhJfJM,9200
|
@@ -219,7 +219,7 @@ model_compression_toolkit/core/pytorch/__init__.py,sha256=Rf1RcYmelmdZmBV5qOKvKW
|
|
219
219
|
model_compression_toolkit/core/pytorch/constants.py,sha256=YwD_joIF0vK8UG2vW1NVvg36pCNWA0vHOXjAgy_XWn0,2794
|
220
220
|
model_compression_toolkit/core/pytorch/default_framework_info.py,sha256=-Vls1P_8Ckm_18nnOsmQkZ71SmzHwtQLbQ383Z4Rb-U,4365
|
221
221
|
model_compression_toolkit/core/pytorch/pytorch_device_config.py,sha256=S25cuw10AW3SEN_fRAGRcG_I3wdvvQx1ehSJzPnn-UI,4404
|
222
|
-
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=
|
222
|
+
model_compression_toolkit/core/pytorch/pytorch_implementation.py,sha256=2RGf4ii9zxJwGLA3mp-qzDp4khFaYNUNN95bNuNNZ0c,27868
|
223
223
|
model_compression_toolkit/core/pytorch/pytorch_node_prior_info.py,sha256=2LDQ7qupglHQ7o1Am7LWdfYVacfQnl-aW2N6l9det1w,3264
|
224
224
|
model_compression_toolkit/core/pytorch/resource_utilization_data_facade.py,sha256=xpKj99OZKT9NT0vKIl_cOe8d89d2gef1gKoNT6PFElE,4989
|
225
225
|
model_compression_toolkit/core/pytorch/utils.py,sha256=GE7T8q93I5C4As0iOias_dk9HpOvXM1N6---dJlyD60,3863
|
@@ -249,6 +249,7 @@ model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_
|
|
249
249
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/reshape_with_static_shapes.py,sha256=hAZXzrEinHa-dJHLj39Hy_9Q-13QyO95rtYVSLrhvT8,4915
|
250
250
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/residual_collapsing.py,sha256=DcJEIkGvBdIMOelNIwaJUZ5UsAHiGnDJPR20I464vWo,2929
|
251
251
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scale_equalization.py,sha256=XFtU9yuBmoZlX0f0mS6otMPWMk-RcWs94XdvvTNhW8Y,3303
|
252
|
+
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/scaled_dot_product_attention.py,sha256=ziL7jwTnjzTf7BHPRPYgWBSCUrSXSyjZnvQqsJhD1nM,12466
|
252
253
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/shift_negative_activation.py,sha256=3WCLvPyx7tVkM0rwYhYq-gntCzW9R_DcImR1ucKlPac,10772
|
253
254
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/softmax_shift.py,sha256=05lV4pIL3hJkZl4JQPV4wk_EFD0eYLG5b8cdzvZk4P8,1588
|
254
255
|
model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/transform_function_call_method.py,sha256=EC9Dvp-_UlpDWnipnf8ds65wh_Y-T8pXAFIwRScWpiY,2044
|
@@ -550,8 +551,8 @@ tests_pytest/pytorch/gptq/test_annealing_cfg.py,sha256=hGC7L6mp3N1ygcJ3OctgS_Fz2
|
|
550
551
|
tests_pytest/pytorch/gptq/test_gradual_act_quantization.py,sha256=tI01aFIUaiCILL5Qn--p1E_rLBUelxLdSY3k52lwcx0,4594
|
551
552
|
tests_pytest/pytorch/trainable_infrastructure/__init__.py,sha256=RAe8mgIr1V8dRIQtLf_dSG5zTUCKuQzxyybYx1dzEAs,697
|
552
553
|
tests_pytest/pytorch/trainable_infrastructure/test_linear_annealing.py,sha256=eNOpSp0GoLxtEdiRypBp8jaujXfdNxBwKh5Rd-P7WLs,1786
|
553
|
-
mct_nightly-2.2.0.
|
554
|
-
mct_nightly-2.2.0.
|
555
|
-
mct_nightly-2.2.0.
|
556
|
-
mct_nightly-2.2.0.
|
557
|
-
mct_nightly-2.2.0.
|
554
|
+
mct_nightly-2.2.0.20241002.500.dist-info/LICENSE.md,sha256=aYSSIb-5AFPeITTvXm1UAoe0uYBiMmSS8flvXaaFUks,10174
|
555
|
+
mct_nightly-2.2.0.20241002.500.dist-info/METADATA,sha256=XUo1iMNL1fh6tGsBz-kglXfHvxhfWdOBebgokDVQJ4A,20830
|
556
|
+
mct_nightly-2.2.0.20241002.500.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
557
|
+
mct_nightly-2.2.0.20241002.500.dist-info/top_level.txt,sha256=csdfSXhtRnpWYRzjZ-dRLIhOmM2TEdVXUxG05A5fgb8,39
|
558
|
+
mct_nightly-2.2.0.20241002.500.dist-info/RECORD,,
|
@@ -27,4 +27,4 @@ from model_compression_toolkit import data_generation
|
|
27
27
|
from model_compression_toolkit import pruning
|
28
28
|
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model
|
29
29
|
|
30
|
-
__version__ = "2.2.0.
|
30
|
+
__version__ = "2.2.0.20241002.000500"
|
@@ -158,7 +158,7 @@ class ModelCollector:
|
|
158
158
|
for td, sc in zip(tensor_data, self.stats_containers_list):
|
159
159
|
if isinstance(sc, (list, tuple)):
|
160
160
|
if not isinstance(td, (list, tuple)):
|
161
|
-
Logger.critical(
|
161
|
+
Logger.critical(f"\'tensor_data\' is of type {type(td)} but must be of the same type as \'stats_containers_list\', which is of type {type(sc)}") # pragma: no cover
|
162
162
|
if len(sc) != len(td):
|
163
163
|
Logger.critical('\'tensor_data\' and \'stats_containers_list\' must have matching lengths') # pragma: no cover
|
164
164
|
for tdi, sci in zip(td, sc):
|
@@ -0,0 +1,231 @@
|
|
1
|
+
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch
|
18
|
+
import math
|
19
|
+
from copy import copy
|
20
|
+
import numpy as np
|
21
|
+
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
|
22
|
+
from model_compression_toolkit.core.common import BaseSubstitution
|
23
|
+
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
|
24
|
+
from model_compression_toolkit.core.common.graph.base_graph import Graph, BaseNode, OutTensor
|
25
|
+
from model_compression_toolkit.core.pytorch.constants import DIM
|
26
|
+
from model_compression_toolkit.core.pytorch.pytorch_device_config import get_working_device
|
27
|
+
|
28
|
+
|
29
|
+
class ScaledDotProductDecomposition(BaseSubstitution):
|
30
|
+
"""
|
31
|
+
Decompose torch.nn.scale_dot_product into its base operators:
|
32
|
+
Transpose (over k)
|
33
|
+
MatMul(over q and transposed k)
|
34
|
+
Mul (for scaling)
|
35
|
+
Add (for masking. optional operation, used in cases that attn_mask ig given)
|
36
|
+
Dropout
|
37
|
+
Softmax
|
38
|
+
Matmul.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(self):
|
42
|
+
"""
|
43
|
+
Matches scaled_dot_product_attention node.
|
44
|
+
"""
|
45
|
+
super().__init__(matcher_instance=NodeOperationMatcher(nn.functional.scaled_dot_product_attention))
|
46
|
+
|
47
|
+
def _get_input_by_name(self, attention_node: FunctionalNode, input_name: str,
|
48
|
+
input_index: int, default_value: any) -> any:
|
49
|
+
"""
|
50
|
+
Search for attention_node input value in op_call_kwargs (using input_name) and op_call_args (using input_index).
|
51
|
+
In case the input is not given, returns its default_value.
|
52
|
+
|
53
|
+
"""
|
54
|
+
if input_name in attention_node.op_call_kwargs:
|
55
|
+
return attention_node.op_call_kwargs[input_name]
|
56
|
+
elif len(attention_node.op_call_args) > input_index: # input order: [attn_mask, dropout_p, is_causal]
|
57
|
+
return attention_node.op_call_args[input_index]
|
58
|
+
return default_value
|
59
|
+
|
60
|
+
def _get_attention_input_nodes(self, graph: Graph, attention_node: FunctionalNode) -> dict:
|
61
|
+
q, k, v = 0, 1, 2
|
62
|
+
prev_nodes = graph.get_prev_nodes(attention_node, sink_index_sorted=True)
|
63
|
+
q_node, k_node, v_node = prev_nodes[q], prev_nodes[k], prev_nodes[v]
|
64
|
+
return {"q": q_node, "k": k_node, "v": v_node}
|
65
|
+
|
66
|
+
def _get_transpose_k_node(self, attention_node_name: str, key_node: BaseNode) -> BaseNode:
|
67
|
+
input_shape, output_shape = copy(key_node.output_shape[0]), copy(key_node.output_shape[0])
|
68
|
+
output_shape[-2], output_shape[-1] = input_shape[-1], input_shape[-2]
|
69
|
+
transpose_node = FunctionalNode(name=f"{attention_node_name}_{key_node.name}_transpose",
|
70
|
+
framework_attr={},
|
71
|
+
input_shape=input_shape,
|
72
|
+
output_shape=output_shape,
|
73
|
+
weights={},
|
74
|
+
layer_class=torch.transpose,
|
75
|
+
op_call_args=[-1, -2], # axes to transpose
|
76
|
+
op_call_kwargs={},
|
77
|
+
functional_op=torch.transpose)
|
78
|
+
return transpose_node
|
79
|
+
|
80
|
+
def _get_scale_node(self, attention_node: FunctionalNode, q_node: BaseNode, matmul_node: BaseNode) -> FunctionalNode:
|
81
|
+
"""
|
82
|
+
:return: multiplication node that represents multiplication by the scale factor
|
83
|
+
"""
|
84
|
+
scale_name = f'{attention_node.name}_scale'
|
85
|
+
q_embd_axis = -1
|
86
|
+
input_scale = self._get_input_by_name(attention_node, "scale", 3, None)
|
87
|
+
scale_factor = input_scale if input_scale else (1 / math.sqrt(q_node.output_shape[0][q_embd_axis]))
|
88
|
+
scale_node = FunctionalNode(name=scale_name,
|
89
|
+
framework_attr={},
|
90
|
+
input_shape=(matmul_node.output_shape),
|
91
|
+
output_shape=matmul_node.output_shape,
|
92
|
+
weights={},
|
93
|
+
layer_class=torch.mul,
|
94
|
+
op_call_args=[scale_factor],
|
95
|
+
op_call_kwargs={},
|
96
|
+
functional_op=torch.mul)
|
97
|
+
return scale_node
|
98
|
+
|
99
|
+
def _get_matmul_node(self, attention_node_name: str, q_node: BaseNode, transposed_k_node: BaseNode) -> BaseNode:
|
100
|
+
matmul1_output_shape = copy(q_node.output_shape[0])
|
101
|
+
matmul1_output_shape[-2] = q_node.output_shape[0][-2]
|
102
|
+
matmul1_output_shape[-1] = transposed_k_node.output_shape[-1]
|
103
|
+
matmul_name = f'{attention_node_name}_matmul1'
|
104
|
+
return FunctionalNode(name=matmul_name,
|
105
|
+
framework_attr={},
|
106
|
+
input_shape=(tuple(q_node.output_shape[0]), tuple(transposed_k_node.output_shape)),
|
107
|
+
output_shape=tuple(matmul1_output_shape),
|
108
|
+
weights={},
|
109
|
+
layer_class=torch.matmul,
|
110
|
+
op_call_args=[],
|
111
|
+
op_call_kwargs={},
|
112
|
+
functional_op=torch.matmul)
|
113
|
+
|
114
|
+
def _get_mask_node(self, attention_node: FunctionalNode, scale_node: FunctionalNode) -> FunctionalNode:
|
115
|
+
"""
|
116
|
+
:return: Add operator node with the mask tensor as input. In case there is no mask tensor, returns None.
|
117
|
+
"""
|
118
|
+
attention_mask_tensor = self._get_attention_mask_tensor(attention_node)
|
119
|
+
if attention_mask_tensor is None:
|
120
|
+
return None
|
121
|
+
mask_node_name = f'{attention_node.name}_mask'
|
122
|
+
return FunctionalNode(name=mask_node_name,
|
123
|
+
framework_attr={},
|
124
|
+
input_shape=(scale_node.output_shape),
|
125
|
+
output_shape=scale_node.output_shape,
|
126
|
+
weights={},
|
127
|
+
layer_class=torch.add,
|
128
|
+
op_call_args=[],
|
129
|
+
op_call_kwargs={'other': attention_mask_tensor},
|
130
|
+
functional_op=torch.add)
|
131
|
+
|
132
|
+
def _get_softmax_node(self, attention_node_name: str, in_out_shape: tuple) -> BaseNode:
|
133
|
+
softmax_name = f'{attention_node_name}_softmax'
|
134
|
+
return BaseNode(name=softmax_name,
|
135
|
+
framework_attr={DIM: -1},
|
136
|
+
input_shape=in_out_shape,
|
137
|
+
output_shape=in_out_shape,
|
138
|
+
weights={},
|
139
|
+
layer_class=nn.Softmax)
|
140
|
+
|
141
|
+
def _get_matmul2_node(self, attention_node_name: str, softmax_node: BaseNode, v_node: BaseNode) -> FunctionalNode:
|
142
|
+
matmul2_output_shape = list(copy(softmax_node.output_shape))
|
143
|
+
matmul2_output_shape[-2] = softmax_node.output_shape[-2]
|
144
|
+
matmul2_output_shape[-1] = v_node.output_shape[0][-1]
|
145
|
+
matmul2_name = f'{attention_node_name}_matmul2'
|
146
|
+
return FunctionalNode(name=matmul2_name,
|
147
|
+
framework_attr={},
|
148
|
+
input_shape=(tuple(softmax_node.output_shape), tuple(v_node.output_shape[0])),
|
149
|
+
output_shape=tuple(matmul2_output_shape),
|
150
|
+
weights={},
|
151
|
+
layer_class=torch.matmul,
|
152
|
+
op_call_args=[],
|
153
|
+
op_call_kwargs={},
|
154
|
+
functional_op=torch.matmul)
|
155
|
+
|
156
|
+
def _get_attention_mask_tensor(self, attention_node: FunctionalNode) -> torch.Tensor:
|
157
|
+
"""
|
158
|
+
:return: mask tensor given as part of attention node input.
|
159
|
+
Since MCT doesn't support infinite values, we don't support is_causal (torch.nn.scale_dot_product_attention
|
160
|
+
argument) and boolean mask tensor, as they both require -inf values.
|
161
|
+
"""
|
162
|
+
device = get_working_device()
|
163
|
+
is_causal = self._get_input_by_name(attention_node, "is_causal", 2, False)
|
164
|
+
if is_causal:
|
165
|
+
raise NotImplementedError("scaled_dot_product_attention is_causal feature is not implemented.")
|
166
|
+
input_weights = list(attention_node.weights.values())
|
167
|
+
attn_mask = input_weights[0] if len(input_weights) > 0 else None
|
168
|
+
if attn_mask is not None and (attn_mask.dtype == "bool"):
|
169
|
+
raise NotImplementedError(
|
170
|
+
"scaled_dot_product_attention attn_mask is of type boolean, which is not supported.")
|
171
|
+
if attn_mask is not None and (not np.isfinite(attn_mask).all()):
|
172
|
+
raise NotImplementedError(
|
173
|
+
"scaled_dot_product_attention attn_mask contains infinite value, which is not supported.")
|
174
|
+
return torch.from_numpy(attn_mask).to(device) if attn_mask is not None else None
|
175
|
+
|
176
|
+
def _get_dropout_node(self, attention_node: FunctionalNode, in_out_shape: tuple) -> BaseNode:
|
177
|
+
dropout_p = attention_node.op_call_kwargs.get('dropout_p', 0)
|
178
|
+
dropout_name = f'{attention_node.name}_dropout'
|
179
|
+
return BaseNode(name=dropout_name,
|
180
|
+
framework_attr={"p": dropout_p},
|
181
|
+
input_shape=in_out_shape,
|
182
|
+
output_shape=in_out_shape,
|
183
|
+
weights={},
|
184
|
+
layer_class=nn.Dropout)
|
185
|
+
|
186
|
+
def substitute(self, graph: Graph, attention_node: FunctionalNode) -> Graph:
|
187
|
+
"""
|
188
|
+
Removes a scaled_dot_product_attention node from the graph, and replaces it with a compatible graph that
|
189
|
+
consists of:
|
190
|
+
Transpose (over k)
|
191
|
+
MatMul(over q and transposed k)
|
192
|
+
Mul (for scaling)
|
193
|
+
Add (for masking. optional operation, used in cases that attn_mask ig given)
|
194
|
+
Dropout
|
195
|
+
Softmax
|
196
|
+
Matmul.
|
197
|
+
:param graph: A Graph to apply substitution on
|
198
|
+
:param attention_node: the node to replace
|
199
|
+
:return: A graph after the substitution
|
200
|
+
"""
|
201
|
+
print("In scale_dot_product_attention substitution@@@@@@@@")
|
202
|
+
input_nodes = self._get_attention_input_nodes(graph, attention_node)
|
203
|
+
q_node, k_node, v_node = input_nodes["q"], input_nodes["k"], input_nodes["v"]
|
204
|
+
transpose_k_node = self._get_transpose_k_node(attention_node.name, k_node)
|
205
|
+
matmul_node = self._get_matmul_node(attention_node.name, q_node, transpose_k_node)
|
206
|
+
scale_node = self._get_scale_node(attention_node, q_node, matmul_node)
|
207
|
+
mask_node = self._get_mask_node(attention_node, scale_node)
|
208
|
+
softmax_node = self._get_softmax_node(attention_node.name, matmul_node.output_shape)
|
209
|
+
dropout_node = self._get_dropout_node(attention_node, softmax_node.output_shape)
|
210
|
+
matmul2_node = self._get_matmul2_node(attention_node.name, softmax_node, v_node)
|
211
|
+
|
212
|
+
graph.add_node_with_in_edges(transpose_k_node, [k_node])
|
213
|
+
graph.add_node_with_in_edges(matmul_node, [q_node, transpose_k_node])
|
214
|
+
graph.add_node_with_in_edges(scale_node, [matmul_node])
|
215
|
+
if mask_node:
|
216
|
+
graph.add_node_with_in_edges(mask_node, [scale_node])
|
217
|
+
graph.add_node_with_in_edges(softmax_node, [mask_node if mask_node else scale_node])
|
218
|
+
graph.add_node_with_in_edges(dropout_node, [softmax_node])
|
219
|
+
graph.add_node_with_in_edges(matmul2_node, [dropout_node if dropout_node else softmax_node, v_node])
|
220
|
+
|
221
|
+
graph_outputs = graph.get_outputs()
|
222
|
+
for i, g_out in enumerate(graph_outputs):
|
223
|
+
if g_out.node == attention_node:
|
224
|
+
graph_outputs[i] = OutTensor(node=matmul2_node, node_out_index=g_out.node_out_index)
|
225
|
+
|
226
|
+
graph.reconnect_out_edges(current_node=attention_node, new_node=matmul2_node)
|
227
|
+
graph.remove_edge(q_node, attention_node)
|
228
|
+
graph.remove_edge(k_node, attention_node)
|
229
|
+
graph.remove_edge(v_node, attention_node)
|
230
|
+
graph.remove_node(attention_node, new_graph_outputs=graph_outputs)
|
231
|
+
return graph
|
@@ -53,6 +53,8 @@ from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.li
|
|
53
53
|
pytorch_linear_collapsing
|
54
54
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.multi_head_attention_decomposition \
|
55
55
|
import MultiHeadAttentionDecomposition
|
56
|
+
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.scaled_dot_product_attention import \
|
57
|
+
ScaledDotProductDecomposition
|
56
58
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.transform_function_call_method import \
|
57
59
|
TransformFunctionCallMethod
|
58
60
|
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.const_holder_conv import \
|
@@ -237,6 +239,7 @@ class PytorchImplementation(FrameworkImplementation):
|
|
237
239
|
"""
|
238
240
|
return [ReshapeWithStaticShapes(),
|
239
241
|
MultiHeadAttentionDecomposition(),
|
242
|
+
ScaledDotProductDecomposition(),
|
240
243
|
TransformFunctionCallMethod(),
|
241
244
|
FunctionalConvSubstitution(fw_info),
|
242
245
|
FunctionalBatchNorm(),
|
{mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/LICENSE.md
RENAMED
File without changes
|
File without changes
|
{mct_nightly-2.2.0.20240930.532.dist-info → mct_nightly-2.2.0.20241002.500.dist-info}/top_level.txt
RENAMED
File without changes
|