onnx2tf 1.29.5__py3-none-any.whl → 1.29.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/ops/AffineGrid.py +187 -0
- onnx2tf/ops/Attention.py +612 -0
- onnx2tf/ops/BlackmanWindow.py +115 -0
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.6.dist-info}/METADATA +4 -4
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.6.dist-info}/RECORD +10 -7
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.6.dist-info}/WHEEL +1 -1
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.6.dist-info}/licenses/LICENSE +0 -0
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.6.dist-info}/licenses/LICENSE_onnx-tensorflow +0 -0
- {onnx2tf-1.29.5.dist-info → onnx2tf-1.29.6.dist-info}/top_level.txt +0 -0
onnx2tf/__init__.py
CHANGED
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import random
|
|
2
|
+
random.seed(0)
|
|
3
|
+
import numpy as np
|
|
4
|
+
np.random.seed(0)
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
import onnx_graphsurgeon as gs
|
|
7
|
+
from onnx2tf.utils.common_functions import (
|
|
8
|
+
replace_parameter,
|
|
9
|
+
get_constant_or_variable,
|
|
10
|
+
print_node_info,
|
|
11
|
+
inverted_operation_enable_disable,
|
|
12
|
+
make_tf_node_info,
|
|
13
|
+
get_replacement_parameter,
|
|
14
|
+
pre_process_transpose,
|
|
15
|
+
post_process_transpose,
|
|
16
|
+
)
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _make_coords(
|
|
21
|
+
size_dim: Any,
|
|
22
|
+
align_corners: bool,
|
|
23
|
+
dtype: Any,
|
|
24
|
+
) -> Any:
|
|
25
|
+
size_dim = tf.cast(size_dim, tf.int32)
|
|
26
|
+
size_f = tf.cast(size_dim, dtype)
|
|
27
|
+
|
|
28
|
+
if align_corners:
|
|
29
|
+
denom = size_f - tf.constant(1.0, dtype=dtype)
|
|
30
|
+
step = tf.where(
|
|
31
|
+
condition=size_dim > 1,
|
|
32
|
+
x=tf.constant(2.0, dtype=dtype) / denom,
|
|
33
|
+
y=tf.constant(0.0, dtype=dtype),
|
|
34
|
+
)
|
|
35
|
+
start = tf.constant(-1.0, dtype=dtype)
|
|
36
|
+
else:
|
|
37
|
+
step = tf.constant(2.0, dtype=dtype) / size_f
|
|
38
|
+
start = tf.constant(-1.0, dtype=dtype) + step / tf.constant(2.0, dtype=dtype)
|
|
39
|
+
|
|
40
|
+
return start + tf.range(size_dim, dtype=dtype) * step
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@print_node_info
|
|
44
|
+
@inverted_operation_enable_disable
|
|
45
|
+
@get_replacement_parameter
|
|
46
|
+
def make_node(
|
|
47
|
+
*,
|
|
48
|
+
graph_node: gs.Node,
|
|
49
|
+
tf_layers_dict: dict,
|
|
50
|
+
**kwargs: dict,
|
|
51
|
+
):
|
|
52
|
+
"""AffineGrid
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
graph_node: gs.Node
|
|
57
|
+
graph_surgeon Node
|
|
58
|
+
|
|
59
|
+
tf_layers_dict: dict
|
|
60
|
+
optype, shape, dtype, tensorflow graph
|
|
61
|
+
"""
|
|
62
|
+
before_op_output_shape_trans_1 = \
|
|
63
|
+
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
64
|
+
before_op_output_shape_trans_2 = \
|
|
65
|
+
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
66
|
+
before_op_output_shape_trans = \
|
|
67
|
+
before_op_output_shape_trans_1 \
|
|
68
|
+
and before_op_output_shape_trans_2
|
|
69
|
+
|
|
70
|
+
graph_node_input_theta = get_constant_or_variable(
|
|
71
|
+
graph_node.inputs[0],
|
|
72
|
+
before_op_output_shape_trans,
|
|
73
|
+
)
|
|
74
|
+
graph_node_input_size = get_constant_or_variable(
|
|
75
|
+
graph_node.inputs[1],
|
|
76
|
+
False \
|
|
77
|
+
if hasattr(graph_node.inputs[1], 'values') \
|
|
78
|
+
and isinstance(graph_node.inputs[1].values, np.ndarray) \
|
|
79
|
+
else before_op_output_shape_trans,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
83
|
+
shape = graph_node_output.shape
|
|
84
|
+
dtype = graph_node_output.dtype
|
|
85
|
+
|
|
86
|
+
# Preserving Graph Structure (Dict)
|
|
87
|
+
tf_layers_dict[graph_node_output.name] = {
|
|
88
|
+
'optype': graph_node.op,
|
|
89
|
+
'shape': shape,
|
|
90
|
+
'dtype': dtype,
|
|
91
|
+
'nhwc': True,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
theta = tf_layers_dict[graph_node_input_theta.name]['tf_node'] \
|
|
95
|
+
if isinstance(graph_node_input_theta, gs.Variable) else graph_node_input_theta
|
|
96
|
+
size = tf_layers_dict[graph_node_input_size.name]['tf_node'] \
|
|
97
|
+
if isinstance(graph_node_input_size, gs.Variable) else graph_node_input_size
|
|
98
|
+
|
|
99
|
+
# Pre-process transpose
|
|
100
|
+
theta = pre_process_transpose(
|
|
101
|
+
value_before_transpose=theta,
|
|
102
|
+
param_target='inputs',
|
|
103
|
+
param_name=graph_node.inputs[0].name,
|
|
104
|
+
**kwargs,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
align_corners = bool(graph_node.attrs.get('align_corners', 0))
|
|
108
|
+
align_corners = replace_parameter(
|
|
109
|
+
value_before_replacement=align_corners,
|
|
110
|
+
param_target='attributes',
|
|
111
|
+
param_name='align_corners',
|
|
112
|
+
**kwargs,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
theta_dtype = theta.dtype
|
|
116
|
+
size_tensor = tf.cast(size, tf.int32)
|
|
117
|
+
|
|
118
|
+
size_rank = size_tensor.shape[0] if size_tensor.shape.rank == 1 else None
|
|
119
|
+
|
|
120
|
+
def _build_grid_2d(size_vec):
|
|
121
|
+
N, _, H, W = tf.unstack(size_vec)
|
|
122
|
+
h_coords = _make_coords(H, align_corners, theta_dtype)
|
|
123
|
+
w_coords = _make_coords(W, align_corners, theta_dtype)
|
|
124
|
+
grid_h, grid_w = tf.meshgrid(h_coords, w_coords, indexing='ij')
|
|
125
|
+
ones = tf.ones_like(grid_w, dtype=theta_dtype)
|
|
126
|
+
grid = tf.stack([grid_w, grid_h, ones], axis=-1)
|
|
127
|
+
grid_flat = tf.reshape(grid, shape=[-1, 3])
|
|
128
|
+
grid_flat_t = tf.transpose(grid_flat)
|
|
129
|
+
grid_flat_t = tf.cast(grid_flat_t, theta_dtype)
|
|
130
|
+
out = tf.matmul(theta, grid_flat_t)
|
|
131
|
+
out = tf.transpose(out, perm=[0, 2, 1])
|
|
132
|
+
out = tf.reshape(out, shape=tf.stack([N, H, W, 2]))
|
|
133
|
+
return out
|
|
134
|
+
|
|
135
|
+
def _build_grid_3d(size_vec):
|
|
136
|
+
N, _, D, H, W = tf.unstack(size_vec)
|
|
137
|
+
d_coords = _make_coords(D, align_corners, theta_dtype)
|
|
138
|
+
h_coords = _make_coords(H, align_corners, theta_dtype)
|
|
139
|
+
w_coords = _make_coords(W, align_corners, theta_dtype)
|
|
140
|
+
grid_d, grid_h, grid_w = tf.meshgrid(d_coords, h_coords, w_coords, indexing='ij')
|
|
141
|
+
ones = tf.ones_like(grid_w, dtype=theta_dtype)
|
|
142
|
+
grid = tf.stack([grid_w, grid_h, grid_d, ones], axis=-1)
|
|
143
|
+
grid_flat = tf.reshape(grid, shape=[-1, 4])
|
|
144
|
+
grid_flat_t = tf.transpose(grid_flat)
|
|
145
|
+
grid_flat_t = tf.cast(grid_flat_t, theta_dtype)
|
|
146
|
+
out = tf.matmul(theta, grid_flat_t)
|
|
147
|
+
out = tf.transpose(out, perm=[0, 2, 1])
|
|
148
|
+
out = tf.reshape(out, shape=tf.stack([N, D, H, W, 3]))
|
|
149
|
+
return out
|
|
150
|
+
|
|
151
|
+
if size_rank == 4:
|
|
152
|
+
grid = _build_grid_2d(size_tensor)
|
|
153
|
+
elif size_rank == 5:
|
|
154
|
+
grid = _build_grid_3d(size_tensor)
|
|
155
|
+
else:
|
|
156
|
+
size_dim = tf.shape(size_tensor)[0]
|
|
157
|
+
grid = tf.cond(
|
|
158
|
+
pred=tf.equal(size_dim, 4),
|
|
159
|
+
true_fn=lambda: _build_grid_2d(size_tensor),
|
|
160
|
+
false_fn=lambda: _build_grid_3d(size_tensor),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = grid
|
|
164
|
+
|
|
165
|
+
# Post-process transpose
|
|
166
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
|
|
167
|
+
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
168
|
+
param_target='outputs',
|
|
169
|
+
param_name=graph_node.outputs[0].name,
|
|
170
|
+
**kwargs,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Generation of Debug Info
|
|
174
|
+
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
175
|
+
make_tf_node_info(
|
|
176
|
+
node_info={
|
|
177
|
+
'tf_op_type': 'AffineGrid',
|
|
178
|
+
'tf_inputs': {
|
|
179
|
+
'theta': theta,
|
|
180
|
+
'size': size,
|
|
181
|
+
'align_corners': align_corners,
|
|
182
|
+
},
|
|
183
|
+
'tf_outputs': {
|
|
184
|
+
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
185
|
+
},
|
|
186
|
+
}
|
|
187
|
+
)
|
onnx2tf/ops/Attention.py
ADDED
|
@@ -0,0 +1,612 @@
|
|
|
1
|
+
import random
|
|
2
|
+
random.seed(0)
|
|
3
|
+
import numpy as np
|
|
4
|
+
np.random.seed(0)
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
import onnx_graphsurgeon as gs
|
|
7
|
+
from onnx2tf.utils.common_functions import (
|
|
8
|
+
replace_parameter,
|
|
9
|
+
get_constant_or_variable,
|
|
10
|
+
print_node_info,
|
|
11
|
+
inverted_operation_enable_disable,
|
|
12
|
+
make_tf_node_info,
|
|
13
|
+
get_replacement_parameter,
|
|
14
|
+
pre_process_transpose,
|
|
15
|
+
post_process_transpose,
|
|
16
|
+
transpose_with_flexing_deterrence,
|
|
17
|
+
)
|
|
18
|
+
from onnx2tf.utils.enums import (
|
|
19
|
+
ONNX_DTYPES_TO_TF_DTYPES,
|
|
20
|
+
)
|
|
21
|
+
from typing import Any, Optional
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _is_empty_input(input_var: Any) -> bool:
|
|
25
|
+
if input_var is None:
|
|
26
|
+
return True
|
|
27
|
+
if isinstance(input_var, str) and input_var == "":
|
|
28
|
+
return True
|
|
29
|
+
if hasattr(input_var, 'name') and input_var.name == "":
|
|
30
|
+
return True
|
|
31
|
+
return False
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _to_tf_tensor(input_var: Any, tf_layers_dict: dict) -> Any:
|
|
35
|
+
return tf_layers_dict[input_var.name]['tf_node'] \
|
|
36
|
+
if isinstance(input_var, gs.Variable) else input_var
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _pad_or_slice_last_dim(
|
|
40
|
+
mask_tensor: Any,
|
|
41
|
+
target_len: Any,
|
|
42
|
+
pad_value: Any,
|
|
43
|
+
) -> Any:
|
|
44
|
+
mask_shape = tf.shape(mask_tensor)
|
|
45
|
+
last_dim = mask_shape[-1]
|
|
46
|
+
|
|
47
|
+
def _pad():
|
|
48
|
+
pad_len = target_len - last_dim
|
|
49
|
+
rank = tf.rank(mask_tensor)
|
|
50
|
+
paddings = tf.concat(
|
|
51
|
+
values=[
|
|
52
|
+
tf.zeros(
|
|
53
|
+
shape=tf.stack([rank - 1, 2]),
|
|
54
|
+
dtype=tf.int32,
|
|
55
|
+
),
|
|
56
|
+
tf.reshape(tf.stack([0, pad_len]), (1, 2)),
|
|
57
|
+
],
|
|
58
|
+
axis=0,
|
|
59
|
+
)
|
|
60
|
+
return tf.pad(
|
|
61
|
+
tensor=mask_tensor,
|
|
62
|
+
paddings=paddings,
|
|
63
|
+
constant_values=pad_value,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _slice():
|
|
67
|
+
return mask_tensor[..., :target_len]
|
|
68
|
+
|
|
69
|
+
def _identity():
|
|
70
|
+
return mask_tensor
|
|
71
|
+
|
|
72
|
+
return tf.cond(
|
|
73
|
+
pred=last_dim < target_len,
|
|
74
|
+
true_fn=_pad,
|
|
75
|
+
false_fn=lambda: tf.cond(
|
|
76
|
+
pred=last_dim > target_len,
|
|
77
|
+
true_fn=_slice,
|
|
78
|
+
false_fn=_identity,
|
|
79
|
+
),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@print_node_info
|
|
84
|
+
@inverted_operation_enable_disable
|
|
85
|
+
@get_replacement_parameter
|
|
86
|
+
def make_node(
|
|
87
|
+
*,
|
|
88
|
+
graph_node: gs.Node,
|
|
89
|
+
tf_layers_dict: dict,
|
|
90
|
+
**kwargs: dict,
|
|
91
|
+
):
|
|
92
|
+
"""Attention
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
graph_node: gs.Node
|
|
97
|
+
graph_surgeon Node
|
|
98
|
+
|
|
99
|
+
tf_layers_dict: dict
|
|
100
|
+
optype, shape, dtype, tensorflow graph
|
|
101
|
+
"""
|
|
102
|
+
before_op_output_shape_trans_q = \
|
|
103
|
+
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
104
|
+
before_op_output_shape_trans_k = \
|
|
105
|
+
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
106
|
+
before_op_output_shape_trans_v = \
|
|
107
|
+
tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
|
|
108
|
+
before_op_output_shape_trans = \
|
|
109
|
+
before_op_output_shape_trans_q \
|
|
110
|
+
and before_op_output_shape_trans_k \
|
|
111
|
+
and before_op_output_shape_trans_v
|
|
112
|
+
|
|
113
|
+
graph_node_input_q = get_constant_or_variable(
|
|
114
|
+
graph_node.inputs[0],
|
|
115
|
+
before_op_output_shape_trans,
|
|
116
|
+
)
|
|
117
|
+
graph_node_input_k = get_constant_or_variable(
|
|
118
|
+
graph_node.inputs[1],
|
|
119
|
+
before_op_output_shape_trans,
|
|
120
|
+
)
|
|
121
|
+
graph_node_input_v = get_constant_or_variable(
|
|
122
|
+
graph_node.inputs[2],
|
|
123
|
+
before_op_output_shape_trans,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
attn_mask_input = None
|
|
127
|
+
past_key_input = None
|
|
128
|
+
past_value_input = None
|
|
129
|
+
nonpad_kv_seqlen_input = None
|
|
130
|
+
|
|
131
|
+
if len(graph_node.inputs) >= 4 and not _is_empty_input(graph_node.inputs[3]):
|
|
132
|
+
attn_mask_input = graph_node.inputs[3]
|
|
133
|
+
if len(graph_node.inputs) >= 5 and not _is_empty_input(graph_node.inputs[4]):
|
|
134
|
+
past_key_input = graph_node.inputs[4]
|
|
135
|
+
if len(graph_node.inputs) >= 6 and not _is_empty_input(graph_node.inputs[5]):
|
|
136
|
+
past_value_input = graph_node.inputs[5]
|
|
137
|
+
if len(graph_node.inputs) >= 7 and not _is_empty_input(graph_node.inputs[6]):
|
|
138
|
+
nonpad_kv_seqlen_input = graph_node.inputs[6]
|
|
139
|
+
|
|
140
|
+
if (past_key_input is None) != (past_value_input is None):
|
|
141
|
+
past_key_input = None
|
|
142
|
+
past_value_input = None
|
|
143
|
+
|
|
144
|
+
Q = _to_tf_tensor(graph_node_input_q, tf_layers_dict)
|
|
145
|
+
K = _to_tf_tensor(graph_node_input_k, tf_layers_dict)
|
|
146
|
+
V = _to_tf_tensor(graph_node_input_v, tf_layers_dict)
|
|
147
|
+
|
|
148
|
+
# Pre-process transpose
|
|
149
|
+
Q = pre_process_transpose(
|
|
150
|
+
value_before_transpose=Q,
|
|
151
|
+
param_target='inputs',
|
|
152
|
+
param_name=graph_node.inputs[0].name,
|
|
153
|
+
**kwargs,
|
|
154
|
+
)
|
|
155
|
+
K = pre_process_transpose(
|
|
156
|
+
value_before_transpose=K,
|
|
157
|
+
param_target='inputs',
|
|
158
|
+
param_name=graph_node.inputs[1].name,
|
|
159
|
+
**kwargs,
|
|
160
|
+
)
|
|
161
|
+
V = pre_process_transpose(
|
|
162
|
+
value_before_transpose=V,
|
|
163
|
+
param_target='inputs',
|
|
164
|
+
param_name=graph_node.inputs[2].name,
|
|
165
|
+
**kwargs,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
attn_mask = None
|
|
169
|
+
if attn_mask_input is not None:
|
|
170
|
+
graph_node_input_attn = get_constant_or_variable(
|
|
171
|
+
attn_mask_input,
|
|
172
|
+
before_op_output_shape_trans,
|
|
173
|
+
)
|
|
174
|
+
attn_mask = _to_tf_tensor(graph_node_input_attn, tf_layers_dict)
|
|
175
|
+
attn_mask = pre_process_transpose(
|
|
176
|
+
value_before_transpose=attn_mask,
|
|
177
|
+
param_target='inputs',
|
|
178
|
+
param_name=attn_mask_input.name,
|
|
179
|
+
**kwargs,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
past_key = None
|
|
183
|
+
past_value = None
|
|
184
|
+
if past_key_input is not None and past_value_input is not None:
|
|
185
|
+
graph_node_input_past_key = get_constant_or_variable(
|
|
186
|
+
past_key_input,
|
|
187
|
+
before_op_output_shape_trans,
|
|
188
|
+
)
|
|
189
|
+
graph_node_input_past_value = get_constant_or_variable(
|
|
190
|
+
past_value_input,
|
|
191
|
+
before_op_output_shape_trans,
|
|
192
|
+
)
|
|
193
|
+
past_key = _to_tf_tensor(graph_node_input_past_key, tf_layers_dict)
|
|
194
|
+
past_value = _to_tf_tensor(graph_node_input_past_value, tf_layers_dict)
|
|
195
|
+
past_key = pre_process_transpose(
|
|
196
|
+
value_before_transpose=past_key,
|
|
197
|
+
param_target='inputs',
|
|
198
|
+
param_name=past_key_input.name,
|
|
199
|
+
**kwargs,
|
|
200
|
+
)
|
|
201
|
+
past_value = pre_process_transpose(
|
|
202
|
+
value_before_transpose=past_value,
|
|
203
|
+
param_target='inputs',
|
|
204
|
+
param_name=past_value_input.name,
|
|
205
|
+
**kwargs,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
nonpad_kv_seqlen = None
|
|
209
|
+
if nonpad_kv_seqlen_input is not None:
|
|
210
|
+
graph_node_input_nonpad = get_constant_or_variable(
|
|
211
|
+
nonpad_kv_seqlen_input,
|
|
212
|
+
before_op_output_shape_trans,
|
|
213
|
+
)
|
|
214
|
+
nonpad_kv_seqlen = _to_tf_tensor(graph_node_input_nonpad, tf_layers_dict)
|
|
215
|
+
nonpad_kv_seqlen = pre_process_transpose(
|
|
216
|
+
value_before_transpose=nonpad_kv_seqlen,
|
|
217
|
+
param_target='inputs',
|
|
218
|
+
param_name=nonpad_kv_seqlen_input.name,
|
|
219
|
+
**kwargs,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
graph_node_output_y: gs.Variable = graph_node.outputs[0]
|
|
223
|
+
y_shape = graph_node_output_y.shape
|
|
224
|
+
y_dtype = graph_node_output_y.dtype
|
|
225
|
+
|
|
226
|
+
graph_node_output_present_key: Optional[gs.Variable] = None
|
|
227
|
+
graph_node_output_present_value: Optional[gs.Variable] = None
|
|
228
|
+
graph_node_output_qk_matmul_output: Optional[gs.Variable] = None
|
|
229
|
+
if len(graph_node.outputs) >= 2 and not _is_empty_input(graph_node.outputs[1]):
|
|
230
|
+
graph_node_output_present_key = graph_node.outputs[1]
|
|
231
|
+
if len(graph_node.outputs) >= 3 and not _is_empty_input(graph_node.outputs[2]):
|
|
232
|
+
graph_node_output_present_value = graph_node.outputs[2]
|
|
233
|
+
if len(graph_node.outputs) >= 4 and not _is_empty_input(graph_node.outputs[3]):
|
|
234
|
+
graph_node_output_qk_matmul_output = graph_node.outputs[3]
|
|
235
|
+
|
|
236
|
+
if (graph_node_output_present_key is None) != (graph_node_output_present_value is None):
|
|
237
|
+
graph_node_output_present_key = None
|
|
238
|
+
graph_node_output_present_value = None
|
|
239
|
+
|
|
240
|
+
# Preserving Graph Structure (Dict)
|
|
241
|
+
tf_layers_dict[graph_node_output_y.name] = {
|
|
242
|
+
'optype': graph_node.op,
|
|
243
|
+
'shape': y_shape,
|
|
244
|
+
'dtype': y_dtype,
|
|
245
|
+
}
|
|
246
|
+
if graph_node_output_present_key is not None:
|
|
247
|
+
tf_layers_dict[graph_node_output_present_key.name] = {
|
|
248
|
+
'optype': graph_node.op,
|
|
249
|
+
'shape': graph_node_output_present_key.shape,
|
|
250
|
+
'dtype': graph_node_output_present_key.dtype,
|
|
251
|
+
}
|
|
252
|
+
if graph_node_output_present_value is not None:
|
|
253
|
+
tf_layers_dict[graph_node_output_present_value.name] = {
|
|
254
|
+
'optype': graph_node.op,
|
|
255
|
+
'shape': graph_node_output_present_value.shape,
|
|
256
|
+
'dtype': graph_node_output_present_value.dtype,
|
|
257
|
+
}
|
|
258
|
+
if graph_node_output_qk_matmul_output is not None:
|
|
259
|
+
tf_layers_dict[graph_node_output_qk_matmul_output.name] = {
|
|
260
|
+
'optype': graph_node.op,
|
|
261
|
+
'shape': graph_node_output_qk_matmul_output.shape,
|
|
262
|
+
'dtype': graph_node_output_qk_matmul_output.dtype,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# Attributes
|
|
266
|
+
opset = kwargs['opset']
|
|
267
|
+
is_causal = bool(graph_node.attrs.get('is_causal', 0))
|
|
268
|
+
qk_matmul_output_mode = int(graph_node.attrs.get('qk_matmul_output_mode', 0))
|
|
269
|
+
softcap = graph_node.attrs.get('softcap', 0.0)
|
|
270
|
+
scale = graph_node.attrs.get('scale', None)
|
|
271
|
+
q_num_heads = graph_node.attrs.get('q_num_heads', None)
|
|
272
|
+
kv_num_heads = graph_node.attrs.get('kv_num_heads', None)
|
|
273
|
+
softmax_precision = graph_node.attrs.get('softmax_precision', None)
|
|
274
|
+
|
|
275
|
+
is_causal = replace_parameter(
|
|
276
|
+
value_before_replacement=is_causal,
|
|
277
|
+
param_target='attributes',
|
|
278
|
+
param_name='is_causal',
|
|
279
|
+
**kwargs,
|
|
280
|
+
)
|
|
281
|
+
qk_matmul_output_mode = replace_parameter(
|
|
282
|
+
value_before_replacement=qk_matmul_output_mode,
|
|
283
|
+
param_target='attributes',
|
|
284
|
+
param_name='qk_matmul_output_mode',
|
|
285
|
+
**kwargs,
|
|
286
|
+
)
|
|
287
|
+
softcap = replace_parameter(
|
|
288
|
+
value_before_replacement=softcap,
|
|
289
|
+
param_target='attributes',
|
|
290
|
+
param_name='softcap',
|
|
291
|
+
**kwargs,
|
|
292
|
+
)
|
|
293
|
+
scale = replace_parameter(
|
|
294
|
+
value_before_replacement=scale,
|
|
295
|
+
param_target='attributes',
|
|
296
|
+
param_name='scale',
|
|
297
|
+
**kwargs,
|
|
298
|
+
)
|
|
299
|
+
q_num_heads = replace_parameter(
|
|
300
|
+
value_before_replacement=q_num_heads,
|
|
301
|
+
param_target='attributes',
|
|
302
|
+
param_name='q_num_heads',
|
|
303
|
+
**kwargs,
|
|
304
|
+
)
|
|
305
|
+
kv_num_heads = replace_parameter(
|
|
306
|
+
value_before_replacement=kv_num_heads,
|
|
307
|
+
param_target='attributes',
|
|
308
|
+
param_name='kv_num_heads',
|
|
309
|
+
**kwargs,
|
|
310
|
+
)
|
|
311
|
+
softmax_precision = replace_parameter(
|
|
312
|
+
value_before_replacement=softmax_precision,
|
|
313
|
+
param_target='attributes',
|
|
314
|
+
param_name='softmax_precision',
|
|
315
|
+
**kwargs,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Reshape 3D inputs to 4D if needed
|
|
319
|
+
q_rank = len(Q.shape) if Q.shape is not None else None
|
|
320
|
+
k_rank = len(K.shape) if K.shape is not None else None
|
|
321
|
+
v_rank = len(V.shape) if V.shape is not None else None
|
|
322
|
+
|
|
323
|
+
input_q_is_3d = q_rank == 3
|
|
324
|
+
input_k_is_3d = k_rank == 3
|
|
325
|
+
input_v_is_3d = v_rank == 3
|
|
326
|
+
input_is_3d = input_q_is_3d and input_k_is_3d and input_v_is_3d
|
|
327
|
+
|
|
328
|
+
if input_is_3d:
|
|
329
|
+
if q_num_heads is None:
|
|
330
|
+
q_num_heads = 1
|
|
331
|
+
if kv_num_heads is None:
|
|
332
|
+
kv_num_heads = 1
|
|
333
|
+
|
|
334
|
+
q_shape = tf.shape(Q)
|
|
335
|
+
k_shape = tf.shape(K)
|
|
336
|
+
v_shape = tf.shape(V)
|
|
337
|
+
|
|
338
|
+
q_num_heads_tensor = \
|
|
339
|
+
tf.constant(q_num_heads, dtype=tf.int32) \
|
|
340
|
+
if isinstance(q_num_heads, int) else tf.cast(q_num_heads, tf.int32)
|
|
341
|
+
kv_num_heads_tensor = \
|
|
342
|
+
tf.constant(kv_num_heads, dtype=tf.int32) \
|
|
343
|
+
if isinstance(kv_num_heads, int) else tf.cast(kv_num_heads, tf.int32)
|
|
344
|
+
|
|
345
|
+
q_head_size = tf.math.floordiv(q_shape[2], q_num_heads_tensor)
|
|
346
|
+
k_head_size = tf.math.floordiv(k_shape[2], kv_num_heads_tensor)
|
|
347
|
+
v_head_size = tf.math.floordiv(v_shape[2], kv_num_heads_tensor)
|
|
348
|
+
|
|
349
|
+
Q = tf.reshape(
|
|
350
|
+
tensor=Q,
|
|
351
|
+
shape=tf.stack([q_shape[0], q_shape[1], q_num_heads_tensor, q_head_size]),
|
|
352
|
+
)
|
|
353
|
+
K = tf.reshape(
|
|
354
|
+
tensor=K,
|
|
355
|
+
shape=tf.stack([k_shape[0], k_shape[1], kv_num_heads_tensor, k_head_size]),
|
|
356
|
+
)
|
|
357
|
+
V = tf.reshape(
|
|
358
|
+
tensor=V,
|
|
359
|
+
shape=tf.stack([v_shape[0], v_shape[1], kv_num_heads_tensor, v_head_size]),
|
|
360
|
+
)
|
|
361
|
+
Q = transpose_with_flexing_deterrence(
|
|
362
|
+
input_tensor=Q,
|
|
363
|
+
perm=[0, 2, 1, 3],
|
|
364
|
+
**kwargs,
|
|
365
|
+
)
|
|
366
|
+
K = transpose_with_flexing_deterrence(
|
|
367
|
+
input_tensor=K,
|
|
368
|
+
perm=[0, 2, 1, 3],
|
|
369
|
+
**kwargs,
|
|
370
|
+
)
|
|
371
|
+
V = transpose_with_flexing_deterrence(
|
|
372
|
+
input_tensor=V,
|
|
373
|
+
perm=[0, 2, 1, 3],
|
|
374
|
+
**kwargs,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# Ensure dtype alignment for matmul
|
|
378
|
+
q_dtype = Q.dtype
|
|
379
|
+
if K.dtype != q_dtype:
|
|
380
|
+
K = tf.cast(K, q_dtype)
|
|
381
|
+
if V.dtype != q_dtype:
|
|
382
|
+
V = tf.cast(V, q_dtype)
|
|
383
|
+
|
|
384
|
+
if past_key is not None and past_value is not None:
|
|
385
|
+
if past_key.dtype != q_dtype:
|
|
386
|
+
past_key = tf.cast(past_key, q_dtype)
|
|
387
|
+
if past_value.dtype != q_dtype:
|
|
388
|
+
past_value = tf.cast(past_value, q_dtype)
|
|
389
|
+
K = tf.concat([past_key, K], axis=2)
|
|
390
|
+
V = tf.concat([past_value, V], axis=2)
|
|
391
|
+
|
|
392
|
+
present_key = K
|
|
393
|
+
present_value = V
|
|
394
|
+
|
|
395
|
+
# Heads for GQA/MQA
|
|
396
|
+
q_heads = Q.shape[1] if Q.shape[1] is not None else tf.shape(Q)[1]
|
|
397
|
+
kv_heads = present_key.shape[1] if present_key.shape[1] is not None else tf.shape(present_key)[1]
|
|
398
|
+
|
|
399
|
+
attn_key = present_key
|
|
400
|
+
attn_value = present_value
|
|
401
|
+
if isinstance(q_heads, int) and isinstance(kv_heads, int):
|
|
402
|
+
if kv_heads != q_heads:
|
|
403
|
+
repeat = q_heads // kv_heads
|
|
404
|
+
attn_key = tf.repeat(attn_key, repeats=repeat, axis=1)
|
|
405
|
+
attn_value = tf.repeat(attn_value, repeats=repeat, axis=1)
|
|
406
|
+
else:
|
|
407
|
+
repeat = tf.math.floordiv(tf.cast(q_heads, tf.int32), tf.cast(kv_heads, tf.int32))
|
|
408
|
+
attn_key = tf.repeat(attn_key, repeats=repeat, axis=1)
|
|
409
|
+
attn_value = tf.repeat(attn_value, repeats=repeat, axis=1)
|
|
410
|
+
|
|
411
|
+
# Scale Q and K
|
|
412
|
+
head_size = tf.shape(Q)[-1]
|
|
413
|
+
if scale is None:
|
|
414
|
+
scale_value = tf.math.rsqrt(tf.cast(head_size, q_dtype))
|
|
415
|
+
else:
|
|
416
|
+
scale_value = tf.cast(scale, q_dtype)
|
|
417
|
+
|
|
418
|
+
scale_sqrt = tf.sqrt(scale_value)
|
|
419
|
+
Q = Q * scale_sqrt
|
|
420
|
+
K_scaled = attn_key * scale_sqrt
|
|
421
|
+
|
|
422
|
+
# QK^T
|
|
423
|
+
qk_matmul_output = tf.matmul(
|
|
424
|
+
a=Q,
|
|
425
|
+
b=K_scaled,
|
|
426
|
+
transpose_b=True,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
q_seq_len = tf.shape(Q)[2]
|
|
430
|
+
kv_seq_len = tf.shape(present_key)[2]
|
|
431
|
+
|
|
432
|
+
attn_bias = None
|
|
433
|
+
|
|
434
|
+
if is_causal:
|
|
435
|
+
causal_mask = tf.linalg.band_part(
|
|
436
|
+
tf.ones(shape=tf.stack([q_seq_len, kv_seq_len]), dtype=tf.float32),
|
|
437
|
+
-1,
|
|
438
|
+
0,
|
|
439
|
+
)
|
|
440
|
+
causal_mask = tf.cast(causal_mask, tf.bool)
|
|
441
|
+
neg_inf = tf.constant(-np.inf, dtype=q_dtype)
|
|
442
|
+
causal_bias = tf.where(
|
|
443
|
+
condition=causal_mask,
|
|
444
|
+
x=tf.zeros_like(causal_mask, dtype=q_dtype),
|
|
445
|
+
y=neg_inf,
|
|
446
|
+
)
|
|
447
|
+
attn_bias = causal_bias
|
|
448
|
+
|
|
449
|
+
if attn_mask is not None and attn_mask != "":
|
|
450
|
+
if attn_mask.dtype != tf.bool:
|
|
451
|
+
attn_mask = tf.cast(attn_mask, q_dtype)
|
|
452
|
+
pad_value = tf.constant(-np.inf, dtype=q_dtype)
|
|
453
|
+
else:
|
|
454
|
+
pad_value = False
|
|
455
|
+
|
|
456
|
+
if opset >= 24:
|
|
457
|
+
attn_mask = _pad_or_slice_last_dim(
|
|
458
|
+
mask_tensor=attn_mask,
|
|
459
|
+
target_len=kv_seq_len,
|
|
460
|
+
pad_value=pad_value,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
if attn_mask.dtype == tf.bool:
|
|
464
|
+
neg_inf = tf.constant(-np.inf, dtype=q_dtype)
|
|
465
|
+
mask_bias = tf.where(
|
|
466
|
+
condition=attn_mask,
|
|
467
|
+
x=tf.zeros_like(attn_mask, dtype=q_dtype),
|
|
468
|
+
y=neg_inf,
|
|
469
|
+
)
|
|
470
|
+
else:
|
|
471
|
+
mask_bias = attn_mask
|
|
472
|
+
|
|
473
|
+
attn_bias = mask_bias if attn_bias is None else attn_bias + mask_bias
|
|
474
|
+
|
|
475
|
+
if opset >= 24 and nonpad_kv_seqlen is not None and past_key is None and past_value is None:
|
|
476
|
+
nonpad_kv_seqlen = tf.cast(nonpad_kv_seqlen, tf.int32)
|
|
477
|
+
seq_range = tf.range(kv_seq_len, dtype=tf.int32)
|
|
478
|
+
seq_range = tf.reshape(seq_range, shape=[1, -1])
|
|
479
|
+
nonpad_mask = seq_range < tf.reshape(nonpad_kv_seqlen, shape=[-1, 1])
|
|
480
|
+
nonpad_mask = tf.reshape(nonpad_mask, shape=[-1, 1, 1, kv_seq_len])
|
|
481
|
+
neg_inf = tf.constant(-np.inf, dtype=q_dtype)
|
|
482
|
+
nonpad_bias = tf.where(
|
|
483
|
+
condition=nonpad_mask,
|
|
484
|
+
x=tf.zeros_like(nonpad_mask, dtype=q_dtype),
|
|
485
|
+
y=neg_inf,
|
|
486
|
+
)
|
|
487
|
+
attn_bias = nonpad_bias if attn_bias is None else attn_bias + nonpad_bias
|
|
488
|
+
|
|
489
|
+
if attn_bias is not None:
|
|
490
|
+
qk_with_bias = qk_matmul_output + attn_bias
|
|
491
|
+
else:
|
|
492
|
+
qk_with_bias = qk_matmul_output
|
|
493
|
+
|
|
494
|
+
# Softcap
|
|
495
|
+
qk_softcap = qk_with_bias
|
|
496
|
+
if isinstance(softcap, (float, int, np.floating, np.integer)):
|
|
497
|
+
if softcap > 0:
|
|
498
|
+
softcap_value = tf.cast(softcap, q_dtype)
|
|
499
|
+
qk_softcap = softcap_value * tf.math.tanh(qk_with_bias / softcap_value)
|
|
500
|
+
else:
|
|
501
|
+
softcap_value = tf.cast(softcap, q_dtype)
|
|
502
|
+
safe_softcap = tf.where(
|
|
503
|
+
condition=tf.equal(softcap_value, 0),
|
|
504
|
+
x=tf.ones_like(softcap_value),
|
|
505
|
+
y=softcap_value,
|
|
506
|
+
)
|
|
507
|
+
qk_softcap = tf.where(
|
|
508
|
+
condition=softcap_value > 0,
|
|
509
|
+
x=softcap_value * tf.math.tanh(qk_with_bias / safe_softcap),
|
|
510
|
+
y=qk_with_bias,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
# Softmax
|
|
514
|
+
softmax_dtype = None
|
|
515
|
+
if softmax_precision is not None:
|
|
516
|
+
if softmax_precision in ONNX_DTYPES_TO_TF_DTYPES:
|
|
517
|
+
softmax_dtype = ONNX_DTYPES_TO_TF_DTYPES[softmax_precision]
|
|
518
|
+
elif int(softmax_precision) == 16:
|
|
519
|
+
softmax_dtype = tf.bfloat16
|
|
520
|
+
|
|
521
|
+
qk_softmax_input = qk_softcap
|
|
522
|
+
if softmax_dtype is not None and softmax_dtype != qk_softmax_input.dtype:
|
|
523
|
+
qk_softmax_input = tf.cast(qk_softmax_input, softmax_dtype)
|
|
524
|
+
|
|
525
|
+
qk_softmax = tf.nn.softmax(
|
|
526
|
+
logits=qk_softmax_input,
|
|
527
|
+
axis=-1,
|
|
528
|
+
)
|
|
529
|
+
|
|
530
|
+
if softmax_dtype is not None and qk_softmax.dtype != q_dtype:
|
|
531
|
+
qk_softmax = tf.cast(qk_softmax, q_dtype)
|
|
532
|
+
|
|
533
|
+
# Output
|
|
534
|
+
Y = tf.matmul(
|
|
535
|
+
a=qk_softmax,
|
|
536
|
+
b=attn_value,
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
if input_is_3d:
|
|
540
|
+
Y = transpose_with_flexing_deterrence(
|
|
541
|
+
input_tensor=Y,
|
|
542
|
+
perm=[0, 2, 1, 3],
|
|
543
|
+
**kwargs,
|
|
544
|
+
)
|
|
545
|
+
y_shape_dyn = tf.shape(Y)
|
|
546
|
+
Y = tf.reshape(
|
|
547
|
+
tensor=Y,
|
|
548
|
+
shape=tf.stack([y_shape_dyn[0], y_shape_dyn[1], y_shape_dyn[2] * y_shape_dyn[3]]),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
tf_layers_dict[graph_node_output_y.name]['tf_node'] = Y
|
|
552
|
+
|
|
553
|
+
# Outputs for KV cache
|
|
554
|
+
if graph_node_output_present_key is not None:
|
|
555
|
+
tf_layers_dict[graph_node_output_present_key.name]['tf_node'] = present_key
|
|
556
|
+
if graph_node_output_present_value is not None:
|
|
557
|
+
tf_layers_dict[graph_node_output_present_value.name]['tf_node'] = present_value
|
|
558
|
+
|
|
559
|
+
# qk_matmul_output output mode
|
|
560
|
+
if graph_node_output_qk_matmul_output is not None:
|
|
561
|
+
qk_output = qk_matmul_output
|
|
562
|
+
if qk_matmul_output_mode == 1:
|
|
563
|
+
qk_output = qk_with_bias
|
|
564
|
+
elif qk_matmul_output_mode == 2:
|
|
565
|
+
qk_output = qk_softcap
|
|
566
|
+
elif qk_matmul_output_mode == 3:
|
|
567
|
+
qk_output = qk_softmax
|
|
568
|
+
tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'] = qk_output
|
|
569
|
+
|
|
570
|
+
# Post-process transpose
|
|
571
|
+
tf_layers_dict[graph_node_output_y.name]['tf_node'] = post_process_transpose(
|
|
572
|
+
value_before_transpose=tf_layers_dict[graph_node_output_y.name]['tf_node'],
|
|
573
|
+
param_target='outputs',
|
|
574
|
+
param_name=graph_node.outputs[0].name,
|
|
575
|
+
**kwargs,
|
|
576
|
+
)
|
|
577
|
+
if graph_node_output_present_key is not None:
|
|
578
|
+
tf_layers_dict[graph_node_output_present_key.name]['tf_node'] = post_process_transpose(
|
|
579
|
+
value_before_transpose=tf_layers_dict[graph_node_output_present_key.name]['tf_node'],
|
|
580
|
+
param_target='outputs',
|
|
581
|
+
param_name=graph_node_output_present_key.name,
|
|
582
|
+
**kwargs,
|
|
583
|
+
)
|
|
584
|
+
if graph_node_output_present_value is not None:
|
|
585
|
+
tf_layers_dict[graph_node_output_present_value.name]['tf_node'] = post_process_transpose(
|
|
586
|
+
value_before_transpose=tf_layers_dict[graph_node_output_present_value.name]['tf_node'],
|
|
587
|
+
param_target='outputs',
|
|
588
|
+
param_name=graph_node_output_present_value.name,
|
|
589
|
+
**kwargs,
|
|
590
|
+
)
|
|
591
|
+
if graph_node_output_qk_matmul_output is not None:
|
|
592
|
+
tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'] = post_process_transpose(
|
|
593
|
+
value_before_transpose=tf_layers_dict[graph_node_output_qk_matmul_output.name]['tf_node'],
|
|
594
|
+
param_target='outputs',
|
|
595
|
+
param_name=graph_node_output_qk_matmul_output.name,
|
|
596
|
+
**kwargs,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# Generation of Debug Info
|
|
600
|
+
tf_layers_dict[graph_node_output_y.name]['tf_node_info'] = \
|
|
601
|
+
make_tf_node_info(
|
|
602
|
+
node_info={
|
|
603
|
+
'tf_op_type': 'Attention',
|
|
604
|
+
'tf_inputs': {
|
|
605
|
+
'a': qk_softmax,
|
|
606
|
+
'b': attn_value,
|
|
607
|
+
},
|
|
608
|
+
'tf_outputs': {
|
|
609
|
+
'output': tf_layers_dict[graph_node_output_y.name]['tf_node'],
|
|
610
|
+
},
|
|
611
|
+
}
|
|
612
|
+
)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
import random
|
|
2
|
+
random.seed(0)
|
|
3
|
+
import numpy as np
|
|
4
|
+
np.random.seed(0)
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
from onnx import TensorProto
|
|
7
|
+
import onnx_graphsurgeon as gs
|
|
8
|
+
from onnx2tf.utils.common_functions import (
|
|
9
|
+
get_constant_or_variable,
|
|
10
|
+
print_node_info,
|
|
11
|
+
inverted_operation_enable_disable,
|
|
12
|
+
make_tf_node_info,
|
|
13
|
+
get_replacement_parameter,
|
|
14
|
+
pre_process_transpose,
|
|
15
|
+
post_process_transpose,
|
|
16
|
+
)
|
|
17
|
+
from onnx2tf.utils.enums import ONNX_DTYPES_TO_TF_DTYPES
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@print_node_info
|
|
21
|
+
@inverted_operation_enable_disable
|
|
22
|
+
@get_replacement_parameter
|
|
23
|
+
def make_node(
|
|
24
|
+
*,
|
|
25
|
+
graph_node: gs.Node,
|
|
26
|
+
tf_layers_dict: dict,
|
|
27
|
+
**kwargs: dict,
|
|
28
|
+
):
|
|
29
|
+
"""BlackmanWindow
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
graph_node: gs.Node
|
|
34
|
+
graph_surgeon Node
|
|
35
|
+
|
|
36
|
+
tf_layers_dict: dict
|
|
37
|
+
optype, shape, dtype, tensorflow graph
|
|
38
|
+
"""
|
|
39
|
+
graph_node_input_1 = get_constant_or_variable(
|
|
40
|
+
graph_node.inputs[0],
|
|
41
|
+
before_op_output_shape_trans=False,
|
|
42
|
+
)
|
|
43
|
+
size = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
44
|
+
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
45
|
+
|
|
46
|
+
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
47
|
+
|
|
48
|
+
shape = graph_node_output.shape
|
|
49
|
+
dtype = graph_node_output.dtype
|
|
50
|
+
|
|
51
|
+
output_datatype = int(graph_node.attrs.get('output_datatype', TensorProto.FLOAT))
|
|
52
|
+
output_datatype = ONNX_DTYPES_TO_TF_DTYPES[output_datatype]
|
|
53
|
+
periodic = bool(graph_node.attrs.get('periodic', 1))
|
|
54
|
+
|
|
55
|
+
# Preserving Graph Structure (Dict)
|
|
56
|
+
tf_layers_dict[graph_node_output.name] = {
|
|
57
|
+
'optype': graph_node.op,
|
|
58
|
+
'shape': shape,
|
|
59
|
+
'dtype': dtype,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
# Pre-process transpose
|
|
63
|
+
size = pre_process_transpose(
|
|
64
|
+
value_before_transpose=size,
|
|
65
|
+
param_target='inputs',
|
|
66
|
+
param_name=graph_node.inputs[0].name,
|
|
67
|
+
**kwargs,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Generation of TF OP
|
|
71
|
+
size_fp = tf.cast(size, tf.float32)
|
|
72
|
+
periodic_size_fp = size_fp
|
|
73
|
+
symmetric_size_fp = size_fp - tf.constant(1.0, dtype=tf.float32)
|
|
74
|
+
is_periodic_fp = tf.cast(periodic, tf.float32)
|
|
75
|
+
size_fp = periodic_size_fp * is_periodic_fp + symmetric_size_fp * (1.0 - is_periodic_fp)
|
|
76
|
+
|
|
77
|
+
two_pi = tf.constant(6.28319, dtype=tf.float32)
|
|
78
|
+
angular_increment = tf.math.divide_no_nan(two_pi, size_fp)
|
|
79
|
+
range_vals = tf.range(tf.cast(periodic_size_fp, tf.int32), dtype=tf.float32)
|
|
80
|
+
range_angular = range_vals * angular_increment
|
|
81
|
+
|
|
82
|
+
a0 = tf.constant(0.42, dtype=tf.float32)
|
|
83
|
+
a1 = tf.constant(0.5, dtype=tf.float32)
|
|
84
|
+
a2 = tf.constant(0.08, dtype=tf.float32)
|
|
85
|
+
|
|
86
|
+
temp0 = a0 - a1 * tf.cos(range_angular)
|
|
87
|
+
temp1 = temp0 + a2 * tf.cos(range_angular * 2.0)
|
|
88
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = tf.cast(
|
|
89
|
+
temp1,
|
|
90
|
+
dtype=output_datatype,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Post-process transpose
|
|
94
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
|
|
95
|
+
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
96
|
+
param_target='outputs',
|
|
97
|
+
param_name=graph_node.outputs[0].name,
|
|
98
|
+
**kwargs,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Generation of Debug Info
|
|
102
|
+
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
103
|
+
make_tf_node_info(
|
|
104
|
+
node_info={
|
|
105
|
+
'tf_op_type': 'BlackmanWindow',
|
|
106
|
+
'tf_inputs': {
|
|
107
|
+
'size': size,
|
|
108
|
+
'periodic': periodic,
|
|
109
|
+
'dtype': output_datatype,
|
|
110
|
+
},
|
|
111
|
+
'tf_outputs': {
|
|
112
|
+
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
113
|
+
},
|
|
114
|
+
}
|
|
115
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx2tf
|
|
3
|
-
Version: 1.29.
|
|
3
|
+
Version: 1.29.6
|
|
4
4
|
Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC).
|
|
5
5
|
Home-page: https://github.com/PINTO0309/onnx2tf
|
|
6
6
|
Author: Katsuya Hyodo
|
|
@@ -95,7 +95,7 @@ https://github.com/PINTO0309/onnx2tf/wiki/model_status
|
|
|
95
95
|
|Acosh|:heavy_check_mark:|
|
|
96
96
|
|Acos|:heavy_check_mark:|
|
|
97
97
|
|Add|:heavy_check_mark:|
|
|
98
|
-
|AffineGrid
|
|
98
|
+
|AffineGrid|:heavy_check_mark:|
|
|
99
99
|
|And|:heavy_check_mark:|
|
|
100
100
|
|ArgMax|:heavy_check_mark:|
|
|
101
101
|
|ArgMin|:heavy_check_mark:|
|
|
@@ -103,7 +103,7 @@ https://github.com/PINTO0309/onnx2tf/wiki/model_status
|
|
|
103
103
|
|Asin|:heavy_check_mark:|
|
|
104
104
|
|Atanh|:heavy_check_mark:|
|
|
105
105
|
|Atan|:heavy_check_mark:|
|
|
106
|
-
|Attention
|
|
106
|
+
|Attention|:heavy_check_mark:|
|
|
107
107
|
|AveragePool|:heavy_check_mark:|
|
|
108
108
|
|BatchNormalization|:heavy_check_mark:|
|
|
109
109
|
|Bernoulli|:heavy_check_mark:|
|
|
@@ -112,7 +112,7 @@ https://github.com/PINTO0309/onnx2tf/wiki/model_status
|
|
|
112
112
|
|BitwiseNot|:heavy_check_mark:|
|
|
113
113
|
|BitwiseOr|:heavy_check_mark:|
|
|
114
114
|
|BitwiseXor|:heavy_check_mark:|
|
|
115
|
-
|BlackmanWindow
|
|
115
|
+
|BlackmanWindow|:heavy_check_mark:|
|
|
116
116
|
|Cast|:heavy_check_mark:|
|
|
117
117
|
|Ceil|:heavy_check_mark:|
|
|
118
118
|
|Celu|:heavy_check_mark:|
|
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
onnx2tf/__init__.py,sha256=
|
|
1
|
+
onnx2tf/__init__.py,sha256=CNyoySoQEQCjv_qCPU4gLxS5PyakyrIhOAy3AYHjpsI,66
|
|
2
2
|
onnx2tf/__main__.py,sha256=2RSCQ7d4lc6CwD-rlGn9UicPFg-P5du7ZD_yh-kuBEU,57
|
|
3
3
|
onnx2tf/onnx2tf.py,sha256=wdBA-lgCEu-ZfUAKIUQgLe8hSP8ifE7rS6nWAq6iF6o,151519
|
|
4
4
|
onnx2tf/ops/Abs.py,sha256=V7btmCG_ZvK_qJovUsguq0ZMJ349mhNQ4FHSgzP_Yuo,4029
|
|
5
5
|
onnx2tf/ops/Acos.py,sha256=Fo8YkFKuWq8Fi2xUrBdKcAH1yJ8r5pjSD0wgLttTNdk,4003
|
|
6
6
|
onnx2tf/ops/Acosh.py,sha256=ATQj2cT5JS_mTfXi0kXqJ1yzSZu5J0zHA5VjV3j7uKY,3588
|
|
7
7
|
onnx2tf/ops/Add.py,sha256=pgJTnV1wZZk3mRaVxxezVkArfmlqlk74DCMZDm6VRJc,12295
|
|
8
|
+
onnx2tf/ops/AffineGrid.py,sha256=j_Q0gRoWpQhep7xHQVqyEBiCbe4yiNelIYSsvq0MPXg,6281
|
|
8
9
|
onnx2tf/ops/And.py,sha256=_ubtWa0r8-60x__pS7MEMil1DfBqxiUsk66yRCYS4KY,4591
|
|
9
10
|
onnx2tf/ops/ArgMax.py,sha256=F3PV4EchYQgH1GATJybVGnmY9sGvZkgxCHbNCue9Qns,7278
|
|
10
11
|
onnx2tf/ops/ArgMin.py,sha256=32r7I8AYLQOKTPOOPX1AZwiPnQfkrFB0Le16vdJ1yBs,4225
|
|
@@ -12,6 +13,7 @@ onnx2tf/ops/Asin.py,sha256=2djUjTaOzXM6t4Qb-EEMZY-pm1rJl24cgcrep2i_6aQ,4003
|
|
|
12
13
|
onnx2tf/ops/Asinh.py,sha256=74ZzTEkpxZY4CGfJT2JJU-SHXYL5KZeUkWY2v7hsMMw,3588
|
|
13
14
|
onnx2tf/ops/Atan.py,sha256=D24XDMxEwXFtJheQAr3V3IWOUOc6Q5M0-b_83bmGGMM,3981
|
|
14
15
|
onnx2tf/ops/Atanh.py,sha256=VsUYopBWWPoo4gta1_aqvUL6NrVXuVkGid4SqDqYJ9Q,3588
|
|
16
|
+
onnx2tf/ops/Attention.py,sha256=7TMOdPztVLtNKSzeozvaRxhUFVhACci8wvhn7ONKWrQ,21006
|
|
15
17
|
onnx2tf/ops/AveragePool.py,sha256=kifQJZplqC2Px209BotbjXCPpRBQQsB8DlJYJTvJD78,20065
|
|
16
18
|
onnx2tf/ops/BatchNormalization.py,sha256=_hlf2-5-j3MCJHEoE2oMNQ8YhCm7ad9h2fwPpTo3i7g,26624
|
|
17
19
|
onnx2tf/ops/Bernoulli.py,sha256=PM0xS0n1q4bnT_9PnbcKW8_Qj8dJYYBQR8kb2X-wIp4,3670
|
|
@@ -20,6 +22,7 @@ onnx2tf/ops/BitwiseAnd.py,sha256=snmmVzVwLxhWh0aKyaskScBvefncGyW7ZPVrmbugazk,345
|
|
|
20
22
|
onnx2tf/ops/BitwiseNot.py,sha256=QuFUyK24JGrEOKYu-6lRi9uZLz4MKVtBwUqzDdqtBKA,2721
|
|
21
23
|
onnx2tf/ops/BitwiseOr.py,sha256=WSswhA3qmp3OJ4iIibl_2ps-tZEyfKI7B19GiFH7Uik,3453
|
|
22
24
|
onnx2tf/ops/BitwiseXor.py,sha256=d1WoshWdfcoQnYrdaxafRleipy1d0AKleTgh0G7lZlw,3456
|
|
25
|
+
onnx2tf/ops/BlackmanWindow.py,sha256=o_wLhYAmMearuJNlSdUfDeQm7D6g_y_H21uG-foctbA,3532
|
|
23
26
|
onnx2tf/ops/Cast.py,sha256=M0LRClHPgZ_8NubwME6ipKrAqcY9aKC5ihQXCkTkNkM,4601
|
|
24
27
|
onnx2tf/ops/Ceil.py,sha256=0-jaueltpQSwpOIDUmy9DdTy98qN-XimYu5cHVPnUIs,3586
|
|
25
28
|
onnx2tf/ops/Celu.py,sha256=9g7WNKo4G_jMtUXcoOfpNdLYqEsuyXLPkkyQZxDuL4U,3853
|
|
@@ -195,9 +198,9 @@ onnx2tf/utils/enums.py,sha256=7c5TqetqB07VjyHoxJHfLgtqBqk9ZRyUF33fPOJR1IM,1649
|
|
|
195
198
|
onnx2tf/utils/iterative_json_optimizer.py,sha256=qqeIxWGxrhcCYk8-ebWnblnOkzDCwi-nseipHzHR_bk,10436
|
|
196
199
|
onnx2tf/utils/json_auto_generator.py,sha256=OC-SfKtUg7zUxaXTAg6kT0ShzIc3ByjDa3FNp173DtA,60302
|
|
197
200
|
onnx2tf/utils/logging.py,sha256=yUCmPuJ_XiUItM3sZMcaMO24JErkQy7zZwVTYWAuiKg,1982
|
|
198
|
-
onnx2tf-1.29.
|
|
199
|
-
onnx2tf-1.29.
|
|
200
|
-
onnx2tf-1.29.
|
|
201
|
-
onnx2tf-1.29.
|
|
202
|
-
onnx2tf-1.29.
|
|
203
|
-
onnx2tf-1.29.
|
|
201
|
+
onnx2tf-1.29.6.dist-info/licenses/LICENSE,sha256=5v_Kxihy8i6mzHVl349ikSREaIdsl9YeUnX1KBDLD2w,1070
|
|
202
|
+
onnx2tf-1.29.6.dist-info/licenses/LICENSE_onnx-tensorflow,sha256=gK4GtS9S5YcyINu6uuNNWdo-kBClyEM4MFLFGiNTeRM,11231
|
|
203
|
+
onnx2tf-1.29.6.dist-info/METADATA,sha256=uth9COxuJ3aYEd7NDot-Nn2pvwnoEFPf66plCp-s1tk,153697
|
|
204
|
+
onnx2tf-1.29.6.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
205
|
+
onnx2tf-1.29.6.dist-info/top_level.txt,sha256=WgfPiEy3f6vZ_FOpAIEA2CF3TCx1eYrhGw93Ih6b9Fw,8
|
|
206
|
+
onnx2tf-1.29.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|