onnx2tf 1.29.5__py3-none-any.whl → 1.29.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,612 @@
1
+ import random
2
+ random.seed(0)
3
+ import numpy as np
4
+ np.random.seed(0)
5
+ import tensorflow as tf
6
+ import onnx_graphsurgeon as gs
7
+ from onnx2tf.utils.common_functions import (
8
+ replace_parameter,
9
+ get_constant_or_variable,
10
+ print_node_info,
11
+ inverted_operation_enable_disable,
12
+ make_tf_node_info,
13
+ get_replacement_parameter,
14
+ pre_process_transpose,
15
+ post_process_transpose,
16
+ transpose_with_flexing_deterrence,
17
+ )
18
+ from onnx2tf.utils.enums import (
19
+ ONNX_DTYPES_TO_TF_DTYPES,
20
+ )
21
+ from typing import Any, Optional
22
+
23
+
24
+ def _is_empty_input(input_var: Any) -> bool:
25
+ if input_var is None:
26
+ return True
27
+ if isinstance(input_var, str) and input_var == "":
28
+ return True
29
+ if hasattr(input_var, 'name') and input_var.name == "":
30
+ return True
31
+ return False
32
+
33
+
34
+ def _to_tf_tensor(input_var: Any, tf_layers_dict: dict) -> Any:
35
+ return tf_layers_dict[input_var.name]['tf_node'] \
36
+ if isinstance(input_var, gs.Variable) else input_var
37
+
38
+
39
+ def _pad_or_slice_last_dim(
40
+ mask_tensor: Any,
41
+ target_len: Any,
42
+ pad_value: Any,
43
+ ) -> Any:
44
+ mask_shape = tf.shape(mask_tensor)
45
+ last_dim = mask_shape[-1]
46
+
47
+ def _pad():
48
+ pad_len = target_len - last_dim
49
+ rank = tf.rank(mask_tensor)
50
+ paddings = tf.concat(
51
+ values=[
52
+ tf.zeros(
53
+ shape=tf.stack([rank - 1, 2]),
54
+ dtype=tf.int32,
55
+ ),
56
+ tf.reshape(tf.stack([0, pad_len]), (1, 2)),
57
+ ],
58
+ axis=0,
59
+ )
60
+ return tf.pad(
61
+ tensor=mask_tensor,
62
+ paddings=paddings,
63
+ constant_values=pad_value,
64
+ )
65
+
66
+ def _slice():
67
+ return mask_tensor[..., :target_len]
68
+
69
+ def _identity():
70
+ return mask_tensor
71
+
72
+ return tf.cond(
73
+ pred=last_dim < target_len,
74
+ true_fn=_pad,
75
+ false_fn=lambda: tf.cond(
76
+ pred=last_dim > target_len,
77
+ true_fn=_slice,
78
+ false_fn=_identity,
79
+ ),
80
+ )
81
+
82
+
83
+ @print_node_info
84
+ @inverted_operation_enable_disable
85
+ @get_replacement_parameter
86
+ def make_node(
87
+ *,
88
+ graph_node: gs.Node,
89
+ tf_layers_dict: dict,
90
+ **kwargs: dict,
91
+ ):
92
+ """Attention
93
+
94
+ Parameters
95
+ ----------
96
+ graph_node: gs.Node
97
+ graph_surgeon Node
98
+
99
+ tf_layers_dict: dict
100
+ optype, shape, dtype, tensorflow graph
101
+ """
102
+ before_op_output_shape_trans_q = \
103
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
104
+ before_op_output_shape_trans_k = \
105
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
106
+ before_op_output_shape_trans_v = \
107
+ tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
108
+ before_op_output_shape_trans = \
109
+ before_op_output_shape_trans_q \
110
+ and before_op_output_shape_trans_k \
111
+ and before_op_output_shape_trans_v
112
+
113
+ graph_node_input_q = get_constant_or_variable(
114
+ graph_node.inputs[0],
115
+ before_op_output_shape_trans,
116
+ )
117
+ graph_node_input_k = get_constant_or_variable(
118
+ graph_node.inputs[1],
119
+ before_op_output_shape_trans,
120
+ )
121
+ graph_node_input_v = get_constant_or_variable(
122
+ graph_node.inputs[2],
123
+ before_op_output_shape_trans,
124
+ )
125
+
126
+ attn_mask_input = None
127
+ past_key_input = None
128
+ past_value_input = None
129
+ nonpad_kv_seqlen_input = None
130
+
131
+ if len(graph_node.inputs) >= 4 and not _is_empty_input(graph_node.inputs[3]):
132
+ attn_mask_input = graph_node.inputs[3]
133
+ if len(graph_node.inputs) >= 5 and not _is_empty_input(graph_node.inputs[4]):
134
+ past_key_input = graph_node.inputs[4]
135
+ if len(graph_node.inputs) >= 6 and not _is_empty_input(graph_node.inputs[5]):
136
+ past_value_input = graph_node.inputs[5]
137
+ if len(graph_node.inputs) >= 7 and not _is_empty_input(graph_node.inputs[6]):
138
+ nonpad_kv_seqlen_input = graph_node.inputs[6]
139
+
140
+ if (past_key_input is None) != (past_value_input is None):
141
+ past_key_input = None
142
+ past_value_input = None
143
+
144
+ Q = _to_tf_tensor(graph_node_input_q, tf_layers_dict)
145
+ K = _to_tf_tensor(graph_node_input_k, tf_layers_dict)
146
+ V = _to_tf_tensor(graph_node_input_v, tf_layers_dict)
147
+
148
+ # Pre-process transpose
149
+ Q = pre_process_transpose(
150
+ value_before_transpose=Q,
151
+ param_target='inputs',
152
+ param_name=graph_node.inputs[0].name,
153
+ **kwargs,
154
+ )
155
+ K = pre_process_transpose(
156
+ value_before_transpose=K,
157
+ param_target='inputs',
158
+ param_name=graph_node.inputs[1].name,
159
+ **kwargs,
160
+ )
161
+ V = pre_process_transpose(
162
+ value_before_transpose=V,
163
+ param_target='inputs',
164
+ param_name=graph_node.inputs[2].name,
165
+ **kwargs,
166
+ )
167
+
168
+ attn_mask = None
169
+ if attn_mask_input is not None:
170
+ graph_node_input_attn = get_constant_or_variable(
171
+ attn_mask_input,
172
+ before_op_output_shape_trans,
173
+ )
174
+ attn_mask = _to_tf_tensor(graph_node_input_attn, tf_layers_dict)
175
+ attn_mask = pre_process_transpose(
176
+ value_before_transpose=attn_mask,
177
+ param_target='inputs',
178
+ param_name=attn_mask_input.name,
179
+ **kwargs,
180
+ )
181
+
182
+ past_key = None
183
+ past_value = None
184
+ if past_key_input is not None and past_value_input is not None:
185
+ graph_node_input_past_key = get_constant_or_variable(
186
+ past_key_input,
187
+ before_op_output_shape_trans,
188
+ )
189
+ graph_node_input_past_value = get_constant_or_variable(
190
+ past_value_input,
191
+ before_op_output_shape_trans,
192
+ )
193
+ past_key = _to_tf_tensor(graph_node_input_past_key, tf_layers_dict)
194
+ past_value = _to_tf_tensor(graph_node_input_past_value, tf_layers_dict)
195
+ past_key = pre_process_transpose(
196
+ value_before_transpose=past_key,
197
+ param_target='inputs',
198
+ param_name=past_key_input.name,
199
+ **kwargs,
200
+ )
201
+ past_value = pre_process_transpose(
202
+ value_before_transpose=past_value,
203
+ param_target='inputs',
204
+ param_name=past_value_input.name,
205
+ **kwargs,
206
+ )
207
+
208
+ nonpad_kv_seqlen = None
209
+ if nonpad_kv_seqlen_input is not None:
210
+ graph_node_input_nonpad = get_constant_or_variable(
211
+ nonpad_kv_seqlen_input,
212
+ before_op_output_shape_trans,
213
+ )
214
+ nonpad_kv_seqlen = _to_tf_tensor(graph_node_input_nonpad, tf_layers_dict)
215
+ nonpad_kv_seqlen = pre_process_transpose(
216
+ value_before_transpose=nonpad_kv_seqlen,
217
+ param_target='inputs',
218
+ param_name=nonpad_kv_seqlen_input.name,
219
+ **kwargs,
220
+ )
221
+
222
+ graph_node_output_y: gs.Variable = graph_node.outputs[0]
223
+ y_shape = graph_node_output_y.shape
224
+ y_dtype = graph_node_output_y.dtype
225
+
226
+ graph_node_output_present_key: Optional[gs.Variable] = None
227
+ graph_node_output_present_value: Optional[gs.Variable] = None
228
+ graph_node_output_qk_matmul_output: Optional[gs.Variable] = None
229
+ if len(graph_node.outputs) >= 2 and not _is_empty_input(graph_node.outputs[1]):
230
+ graph_node_output_present_key = graph_node.outputs[1]
231
+ if len(graph_node.outputs) >= 3 and not _is_empty_input(graph_node.outputs[2]):
232
+ graph_node_output_present_value = graph_node.outputs[2]
233
+ if len(graph_node.outputs) >= 4 and not _is_empty_input(graph_node.outputs[3]):
234
+ graph_node_output_qk_matmul_output = graph_node.outputs[3]
235
+
236
+ if (graph_node_output_present_key is None) != (graph_node_output_present_value is None):
237
+ graph_node_output_present_key = None
238
+ graph_node_output_present_value = None
239
+
240
+ # Preserving Graph Structure (Dict)
241
+ tf_layers_dict[graph_node_output_y.name] = {
242
+ 'optype': graph_node.op,
243
+ 'shape': y_shape,
244
+ 'dtype': y_dtype,
245
+ }
246
+ if graph_node_output_present_key is not None:
247
+ tf_layers_dict[graph_node_output_present_key.name] = {
248
+ 'optype': graph_node.op,
249
+ 'shape': graph_node_output_present_key.shape,
250
+ 'dtype': graph_node_output_present_key.dtype,
251
+ }
252
+ if graph_node_output_present_value is not None:
253
+ tf_layers_dict[graph_node_output_present_value.name] = {
254
+ 'optype': graph_node.op,
255
+ 'shape': graph_node_output_present_value.shape,
256
+ 'dtype': graph_node_output_present_value.dtype,
257
+ }
258
+ if graph_node_output_qk_matmul_output is not None:
259
+ tf_layers_dict[graph_node_output_qk_matmul_output.name] = {
260
+ 'optype': graph_node.op,
261
+ 'shape': graph_node_output_qk_matmul_output.shape,
262
+ 'dtype': graph_node_output_qk_matmul_output.dtype,
263
+ }
264
+
265
+ # Attributes
266
+ opset = kwargs['opset']
267
+ is_causal = bool(graph_node.attrs.get('is_causal', 0))
268
+ qk_matmul_output_mode = int(graph_node.attrs.get('qk_matmul_output_mode', 0))
269
+ softcap = graph_node.attrs.get('softcap', 0.0)
270
+ scale = graph_node.attrs.get('scale', None)
271
+ q_num_heads = graph_node.attrs.get('q_num_heads', None)
272
+ kv_num_heads = graph_node.attrs.get('kv_num_heads', None)
273
+ softmax_precision = graph_node.attrs.get('softmax_precision', None)
274
+
275
+ is_causal = replace_parameter(
276
+ value_before_replacement=is_causal,
277
+ param_target='attributes',
278
+ param_name='is_causal',
279
+ **kwargs,
280
+ )
281
+ qk_matmul_output_mode = replace_parameter(
282
+ value_before_replacement=qk_matmul_output_mode,
283
+ param_target='attributes',
284
+ param_name='qk_matmul_output_mode',
285
+ **kwargs,
286
+ )
287
+ softcap = replace_parameter(
288
+ value_before_replacement=softcap,
289
+ param_target='attributes',
290
+ param_name='softcap',
291
+ **kwargs,
292
+ )
293
+ scale = replace_parameter(
294
+ value_before_replacement=scale,
295
+ param_target='attributes',
296
+ param_name='scale',
297
+ **kwargs,
298
+ )
299
+ q_num_heads = replace_parameter(
300
+ value_before_replacement=q_num_heads,
301
+ param_target='attributes',
302
+ param_name='q_num_heads',
303
+ **kwargs,
304
+ )
305
+ kv_num_heads = replace_parameter(
306
+ value_before_replacement=kv_num_heads,
307
+ param_target='attributes',
308
+ param_name='kv_num_heads',
309
+ **kwargs,
310
+ )
311
+ softmax_precision = replace_parameter(
312
+ value_before_replacement=softmax_precision,
313
+ param_target='attributes',
314
+ param_name='softmax_precision',
315
+ **kwargs,
316
+ )
317
+
318
+ # Reshape 3D inputs to 4D if needed
319
+ q_rank = len(Q.shape) if Q.shape is not None else None
320
+ k_rank = len(K.shape) if K.shape is not None else None
321
+ v_rank = len(V.shape) if V.shape is not None else None
322
+
323
+ input_q_is_3d = q_rank == 3
324
+ input_k_is_3d = k_rank == 3
325
+ input_v_is_3d = v_rank == 3
326
+ input_is_3d = input_q_is_3d and input_k_is_3d and input_v_is_3d
327
+
328
+ if input_is_3d:
329
+ if q_num_heads is None:
330
+ q_num_heads = 1
331
+ if kv_num_heads is None:
332
+ kv_num_heads = 1
333
+
334
+ q_shape = tf.shape(Q)
335
+ k_shape = tf.shape(K)
336
+ v_shape = tf.shape(V)
337
+
338
+ q_num_heads_tensor = \
339
+ tf.constant(q_num_heads, dtype=tf.int32) \
340
+ if isinstance(q_num_heads, int) else tf.cast(q_num_heads, tf.int32)
341
+ kv_num_heads_tensor = \
342
+ tf.constant(kv_num_heads, dtype=tf.int32) \
343
+ if isinstance(kv_num_heads, int) else tf.cast(kv_num_heads, tf.int32)
344
+
345
+ q_head_size = tf.math.floordiv(q_shape[2], q_num_heads_tensor)
346
+ k_head_size = tf.math.floordiv(k_shape[2], kv_num_heads_tensor)
347
+ v_head_size = tf.math.floordiv(v_shape[2], kv_num_heads_tensor)
348
+
349
+ Q = tf.reshape(
350
+ tensor=Q,
351
+ shape=tf.stack([q_shape[0], q_shape[1], q_num_heads_tensor, q_head_size]),
352
+ )
353
+ K = tf.reshape(
354
+ tensor=K,
355
+ shape=tf.stack([k_shape[0], k_shape[1], kv_num_heads_tensor, k_head_size]),
356
+ )
357
+ V = tf.reshape(
358
+ tensor=V,
359
+ shape=tf.stack([v_shape[0], v_shape[1], kv_num_heads_tensor, v_head_size]),
360
+ )
361
+ Q = transpose_with_flexing_deterrence(
362
+ input_tensor=Q,
363
+ perm=[0, 2, 1, 3],
364
+ **kwargs,
365
+ )
366
+ K = transpose_with_flexing_deterrence(
367
+ input_tensor=K,
368
+ perm=[0, 2, 1, 3],
369
+ **kwargs,
370
+ )
371
+ V = transpose_with_flexing_deterrence(
372
+ input_tensor=V,
373
+ perm=[0, 2, 1, 3],
374
+ **kwargs,
375
+ )
376
+
377
+ # Ensure dtype alignment for matmul
378
+ q_dtype = Q.dtype
379
+ if K.dtype != q_dtype:
380
+ K = tf.cast(K, q_dtype)
381
+ if V.dtype != q_dtype:
382
+ V = tf.cast(V, q_dtype)
383
+
384
+ if past_key is not None and past_value is not None:
385
+ if past_key.dtype != q_dtype:
386
+ past_key = tf.cast(past_key, q_dtype)
387
+ if past_value.dtype != q_dtype:
388
+ past_value = tf.cast(past_value, q_dtype)
389
+ K = tf.concat([past_key, K], axis=2)
390
+ V = tf.concat([past_value, V], axis=2)
391
+
392
+ present_key = K
393
+ present_value = V
394
+
395
+ # Heads for GQA/MQA
396
+ q_heads = Q.shape[1] if Q.shape[1] is not None else tf.shape(Q)[1]
397
+ kv_heads = present_key.shape[1] if present_key.shape[1] is not None else tf.shape(present_key)[1]
398
+
399
+ attn_key = present_key
400
+ attn_value = present_value
401
+ if isinstance(q_heads, int) and isinstance(kv_heads, int):
402
+ if kv_heads != q_heads:
403
+ repeat = q_heads // kv_heads
404
+ attn_key = tf.repeat(attn_key, repeats=repeat, axis=1)
405
+ attn_value = tf.repeat(attn_value, repeats=repeat, axis=1)
406
+ else:
407
+ repeat = tf.math.floordiv(tf.cast(q_heads, tf.int32), tf.cast(kv_heads, tf.int32))
408
+ attn_key = tf.repeat(attn_key, repeats=repeat, axis=1)
409
+ attn_value = tf.repeat(attn_value, repeats=repeat, axis=1)
410
+
411
+ # Scale Q and K
412
+ head_size = tf.shape(Q)[-1]
413
+ if scale is None:
414
+ scale_value = tf.math.rsqrt(tf.cast(head_size, q_dtype))
415
+ else:
416
+ scale_value = tf.cast(scale, q_dtype)
417
+
418
+ scale_sqrt = tf.sqrt(scale_value)
419
+ Q = Q * scale_sqrt
420
+ K_scaled = attn_key * scale_sqrt
421
+
422
+ # QK^T
423
+ qk_matmul_output = tf.matmul(
424
+ a=Q,
425
+ b=K_scaled,
426
+ transpose_b=True,
427
+ )
428
+
429
+ q_seq_len = tf.shape(Q)[2]
430
+ kv_seq_len = tf.shape(present_key)[2]
431
+
432
+ attn_bias = None
433
+
434
+ if is_causal:
435
+ causal_mask = tf.linalg.band_part(
436
+ tf.ones(shape=tf.stack([q_seq_len, kv_seq_len]), dtype=tf.float32),
437
+ -1,
438
+ 0,
439
+ )
440
+ causal_mask = tf.cast(causal_mask, tf.bool)
441
+ neg_inf = tf.constant(-np.inf, dtype=q_dtype)
442
+ causal_bias = tf.where(
443
+ condition=causal_mask,
444
+ x=tf.zeros_like(causal_mask, dtype=q_dtype),
445
+ y=neg_inf,
446
+ )
447
+ attn_bias = causal_bias
448
+
449
+ if attn_mask is not None and attn_mask != "":
450
+ if attn_mask.dtype != tf.bool:
451
+ attn_mask = tf.cast(attn_mask, q_dtype)
452
+ pad_value = tf.constant(-np.inf, dtype=q_dtype)
453
+ else:
454
+ pad_value = False
455
+
456
+ if opset >= 24:
457
+ attn_mask = _pad_or_slice_last_dim(
458
+ mask_tensor=attn_mask,
459
+ target_len=kv_seq_len,
460
+ pad_value=pad_value,
461
+ )
462
+
463
+ if attn_mask.dtype == tf.bool:
464
+ neg_inf = tf.constant(-np.inf, dtype=q_dtype)
465
+ mask_bias = tf.where(
466
+ condition=attn_mask,
467
+ x=tf.zeros_like(attn_mask, dtype=q_dtype),
468
+ y=neg_inf,
469
+ )
470
+ else:
471
+ mask_bias = attn_mask
472
+
473
+ attn_bias = mask_bias if attn_bias is None else attn_bias + mask_bias
474
+
475
+ if opset >= 24 and nonpad_kv_seqlen is not None and past_key is None and past_value is None:
476
+ nonpad_kv_seqlen = tf.cast(nonpad_kv_seqlen, tf.int32)
477
+ seq_range = tf.range(kv_seq_len, dtype=tf.int32)
478
+ seq_range = tf.reshape(seq_range, shape=[1, -1])
479
+ nonpad_mask = seq_range < tf.reshape(nonpad_kv_seqlen, shape=[-1, 1])
480
+ nonpad_mask = tf.reshape(nonpad_mask, shape=[-1, 1, 1, kv_seq_len])
481
+ neg_inf = tf.constant(-np.inf, dtype=q_dtype)
482
+ nonpad_bias = tf.where(
483
+ condition=nonpad_mask,
484
+ x=tf.zeros_like(nonpad_mask, dtype=q_dtype),
485
+ y=neg_inf,
486
+ )
487
+ attn_bias = nonpad_bias if attn_bias is None else attn_bias + nonpad_bias
488
+
489
+ if attn_bias is not None:
490
+ qk_with_bias = qk_matmul_output + attn_bias
491
+ else:
492
+ qk_with_bias = qk_matmul_output
493
+
494
+ # Softcap
495
+ qk_softcap = qk_with_bias
496
+ if isinstance(softcap, (float, int, np.floating, np.integer)):
497
+ if softcap > 0:
498
+ softcap_value = tf.cast(softcap, q_dtype)
499
+ qk_softcap = softcap_value * tf.math.tanh(qk_with_bias / softcap_value)
500
+ else:
501
+ softcap_value = tf.cast(softcap, q_dtype)
502
+ safe_softcap = tf.where(
503
+ condition=tf.equal(softcap_value, 0),
504
+ x=tf.ones_like(softcap_value),
505
+ y=softcap_value,
506
+ )
507
+ qk_softcap = tf.where(
508
+ condition=softcap_value > 0,
509
+ x=softcap_value * tf.math.tanh(qk_with_bias / safe_softcap),
510
+ y=qk_with_bias,
511
+ )
512
+
513
+ # Softmax
514
+ softmax_dtype = None
515
+ if softmax_precision is not None:
516
+ if softmax_precision in ONNX_DTYPES_TO_TF_DTYPES:
517
+ softmax_dtype = ONNX_DTYPES_TO_TF_DTYPES[softmax_precision]
518
+ elif int(softmax_precision) == 16:
519
+ softmax_dtype = tf.bfloat16
520
+
521
+ qk_softmax_input = qk_softcap
522
+ if softmax_dtype is not None and softmax_dtype != qk_softmax_input.dtype:
523
+ qk_softmax_input = tf.cast(qk_softmax_input, softmax_dtype)
524
+
525
+ qk_softmax = tf.nn.softmax(
526
+ logits=qk_softmax_input,
527
+ axis=-1,
528
+ )
529
+
530
+ if softmax_dtype is not None and qk_softmax.dtype != q_dtype:
531
+ qk_softmax = tf.cast(qk_softmax, q_dtype)
532
+
533
+ # Output
534
+ Y = tf.matmul(
535
+ a=qk_softmax,
536
+ b=attn_value,
537
+ )
538
+
539
+ if input_is_3d:
540
+ Y = transpose_with_flexing_deterrence(
541
+ input_tensor=Y,
542
+ perm=[0, 2, 1, 3],
543
+ **kwargs,
544
+ )
545
+ y_shape_dyn = tf.shape(Y)
546
+ Y = tf.reshape(
547
+ tensor=Y,
548
+ shape=tf.stack([y_shape_dyn[0], y_shape_dyn[1], y_shape_dyn[2] * y_shape_dyn[3]]),
549
+ )
550
+
551
+ tf_layers_dict[graph_node_output_y.name]['tf_node'] = Y
552
+
553
+ # Outputs for KV cache
554
+ if graph_node_output_present_key is not None:
555
+ tf_layers_dict[graph_node_output_present_key.name]['tf_node'] = present_key
556
+ if graph_node_output_present_value is not None:
557
+ tf_layers_dict[graph_node_output_present_value.name]['tf_node'] = present_value
558
+
559
+ # qk_matmul_output output mode
560
+ if graph_node_output_qk_matmul_output is not None:
561
+ qk_output = qk_matmul_output
562
+ if qk_matmul_output_mode == 1:
563
+ qk_output = qk_with_bias
564
+ elif qk_matmul_output_mode == 2:
565
+ qk_output = qk_softcap
566
+ elif qk_matmul_output_mode == 3:
567
+ qk_output = qk_softmax
568
+ tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'] = qk_output
569
+
570
+ # Post-process transpose
571
+ tf_layers_dict[graph_node_output_y.name]['tf_node'] = post_process_transpose(
572
+ value_before_transpose=tf_layers_dict[graph_node_output_y.name]['tf_node'],
573
+ param_target='outputs',
574
+ param_name=graph_node.outputs[0].name,
575
+ **kwargs,
576
+ )
577
+ if graph_node_output_present_key is not None:
578
+ tf_layers_dict[graph_node_output_present_key.name]['tf_node'] = post_process_transpose(
579
+ value_before_transpose=tf_layers_dict[graph_node_output_present_key.name]['tf_node'],
580
+ param_target='outputs',
581
+ param_name=graph_node_output_present_key.name,
582
+ **kwargs,
583
+ )
584
+ if graph_node_output_present_value is not None:
585
+ tf_layers_dict[graph_node_output_present_value.name]['tf_node'] = post_process_transpose(
586
+ value_before_transpose=tf_layers_dict[graph_node_output_present_value.name]['tf_node'],
587
+ param_target='outputs',
588
+ param_name=graph_node_output_present_value.name,
589
+ **kwargs,
590
+ )
591
+ if graph_node_output_qk_matmul_output is not None:
592
+ tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'] = post_process_transpose(
593
+ value_before_transpose=tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'],
594
+ param_target='outputs',
595
+ param_name=graph_node_output_qk_matmul_output.name,
596
+ **kwargs,
597
+ )
598
+
599
+ # Generation of Debug Info
600
+ tf_layers_dict[graph_node_output_y.name]['tf_node_info'] = \
601
+ make_tf_node_info(
602
+ node_info={
603
+ 'tf_op_type': 'Attention',
604
+ 'tf_inputs': {
605
+ 'a': qk_softmax,
606
+ 'b': attn_value,
607
+ },
608
+ 'tf_outputs': {
609
+ 'output': tf_layers_dict[graph_node_output_y.name]['tf_node'],
610
+ },
611
+ }
612
+ )