onnx2tf 1.29.5__py3-none-any.whl → 1.29.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/ops/AffineGrid.py +187 -0
- onnx2tf/ops/Attention.py +612 -0
- onnx2tf/ops/BlackmanWindow.py +115 -0
- onnx2tf/ops/GridSample.py +466 -369
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.7.dist-info}/METADATA +6 -6
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.7.dist-info}/RECORD +11 -8
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.7.dist-info}/WHEEL +1 -1
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.7.dist-info}/licenses/LICENSE +0 -0
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.7.dist-info}/licenses/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.7.dist-info}/top_level.txt +0 -0
onnx2tf/ops/Attention.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
1
|
+
import random
|
|
2
|
+
random.seed(0)
|
|
3
|
+
import numpy as np
|
|
4
|
+
np.random.seed(0)
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
import onnx_graphsurgeon as gs
|
|
7
|
+
from onnx2tf.utils.common_functions import (
|
|
8
|
+
replace_parameter,
|
|
9
|
+
get_constant_or_variable,
|
|
10
|
+
print_node_info,
|
|
11
|
+
inverted_operation_enable_disable,
|
|
12
|
+
make_tf_node_info,
|
|
13
|
+
get_replacement_parameter,
|
|
14
|
+
pre_process_transpose,
|
|
15
|
+
post_process_transpose,
|
|
16
|
+
transpose_with_flexing_deterrence,
|
|
17
|
+
)
|
|
18
|
+
from onnx2tf.utils.enums import (
|
|
19
|
+
ONNX_DTYPES_TO_TF_DTYPES,
|
|
20
|
+
)
|
|
21
|
+
from typing import Any, Optional
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _is_empty_input(input_var: Any) -> bool:
|
|
25
|
+
if input_var is None:
|
|
26
|
+
return True
|
|
27
|
+
if isinstance(input_var, str) and input_var == "":
|
|
28
|
+
return True
|
|
29
|
+
if hasattr(input_var, 'name') and input_var.name == "":
|
|
30
|
+
return True
|
|
31
|
+
return False
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _to_tf_tensor(input_var: Any, tf_layers_dict: dict) -> Any:
|
|
35
|
+
return tf_layers_dict[input_var.name]['tf_node'] \
|
|
36
|
+
if isinstance(input_var, gs.Variable) else input_var
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _pad_or_slice_last_dim(
|
|
40
|
+
mask_tensor: Any,
|
|
41
|
+
target_len: Any,
|
|
42
|
+
pad_value: Any,
|
|
43
|
+
) -> Any:
|
|
44
|
+
mask_shape = tf.shape(mask_tensor)
|
|
45
|
+
last_dim = mask_shape[-1]
|
|
46
|
+
|
|
47
|
+
def _pad():
|
|
48
|
+
pad_len = target_len - last_dim
|
|
49
|
+
rank = tf.rank(mask_tensor)
|
|
50
|
+
paddings = tf.concat(
|
|
51
|
+
values=[
|
|
52
|
+
tf.zeros(
|
|
53
|
+
shape=tf.stack([rank - 1, 2]),
|
|
54
|
+
dtype=tf.int32,
|
|
55
|
+
),
|
|
56
|
+
tf.reshape(tf.stack([0, pad_len]), (1, 2)),
|
|
57
|
+
],
|
|
58
|
+
axis=0,
|
|
59
|
+
)
|
|
60
|
+
return tf.pad(
|
|
61
|
+
tensor=mask_tensor,
|
|
62
|
+
paddings=paddings,
|
|
63
|
+
constant_values=pad_value,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _slice():
|
|
67
|
+
return mask_tensor[..., :target_len]
|
|
68
|
+
|
|
69
|
+
def _identity():
|
|
70
|
+
return mask_tensor
|
|
71
|
+
|
|
72
|
+
return tf.cond(
|
|
73
|
+
pred=last_dim < target_len,
|
|
74
|
+
true_fn=_pad,
|
|
75
|
+
false_fn=lambda: tf.cond(
|
|
76
|
+
pred=last_dim > target_len,
|
|
77
|
+
true_fn=_slice,
|
|
78
|
+
false_fn=_identity,
|
|
79
|
+
),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@print_node_info
|
|
84
|
+
@inverted_operation_enable_disable
|
|
85
|
+
@get_replacement_parameter
|
|
86
|
+
def make_node(
|
|
87
|
+
*,
|
|
88
|
+
graph_node: gs.Node,
|
|
89
|
+
tf_layers_dict: dict,
|
|
90
|
+
**kwargs: dict,
|
|
91
|
+
):
|
|
92
|
+
"""Attention
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
graph_node: gs.Node
|
|
97
|
+
graph_surgeon Node
|
|
98
|
+
|
|
99
|
+
tf_layers_dict: dict
|
|
100
|
+
optype, shape, dtype, tensorflow graph
|
|
101
|
+
"""
|
|
102
|
+
before_op_output_shape_trans_q = \
|
|
103
|
+
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
104
|
+
before_op_output_shape_trans_k = \
|
|
105
|
+
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
106
|
+
before_op_output_shape_trans_v = \
|
|
107
|
+
tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
|
|
108
|
+
before_op_output_shape_trans = \
|
|
109
|
+
before_op_output_shape_trans_q \
|
|
110
|
+
and before_op_output_shape_trans_k \
|
|
111
|
+
and before_op_output_shape_trans_v
|
|
112
|
+
|
|
113
|
+
graph_node_input_q = get_constant_or_variable(
|
|
114
|
+
graph_node.inputs[0],
|
|
115
|
+
before_op_output_shape_trans,
|
|
116
|
+
)
|
|
117
|
+
graph_node_input_k = get_constant_or_variable(
|
|
118
|
+
graph_node.inputs[1],
|
|
119
|
+
before_op_output_shape_trans,
|
|
120
|
+
)
|
|
121
|
+
graph_node_input_v = get_constant_or_variable(
|
|
122
|
+
graph_node.inputs[2],
|
|
123
|
+
before_op_output_shape_trans,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
attn_mask_input = None
|
|
127
|
+
past_key_input = None
|
|
128
|
+
past_value_input = None
|
|
129
|
+
nonpad_kv_seqlen_input = None
|
|
130
|
+
|
|
131
|
+
if len(graph_node.inputs) >= 4 and not _is_empty_input(graph_node.inputs[3]):
|
|
132
|
+
attn_mask_input = graph_node.inputs[3]
|
|
133
|
+
if len(graph_node.inputs) >= 5 and not _is_empty_input(graph_node.inputs[4]):
|
|
134
|
+
past_key_input = graph_node.inputs[4]
|
|
135
|
+
if len(graph_node.inputs) >= 6 and not _is_empty_input(graph_node.inputs[5]):
|
|
136
|
+
past_value_input = graph_node.inputs[5]
|
|
137
|
+
if len(graph_node.inputs) >= 7 and not _is_empty_input(graph_node.inputs[6]):
|
|
138
|
+
nonpad_kv_seqlen_input = graph_node.inputs[6]
|
|
139
|
+
|
|
140
|
+
if (past_key_input is None) != (past_value_input is None):
|
|
141
|
+
past_key_input = None
|
|
142
|
+
past_value_input = None
|
|
143
|
+
|
|
144
|
+
Q = _to_tf_tensor(graph_node_input_q, tf_layers_dict)
|
|
145
|
+
K = _to_tf_tensor(graph_node_input_k, tf_layers_dict)
|
|
146
|
+
V = _to_tf_tensor(graph_node_input_v, tf_layers_dict)
|
|
147
|
+
|
|
148
|
+
# Pre-process transpose
|
|
149
|
+
Q = pre_process_transpose(
|
|
150
|
+
value_before_transpose=Q,
|
|
151
|
+
param_target='inputs',
|
|
152
|
+
param_name=graph_node.inputs[0].name,
|
|
153
|
+
**kwargs,
|
|
154
|
+
)
|
|
155
|
+
K = pre_process_transpose(
|
|
156
|
+
value_before_transpose=K,
|
|
157
|
+
param_target='inputs',
|
|
158
|
+
param_name=graph_node.inputs[1].name,
|
|
159
|
+
**kwargs,
|
|
160
|
+
)
|
|
161
|
+
V = pre_process_transpose(
|
|
162
|
+
value_before_transpose=V,
|
|
163
|
+
param_target='inputs',
|
|
164
|
+
param_name=graph_node.inputs[2].name,
|
|
165
|
+
**kwargs,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
attn_mask = None
|
|
169
|
+
if attn_mask_input is not None:
|
|
170
|
+
graph_node_input_attn = get_constant_or_variable(
|
|
171
|
+
attn_mask_input,
|
|
172
|
+
before_op_output_shape_trans,
|
|
173
|
+
)
|
|
174
|
+
attn_mask = _to_tf_tensor(graph_node_input_attn, tf_layers_dict)
|
|
175
|
+
attn_mask = pre_process_transpose(
|
|
176
|
+
value_before_transpose=attn_mask,
|
|
177
|
+
param_target='inputs',
|
|
178
|
+
param_name=attn_mask_input.name,
|
|
179
|
+
**kwargs,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
past_key = None
|
|
183
|
+
past_value = None
|
|
184
|
+
if past_key_input is not None and past_value_input is not None:
|
|
185
|
+
graph_node_input_past_key = get_constant_or_variable(
|
|
186
|
+
past_key_input,
|
|
187
|
+
before_op_output_shape_trans,
|
|
188
|
+
)
|
|
189
|
+
graph_node_input_past_value = get_constant_or_variable(
|
|
190
|
+
past_value_input,
|
|
191
|
+
before_op_output_shape_trans,
|
|
192
|
+
)
|
|
193
|
+
past_key = _to_tf_tensor(graph_node_input_past_key, tf_layers_dict)
|
|
194
|
+
past_value = _to_tf_tensor(graph_node_input_past_value, tf_layers_dict)
|
|
195
|
+
past_key = pre_process_transpose(
|
|
196
|
+
value_before_transpose=past_key,
|
|
197
|
+
param_target='inputs',
|
|
198
|
+
param_name=past_key_input.name,
|
|
199
|
+
**kwargs,
|
|
200
|
+
)
|
|
201
|
+
past_value = pre_process_transpose(
|
|
202
|
+
value_before_transpose=past_value,
|
|
203
|
+
param_target='inputs',
|
|
204
|
+
param_name=past_value_input.name,
|
|
205
|
+
**kwargs,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
nonpad_kv_seqlen = None
|
|
209
|
+
if nonpad_kv_seqlen_input is not None:
|
|
210
|
+
graph_node_input_nonpad = get_constant_or_variable(
|
|
211
|
+
nonpad_kv_seqlen_input,
|
|
212
|
+
before_op_output_shape_trans,
|
|
213
|
+
)
|
|
214
|
+
nonpad_kv_seqlen = _to_tf_tensor(graph_node_input_nonpad, tf_layers_dict)
|
|
215
|
+
nonpad_kv_seqlen = pre_process_transpose(
|
|
216
|
+
value_before_transpose=nonpad_kv_seqlen,
|
|
217
|
+
param_target='inputs',
|
|
218
|
+
param_name=nonpad_kv_seqlen_input.name,
|
|
219
|
+
**kwargs,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
graph_node_output_y: gs.Variable = graph_node.outputs[0]
|
|
223
|
+
y_shape = graph_node_output_y.shape
|
|
224
|
+
y_dtype = graph_node_output_y.dtype
|
|
225
|
+
|
|
226
|
+
graph_node_output_present_key: Optional[gs.Variable] = None
|
|
227
|
+
graph_node_output_present_value: Optional[gs.Variable] = None
|
|
228
|
+
graph_node_output_qk_matmul_output: Optional[gs.Variable] = None
|
|
229
|
+
if len(graph_node.outputs) >= 2 and not _is_empty_input(graph_node.outputs[1]):
|
|
230
|
+
graph_node_output_present_key = graph_node.outputs[1]
|
|
231
|
+
if len(graph_node.outputs) >= 3 and not _is_empty_input(graph_node.outputs[2]):
|
|
232
|
+
graph_node_output_present_value = graph_node.outputs[2]
|
|
233
|
+
if len(graph_node.outputs) >= 4 and not _is_empty_input(graph_node.outputs[3]):
|
|
234
|
+
graph_node_output_qk_matmul_output = graph_node.outputs[3]
|
|
235
|
+
|
|
236
|
+
if (graph_node_output_present_key is None) != (graph_node_output_present_value is None):
|
|
237
|
+
graph_node_output_present_key = None
|
|
238
|
+
graph_node_output_present_value = None
|
|
239
|
+
|
|
240
|
+
# Preserving Graph Structure (Dict)
|
|
241
|
+
tf_layers_dict[graph_node_output_y.name] = {
|
|
242
|
+
'optype': graph_node.op,
|
|
243
|
+
'shape': y_shape,
|
|
244
|
+
'dtype': y_dtype,
|
|
245
|
+
}
|
|
246
|
+
if graph_node_output_present_key is not None:
|
|
247
|
+
tf_layers_dict[graph_node_output_present_key.name] = {
|
|
248
|
+
'optype': graph_node.op,
|
|
249
|
+
'shape': graph_node_output_present_key.shape,
|
|
250
|
+
'dtype': graph_node_output_present_key.dtype,
|
|
251
|
+
}
|
|
252
|
+
if graph_node_output_present_value is not None:
|
|
253
|
+
tf_layers_dict[graph_node_output_present_value.name] = {
|
|
254
|
+
'optype': graph_node.op,
|
|
255
|
+
'shape': graph_node_output_present_value.shape,
|
|
256
|
+
'dtype': graph_node_output_present_value.dtype,
|
|
257
|
+
}
|
|
258
|
+
if graph_node_output_qk_matmul_output is not None:
|
|
259
|
+
tf_layers_dict[graph_node_output_qk_matmul_output.name] = {
|
|
260
|
+
'optype': graph_node.op,
|
|
261
|
+
'shape': graph_node_output_qk_matmul_output.shape,
|
|
262
|
+
'dtype': graph_node_output_qk_matmul_output.dtype,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# Attributes
|
|
266
|
+
opset = kwargs['opset']
|
|
267
|
+
is_causal = bool(graph_node.attrs.get('is_causal', 0))
|
|
268
|
+
qk_matmul_output_mode = int(graph_node.attrs.get('qk_matmul_output_mode', 0))
|
|
269
|
+
softcap = graph_node.attrs.get('softcap', 0.0)
|
|
270
|
+
scale = graph_node.attrs.get('scale', None)
|
|
271
|
+
q_num_heads = graph_node.attrs.get('q_num_heads', None)
|
|
272
|
+
kv_num_heads = graph_node.attrs.get('kv_num_heads', None)
|
|
273
|
+
softmax_precision = graph_node.attrs.get('softmax_precision', None)
|
|
274
|
+
|
|
275
|
+
is_causal = replace_parameter(
|
|
276
|
+
value_before_replacement=is_causal,
|
|
277
|
+
param_target='attributes',
|
|
278
|
+
param_name='is_causal',
|
|
279
|
+
**kwargs,
|
|
280
|
+
)
|
|
281
|
+
qk_matmul_output_mode = replace_parameter(
|
|
282
|
+
value_before_replacement=qk_matmul_output_mode,
|
|
283
|
+
param_target='attributes',
|
|
284
|
+
param_name='qk_matmul_output_mode',
|
|
285
|
+
**kwargs,
|
|
286
|
+
)
|
|
287
|
+
softcap = replace_parameter(
|
|
288
|
+
value_before_replacement=softcap,
|
|
289
|
+
param_target='attributes',
|
|
290
|
+
param_name='softcap',
|
|
291
|
+
**kwargs,
|
|
292
|
+
)
|
|
293
|
+
scale = replace_parameter(
|
|
294
|
+
value_before_replacement=scale,
|
|
295
|
+
param_target='attributes',
|
|
296
|
+
param_name='scale',
|
|
297
|
+
**kwargs,
|
|
298
|
+
)
|
|
299
|
+
q_num_heads = replace_parameter(
|
|
300
|
+
value_before_replacement=q_num_heads,
|
|
301
|
+
param_target='attributes',
|
|
302
|
+
param_name='q_num_heads',
|
|
303
|
+
**kwargs,
|
|
304
|
+
)
|
|
305
|
+
kv_num_heads = replace_parameter(
|
|
306
|
+
value_before_replacement=kv_num_heads,
|
|
307
|
+
param_target='attributes',
|
|
308
|
+
param_name='kv_num_heads',
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
softmax_precision = replace_parameter(
|
|
312
|
+
value_before_replacement=softmax_precision,
|
|
313
|
+
param_target='attributes',
|
|
314
|
+
param_name='softmax_precision',
|
|
315
|
+
**kwargs,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Reshape 3D inputs to 4D if needed
|
|
319
|
+
q_rank = len(Q.shape) if Q.shape is not None else None
|
|
320
|
+
k_rank = len(K.shape) if K.shape is not None else None
|
|
321
|
+
v_rank = len(V.shape) if V.shape is not None else None
|
|
322
|
+
|
|
323
|
+
input_q_is_3d = q_rank == 3
|
|
324
|
+
input_k_is_3d = k_rank == 3
|
|
325
|
+
input_v_is_3d = v_rank == 3
|
|
326
|
+
input_is_3d = input_q_is_3d and input_k_is_3d and input_v_is_3d
|
|
327
|
+
|
|
328
|
+
if input_is_3d:
|
|
329
|
+
if q_num_heads is None:
|
|
330
|
+
q_num_heads = 1
|
|
331
|
+
if kv_num_heads is None:
|
|
332
|
+
kv_num_heads = 1
|
|
333
|
+
|
|
334
|
+
q_shape = tf.shape(Q)
|
|
335
|
+
k_shape = tf.shape(K)
|
|
336
|
+
v_shape = tf.shape(V)
|
|
337
|
+
|
|
338
|
+
q_num_heads_tensor = \
|
|
339
|
+
tf.constant(q_num_heads, dtype=tf.int32) \
|
|
340
|
+
if isinstance(q_num_heads, int) else tf.cast(q_num_heads, tf.int32)
|
|
341
|
+
kv_num_heads_tensor = \
|
|
342
|
+
tf.constant(kv_num_heads, dtype=tf.int32) \
|
|
343
|
+
if isinstance(kv_num_heads, int) else tf.cast(kv_num_heads, tf.int32)
|
|
344
|
+
|
|
345
|
+
q_head_size = tf.math.floordiv(q_shape[2], q_num_heads_tensor)
|
|
346
|
+
k_head_size = tf.math.floordiv(k_shape[2], kv_num_heads_tensor)
|
|
347
|
+
v_head_size = tf.math.floordiv(v_shape[2], kv_num_heads_tensor)
|
|
348
|
+
|
|
349
|
+
Q = tf.reshape(
|
|
350
|
+
tensor=Q,
|
|
351
|
+
shape=tf.stack([q_shape[0], q_shape[1], q_num_heads_tensor, q_head_size]),
|
|
352
|
+
)
|
|
353
|
+
K = tf.reshape(
|
|
354
|
+
tensor=K,
|
|
355
|
+
shape=tf.stack([k_shape[0], k_shape[1], kv_num_heads_tensor, k_head_size]),
|
|
356
|
+
)
|
|
357
|
+
V = tf.reshape(
|
|
358
|
+
tensor=V,
|
|
359
|
+
shape=tf.stack([v_shape[0], v_shape[1], kv_num_heads_tensor, v_head_size]),
|
|
360
|
+
)
|
|
361
|
+
Q = transpose_with_flexing_deterrence(
|
|
362
|
+
input_tensor=Q,
|
|
363
|
+
perm=[0, 2, 1, 3],
|
|
364
|
+
**kwargs,
|
|
365
|
+
)
|
|
366
|
+
K = transpose_with_flexing_deterrence(
|
|
367
|
+
input_tensor=K,
|
|
368
|
+
perm=[0, 2, 1, 3],
|
|
369
|
+
**kwargs,
|
|
370
|
+
)
|
|
371
|
+
V = transpose_with_flexing_deterrence(
|
|
372
|
+
input_tensor=V,
|
|
373
|
+
perm=[0, 2, 1, 3],
|
|
374
|
+
**kwargs,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# Ensure dtype alignment for matmul
|
|
378
|
+
q_dtype = Q.dtype
|
|
379
|
+
if K.dtype != q_dtype:
|
|
380
|
+
K = tf.cast(K, q_dtype)
|
|
381
|
+
if V.dtype != q_dtype:
|
|
382
|
+
V = tf.cast(V, q_dtype)
|
|
383
|
+
|
|
384
|
+
if past_key is not None and past_value is not None:
|
|
385
|
+
if past_key.dtype != q_dtype:
|
|
386
|
+
past_key = tf.cast(past_key, q_dtype)
|
|
387
|
+
if past_value.dtype != q_dtype:
|
|
388
|
+
past_value = tf.cast(past_value, q_dtype)
|
|
389
|
+
K = tf.concat([past_key, K], axis=2)
|
|
390
|
+
V = tf.concat([past_value, V], axis=2)
|
|
391
|
+
|
|
392
|
+
present_key = K
|
|
393
|
+
present_value = V
|
|
394
|
+
|
|
395
|
+
# Heads for GQA/MQA
|
|
396
|
+
q_heads = Q.shape[1] if Q.shape[1] is not None else tf.shape(Q)[1]
|
|
397
|
+
kv_heads = present_key.shape[1] if present_key.shape[1] is not None else tf.shape(present_key)[1]
|
|
398
|
+
|
|
399
|
+
attn_key = present_key
|
|
400
|
+
attn_value = present_value
|
|
401
|
+
if isinstance(q_heads, int) and isinstance(kv_heads, int):
|
|
402
|
+
if kv_heads != q_heads:
|
|
403
|
+
repeat = q_heads // kv_heads
|
|
404
|
+
attn_key = tf.repeat(attn_key, repeats=repeat, axis=1)
|
|
405
|
+
attn_value = tf.repeat(attn_value, repeats=repeat, axis=1)
|
|
406
|
+
else:
|
|
407
|
+
repeat = tf.math.floordiv(tf.cast(q_heads, tf.int32), tf.cast(kv_heads, tf.int32))
|
|
408
|
+
attn_key = tf.repeat(attn_key, repeats=repeat, axis=1)
|
|
409
|
+
attn_value = tf.repeat(attn_value, repeats=repeat, axis=1)
|
|
410
|
+
|
|
411
|
+
# Scale Q and K
|
|
412
|
+
head_size = tf.shape(Q)[-1]
|
|
413
|
+
if scale is None:
|
|
414
|
+
scale_value = tf.math.rsqrt(tf.cast(head_size, q_dtype))
|
|
415
|
+
else:
|
|
416
|
+
scale_value = tf.cast(scale, q_dtype)
|
|
417
|
+
|
|
418
|
+
scale_sqrt = tf.sqrt(scale_value)
|
|
419
|
+
Q = Q * scale_sqrt
|
|
420
|
+
K_scaled = attn_key * scale_sqrt
|
|
421
|
+
|
|
422
|
+
# QK^T
|
|
423
|
+
qk_matmul_output = tf.matmul(
|
|
424
|
+
a=Q,
|
|
425
|
+
b=K_scaled,
|
|
426
|
+
transpose_b=True,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
q_seq_len = tf.shape(Q)[2]
|
|
430
|
+
kv_seq_len = tf.shape(present_key)[2]
|
|
431
|
+
|
|
432
|
+
attn_bias = None
|
|
433
|
+
|
|
434
|
+
if is_causal:
|
|
435
|
+
causal_mask = tf.linalg.band_part(
|
|
436
|
+
tf.ones(shape=tf.stack([q_seq_len, kv_seq_len]), dtype=tf.float32),
|
|
437
|
+
-1,
|
|
438
|
+
0,
|
|
439
|
+
)
|
|
440
|
+
causal_mask = tf.cast(causal_mask, tf.bool)
|
|
441
|
+
neg_inf = tf.constant(-np.inf, dtype=q_dtype)
|
|
442
|
+
causal_bias = tf.where(
|
|
443
|
+
condition=causal_mask,
|
|
444
|
+
x=tf.zeros_like(causal_mask, dtype=q_dtype),
|
|
445
|
+
y=neg_inf,
|
|
446
|
+
)
|
|
447
|
+
attn_bias = causal_bias
|
|
448
|
+
|
|
449
|
+
if attn_mask is not None and attn_mask != "":
|
|
450
|
+
if attn_mask.dtype != tf.bool:
|
|
451
|
+
attn_mask = tf.cast(attn_mask, q_dtype)
|
|
452
|
+
pad_value = tf.constant(-np.inf, dtype=q_dtype)
|
|
453
|
+
else:
|
|
454
|
+
pad_value = False
|
|
455
|
+
|
|
456
|
+
if opset >= 24:
|
|
457
|
+
attn_mask = _pad_or_slice_last_dim(
|
|
458
|
+
mask_tensor=attn_mask,
|
|
459
|
+
target_len=kv_seq_len,
|
|
460
|
+
pad_value=pad_value,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
if attn_mask.dtype == tf.bool:
|
|
464
|
+
neg_inf = tf.constant(-np.inf, dtype=q_dtype)
|
|
465
|
+
mask_bias = tf.where(
|
|
466
|
+
condition=attn_mask,
|
|
467
|
+
x=tf.zeros_like(attn_mask, dtype=q_dtype),
|
|
468
|
+
y=neg_inf,
|
|
469
|
+
)
|
|
470
|
+
else:
|
|
471
|
+
mask_bias = attn_mask
|
|
472
|
+
|
|
473
|
+
attn_bias = mask_bias if attn_bias is None else attn_bias + mask_bias
|
|
474
|
+
|
|
475
|
+
if opset >= 24 and nonpad_kv_seqlen is not None and past_key is None and past_value is None:
|
|
476
|
+
nonpad_kv_seqlen = tf.cast(nonpad_kv_seqlen, tf.int32)
|
|
477
|
+
seq_range = tf.range(kv_seq_len, dtype=tf.int32)
|
|
478
|
+
seq_range = tf.reshape(seq_range, shape=[1, -1])
|
|
479
|
+
nonpad_mask = seq_range < tf.reshape(nonpad_kv_seqlen, shape=[-1, 1])
|
|
480
|
+
nonpad_mask = tf.reshape(nonpad_mask, shape=[-1, 1, 1, kv_seq_len])
|
|
481
|
+
neg_inf = tf.constant(-np.inf, dtype=q_dtype)
|
|
482
|
+
nonpad_bias = tf.where(
|
|
483
|
+
condition=nonpad_mask,
|
|
484
|
+
x=tf.zeros_like(nonpad_mask, dtype=q_dtype),
|
|
485
|
+
y=neg_inf,
|
|
486
|
+
)
|
|
487
|
+
attn_bias = nonpad_bias if attn_bias is None else attn_bias + nonpad_bias
|
|
488
|
+
|
|
489
|
+
if attn_bias is not None:
|
|
490
|
+
qk_with_bias = qk_matmul_output + attn_bias
|
|
491
|
+
else:
|
|
492
|
+
qk_with_bias = qk_matmul_output
|
|
493
|
+
|
|
494
|
+
# Softcap
|
|
495
|
+
qk_softcap = qk_with_bias
|
|
496
|
+
if isinstance(softcap, (float, int, np.floating, np.integer)):
|
|
497
|
+
if softcap > 0:
|
|
498
|
+
softcap_value = tf.cast(softcap, q_dtype)
|
|
499
|
+
qk_softcap = softcap_value * tf.math.tanh(qk_with_bias / softcap_value)
|
|
500
|
+
else:
|
|
501
|
+
softcap_value = tf.cast(softcap, q_dtype)
|
|
502
|
+
safe_softcap = tf.where(
|
|
503
|
+
condition=tf.equal(softcap_value, 0),
|
|
504
|
+
x=tf.ones_like(softcap_value),
|
|
505
|
+
y=softcap_value,
|
|
506
|
+
)
|
|
507
|
+
qk_softcap = tf.where(
|
|
508
|
+
condition=softcap_value > 0,
|
|
509
|
+
x=softcap_value * tf.math.tanh(qk_with_bias / safe_softcap),
|
|
510
|
+
y=qk_with_bias,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Softmax
|
|
514
|
+
softmax_dtype = None
|
|
515
|
+
if softmax_precision is not None:
|
|
516
|
+
if softmax_precision in ONNX_DTYPES_TO_TF_DTYPES:
|
|
517
|
+
softmax_dtype = ONNX_DTYPES_TO_TF_DTYPES[softmax_precision]
|
|
518
|
+
elif int(softmax_precision) == 16:
|
|
519
|
+
softmax_dtype = tf.bfloat16
|
|
520
|
+
|
|
521
|
+
qk_softmax_input = qk_softcap
|
|
522
|
+
if softmax_dtype is not None and softmax_dtype != qk_softmax_input.dtype:
|
|
523
|
+
qk_softmax_input = tf.cast(qk_softmax_input, softmax_dtype)
|
|
524
|
+
|
|
525
|
+
qk_softmax = tf.nn.softmax(
|
|
526
|
+
logits=qk_softmax_input,
|
|
527
|
+
axis=-1,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if softmax_dtype is not None and qk_softmax.dtype != q_dtype:
|
|
531
|
+
qk_softmax = tf.cast(qk_softmax, q_dtype)
|
|
532
|
+
|
|
533
|
+
# Output
|
|
534
|
+
Y = tf.matmul(
|
|
535
|
+
a=qk_softmax,
|
|
536
|
+
b=attn_value,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
if input_is_3d:
|
|
540
|
+
Y = transpose_with_flexing_deterrence(
|
|
541
|
+
input_tensor=Y,
|
|
542
|
+
perm=[0, 2, 1, 3],
|
|
543
|
+
**kwargs,
|
|
544
|
+
)
|
|
545
|
+
y_shape_dyn = tf.shape(Y)
|
|
546
|
+
Y = tf.reshape(
|
|
547
|
+
tensor=Y,
|
|
548
|
+
shape=tf.stack([y_shape_dyn[0], y_shape_dyn[1], y_shape_dyn[2] * y_shape_dyn[3]]),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
tf_layers_dict[graph_node_output_y.name]['tf_node'] = Y
|
|
552
|
+
|
|
553
|
+
# Outputs for KV cache
|
|
554
|
+
if graph_node_output_present_key is not None:
|
|
555
|
+
tf_layers_dict[graph_node_output_present_key.name]['tf_node'] = present_key
|
|
556
|
+
if graph_node_output_present_value is not None:
|
|
557
|
+
tf_layers_dict[graph_node_output_present_value.name]['tf_node'] = present_value
|
|
558
|
+
|
|
559
|
+
# qk_matmul_output output mode
|
|
560
|
+
if graph_node_output_qk_matmul_output is not None:
|
|
561
|
+
qk_output = qk_matmul_output
|
|
562
|
+
if qk_matmul_output_mode == 1:
|
|
563
|
+
qk_output = qk_with_bias
|
|
564
|
+
elif qk_matmul_output_mode == 2:
|
|
565
|
+
qk_output = qk_softcap
|
|
566
|
+
elif qk_matmul_output_mode == 3:
|
|
567
|
+
qk_output = qk_softmax
|
|
568
|
+
tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'] = qk_output
|
|
569
|
+
|
|
570
|
+
# Post-process transpose
|
|
571
|
+
tf_layers_dict[graph_node_output_y.name]['tf_node'] = post_process_transpose(
|
|
572
|
+
value_before_transpose=tf_layers_dict[graph_node_output_y.name]['tf_node'],
|
|
573
|
+
param_target='outputs',
|
|
574
|
+
param_name=graph_node.outputs[0].name,
|
|
575
|
+
**kwargs,
|
|
576
|
+
)
|
|
577
|
+
if graph_node_output_present_key is not None:
|
|
578
|
+
tf_layers_dict[graph_node_output_present_key.name]['tf_node'] = post_process_transpose(
|
|
579
|
+
value_before_transpose=tf_layers_dict[graph_node_output_present_key.name]['tf_node'],
|
|
580
|
+
param_target='outputs',
|
|
581
|
+
param_name=graph_node_output_present_key.name,
|
|
582
|
+
**kwargs,
|
|
583
|
+
)
|
|
584
|
+
if graph_node_output_present_value is not None:
|
|
585
|
+
tf_layers_dict[graph_node_output_present_value.name]['tf_node'] = post_process_transpose(
|
|
586
|
+
value_before_transpose=tf_layers_dict[graph_node_output_present_value.name]['tf_node'],
|
|
587
|
+
param_target='outputs',
|
|
588
|
+
param_name=graph_node_output_present_value.name,
|
|
589
|
+
**kwargs,
|
|
590
|
+
)
|
|
591
|
+
if graph_node_output_qk_matmul_output is not None:
|
|
592
|
+
tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'] = post_process_transpose(
|
|
593
|
+
value_before_transpose=tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'],
|
|
594
|
+
param_target='outputs',
|
|
595
|
+
param_name=graph_node_output_qk_matmul_output.name,
|
|
596
|
+
**kwargs,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# Generation of Debug Info
|
|
600
|
+
tf_layers_dict[graph_node_output_y.name]['tf_node_info'] = \
|
|
601
|
+
make_tf_node_info(
|
|
602
|
+
node_info={
|
|
603
|
+
'tf_op_type': 'Attention',
|
|
604
|
+
'tf_inputs': {
|
|
605
|
+
'a': qk_softmax,
|
|
606
|
+
'b': attn_value,
|
|
607
|
+
},
|
|
608
|
+
'tf_outputs': {
|
|
609
|
+
'output': tf_layers_dict[graph_node_output_y.name]['tf_node'],
|
|
610
|
+
},
|
|
611
|
+
}
|
|
612
|
+
)
|