onnx2tf 1.29.17__py3-none-any.whl → 1.29.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/ops/Col2Im.py +108 -64
- onnx2tf/ops/DFT.py +245 -0
- onnx2tf/ops/DeformConv.py +399 -0
- onnx2tf/ops/ImageDecoder.py +147 -0
- onnx2tf/ops/NegativeLogLikelihoodLoss.py +237 -0
- onnx2tf/ops/RMSNormalization.py +175 -0
- onnx2tf/ops/RegexFullMatch.py +108 -0
- onnx2tf/ops/RotaryEmbedding.py +285 -0
- onnx2tf/ops/Scan.py +438 -0
- onnx2tf/ops/SoftmaxCrossEntropyLoss.py +289 -0
- onnx2tf/ops/StringConcat.py +128 -0
- onnx2tf/ops/StringNormalizer.py +54 -39
- onnx2tf/ops/StringSplit.py +156 -0
- onnx2tf/ops/TensorScatter.py +223 -0
- {onnx2tf-1.29.17.dist-info → onnx2tf-1.29.19.dist-info}/METADATA +15 -14
- {onnx2tf-1.29.17.dist-info → onnx2tf-1.29.19.dist-info}/RECORD +19 -7
- {onnx2tf-1.29.17.dist-info → onnx2tf-1.29.19.dist-info}/WHEEL +1 -1
- {onnx2tf-1.29.17.dist-info → onnx2tf-1.29.19.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import random
|
|
3
|
+
random.seed(0)
|
|
4
|
+
import numpy as np
|
|
5
|
+
np.random.seed(0)
|
|
6
|
+
import tensorflow as tf
|
|
7
|
+
import onnx_graphsurgeon as gs
|
|
8
|
+
from onnx2tf.utils.common_functions import (
|
|
9
|
+
get_constant_or_variable,
|
|
10
|
+
print_node_info,
|
|
11
|
+
inverted_operation_enable_disable,
|
|
12
|
+
make_tf_node_info,
|
|
13
|
+
get_replacement_parameter,
|
|
14
|
+
pre_process_transpose,
|
|
15
|
+
post_process_transpose,
|
|
16
|
+
)
|
|
17
|
+
from onnx2tf.utils.logging import *
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _as_tensor(value):
|
|
21
|
+
if isinstance(value, np.ndarray):
|
|
22
|
+
return tf.convert_to_tensor(value)
|
|
23
|
+
if isinstance(value, (np.generic, int, float, bool, str, bytes)):
|
|
24
|
+
return tf.convert_to_tensor(value)
|
|
25
|
+
return value
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _split_rotary(input_tensor, rotary_dim):
|
|
29
|
+
if isinstance(rotary_dim, int):
|
|
30
|
+
x_rotate = input_tensor[:, :, :, :rotary_dim]
|
|
31
|
+
x_not_rotate = input_tensor[:, :, :, rotary_dim:]
|
|
32
|
+
return x_rotate, x_not_rotate
|
|
33
|
+
rotary_dim = tf.cast(rotary_dim, tf.int32)
|
|
34
|
+
input_shape = tf.shape(input_tensor)
|
|
35
|
+
head_size = input_shape[-1]
|
|
36
|
+
x_rotate = tf.slice(
|
|
37
|
+
input_tensor,
|
|
38
|
+
[0, 0, 0, 0],
|
|
39
|
+
[-1, -1, -1, rotary_dim],
|
|
40
|
+
)
|
|
41
|
+
x_not_rotate = tf.slice(
|
|
42
|
+
input_tensor,
|
|
43
|
+
[0, 0, 0, rotary_dim],
|
|
44
|
+
[-1, -1, -1, head_size - rotary_dim],
|
|
45
|
+
)
|
|
46
|
+
return x_rotate, x_not_rotate
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@print_node_info
|
|
50
|
+
@inverted_operation_enable_disable
|
|
51
|
+
@get_replacement_parameter
|
|
52
|
+
def make_node(
|
|
53
|
+
*,
|
|
54
|
+
graph_node: gs.Node,
|
|
55
|
+
tf_layers_dict: dict,
|
|
56
|
+
**kwargs: dict,
|
|
57
|
+
):
|
|
58
|
+
"""RotaryEmbedding
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
graph_node: gs.Node
|
|
63
|
+
graph_surgeon Node
|
|
64
|
+
|
|
65
|
+
tf_layers_dict: dict
|
|
66
|
+
optype, shape, dtype, tensorflow graph
|
|
67
|
+
"""
|
|
68
|
+
before_op_output_shape_trans_1 = \
|
|
69
|
+
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
70
|
+
before_op_output_shape_trans_2 = \
|
|
71
|
+
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
|
|
72
|
+
before_op_output_shape_trans_3 = \
|
|
73
|
+
tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
|
|
74
|
+
before_op_output_shape_trans = \
|
|
75
|
+
before_op_output_shape_trans_1 \
|
|
76
|
+
and before_op_output_shape_trans_2 \
|
|
77
|
+
and before_op_output_shape_trans_3
|
|
78
|
+
if len(graph_node.inputs) >= 4:
|
|
79
|
+
before_op_output_shape_trans_4 = \
|
|
80
|
+
tf_layers_dict.get(graph_node.inputs[3].name, {}).get('before_op_output_shape_trans', True)
|
|
81
|
+
before_op_output_shape_trans = \
|
|
82
|
+
before_op_output_shape_trans \
|
|
83
|
+
and before_op_output_shape_trans_4
|
|
84
|
+
|
|
85
|
+
graph_node_input_1 = get_constant_or_variable(
|
|
86
|
+
graph_node.inputs[0],
|
|
87
|
+
before_op_output_shape_trans,
|
|
88
|
+
)
|
|
89
|
+
graph_node_input_2 = get_constant_or_variable(
|
|
90
|
+
graph_node.inputs[1],
|
|
91
|
+
before_op_output_shape_trans,
|
|
92
|
+
)
|
|
93
|
+
graph_node_input_3 = get_constant_or_variable(
|
|
94
|
+
graph_node.inputs[2],
|
|
95
|
+
before_op_output_shape_trans,
|
|
96
|
+
)
|
|
97
|
+
graph_node_input_4 = None
|
|
98
|
+
if len(graph_node.inputs) >= 4:
|
|
99
|
+
graph_node_input_4 = get_constant_or_variable(
|
|
100
|
+
graph_node.inputs[3],
|
|
101
|
+
before_op_output_shape_trans=False,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
105
|
+
shape = graph_node_output.shape
|
|
106
|
+
dtype = graph_node_output.dtype
|
|
107
|
+
|
|
108
|
+
input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
|
|
109
|
+
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
|
|
110
|
+
cos_cache = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
|
|
111
|
+
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
|
|
112
|
+
sin_cache = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
|
|
113
|
+
if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
|
|
114
|
+
position_ids = None
|
|
115
|
+
if graph_node_input_4 is not None:
|
|
116
|
+
position_ids = tf_layers_dict[graph_node_input_4.name]['tf_node'] \
|
|
117
|
+
if isinstance(graph_node_input_4, gs.Variable) else graph_node_input_4
|
|
118
|
+
|
|
119
|
+
# Preserving Graph Structure (Dict)
|
|
120
|
+
tf_layers_dict[graph_node_output.name] = {
|
|
121
|
+
'optype': graph_node.op,
|
|
122
|
+
'shape': shape,
|
|
123
|
+
'dtype': dtype,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
# Pre-process transpose
|
|
127
|
+
input_tensor = pre_process_transpose(
|
|
128
|
+
value_before_transpose=input_tensor,
|
|
129
|
+
param_target='inputs',
|
|
130
|
+
param_name=graph_node.inputs[0].name,
|
|
131
|
+
**kwargs,
|
|
132
|
+
)
|
|
133
|
+
cos_cache = pre_process_transpose(
|
|
134
|
+
value_before_transpose=cos_cache,
|
|
135
|
+
param_target='inputs',
|
|
136
|
+
param_name=graph_node.inputs[1].name,
|
|
137
|
+
**kwargs,
|
|
138
|
+
)
|
|
139
|
+
sin_cache = pre_process_transpose(
|
|
140
|
+
value_before_transpose=sin_cache,
|
|
141
|
+
param_target='inputs',
|
|
142
|
+
param_name=graph_node.inputs[2].name,
|
|
143
|
+
**kwargs,
|
|
144
|
+
)
|
|
145
|
+
if position_ids is not None:
|
|
146
|
+
position_ids = pre_process_transpose(
|
|
147
|
+
value_before_transpose=position_ids,
|
|
148
|
+
param_target='inputs',
|
|
149
|
+
param_name=graph_node.inputs[3].name,
|
|
150
|
+
**kwargs,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Generation of TF OP
|
|
154
|
+
input_tensor = _as_tensor(input_tensor)
|
|
155
|
+
cos_cache = _as_tensor(cos_cache)
|
|
156
|
+
sin_cache = _as_tensor(sin_cache)
|
|
157
|
+
if position_ids is not None:
|
|
158
|
+
position_ids = _as_tensor(position_ids)
|
|
159
|
+
|
|
160
|
+
input_dtype = input_tensor.dtype
|
|
161
|
+
if cos_cache.dtype != input_dtype:
|
|
162
|
+
cos_cache = tf.cast(cos_cache, input_dtype)
|
|
163
|
+
if sin_cache.dtype != input_dtype:
|
|
164
|
+
sin_cache = tf.cast(sin_cache, input_dtype)
|
|
165
|
+
|
|
166
|
+
input_rank = input_tensor.shape.rank
|
|
167
|
+
if input_rank is None:
|
|
168
|
+
error(
|
|
169
|
+
f'RotaryEmbedding only supports 3D/4D input with known rank.\n' +
|
|
170
|
+
f'graph_node.name: {graph_node.name}'
|
|
171
|
+
)
|
|
172
|
+
sys.exit(1)
|
|
173
|
+
|
|
174
|
+
original_input_tensor = input_tensor
|
|
175
|
+
original_input_shape = tf.shape(original_input_tensor)
|
|
176
|
+
|
|
177
|
+
num_heads = graph_node.attrs.get('num_heads', None)
|
|
178
|
+
rotary_embedding_dim = graph_node.attrs.get('rotary_embedding_dim', 0)
|
|
179
|
+
interleaved = bool(graph_node.attrs.get('interleaved', 0))
|
|
180
|
+
|
|
181
|
+
if input_rank == 4:
|
|
182
|
+
input_tensor = tf.transpose(input_tensor, perm=[0, 2, 1, 3])
|
|
183
|
+
elif input_rank == 3:
|
|
184
|
+
if num_heads is None or int(num_heads) == 0:
|
|
185
|
+
error(
|
|
186
|
+
f'num_heads attribute is required for 3D input in RotaryEmbedding.\n' +
|
|
187
|
+
f'graph_node.name: {graph_node.name}'
|
|
188
|
+
)
|
|
189
|
+
sys.exit(1)
|
|
190
|
+
num_heads = int(num_heads)
|
|
191
|
+
input_shape = tf.shape(input_tensor)
|
|
192
|
+
head_size = tf.math.floordiv(
|
|
193
|
+
input_shape[-1],
|
|
194
|
+
tf.constant(num_heads, dtype=input_shape.dtype),
|
|
195
|
+
)
|
|
196
|
+
input_tensor = tf.reshape(
|
|
197
|
+
input_tensor,
|
|
198
|
+
tf.stack(
|
|
199
|
+
[
|
|
200
|
+
input_shape[0],
|
|
201
|
+
input_shape[1],
|
|
202
|
+
tf.constant(num_heads, dtype=input_shape.dtype),
|
|
203
|
+
head_size,
|
|
204
|
+
]
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
error(
|
|
209
|
+
f'RotaryEmbedding only supports 3D/4D input.\n' +
|
|
210
|
+
f'graph_node.name: {graph_node.name}'
|
|
211
|
+
)
|
|
212
|
+
sys.exit(1)
|
|
213
|
+
|
|
214
|
+
head_size = input_tensor.shape[-1]
|
|
215
|
+
if head_size is None:
|
|
216
|
+
head_size = tf.shape(input_tensor)[-1]
|
|
217
|
+
|
|
218
|
+
if rotary_embedding_dim is None or int(rotary_embedding_dim) == 0:
|
|
219
|
+
rotary_embedding_dim = head_size
|
|
220
|
+
else:
|
|
221
|
+
rotary_embedding_dim = int(rotary_embedding_dim)
|
|
222
|
+
|
|
223
|
+
x_rotate, x_not_rotate = _split_rotary(input_tensor, rotary_embedding_dim)
|
|
224
|
+
|
|
225
|
+
if position_ids is not None:
|
|
226
|
+
cos_cache = tf.gather(cos_cache, position_ids)
|
|
227
|
+
sin_cache = tf.gather(sin_cache, position_ids)
|
|
228
|
+
|
|
229
|
+
cos_cache = tf.expand_dims(cos_cache, axis=2)
|
|
230
|
+
sin_cache = tf.expand_dims(sin_cache, axis=2)
|
|
231
|
+
|
|
232
|
+
if interleaved:
|
|
233
|
+
x1 = x_rotate[:, :, :, 0::2]
|
|
234
|
+
x2 = x_rotate[:, :, :, 1::2]
|
|
235
|
+
else:
|
|
236
|
+
x1, x2 = tf.split(x_rotate, num_or_size_splits=2, axis=-1)
|
|
237
|
+
|
|
238
|
+
real = (cos_cache * x1) - (sin_cache * x2)
|
|
239
|
+
imag = (sin_cache * x1) + (cos_cache * x2)
|
|
240
|
+
|
|
241
|
+
if interleaved:
|
|
242
|
+
real = tf.expand_dims(real, axis=-1)
|
|
243
|
+
imag = tf.expand_dims(imag, axis=-1)
|
|
244
|
+
x_rotate = tf.reshape(
|
|
245
|
+
tf.concat([real, imag], axis=-1),
|
|
246
|
+
tf.shape(x_rotate),
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
x_rotate = tf.concat([real, imag], axis=-1)
|
|
250
|
+
|
|
251
|
+
output_tensor = tf.concat([x_rotate, x_not_rotate], axis=-1)
|
|
252
|
+
if input_rank == 3:
|
|
253
|
+
output_tensor = tf.reshape(output_tensor, original_input_shape)
|
|
254
|
+
else:
|
|
255
|
+
output_tensor = tf.transpose(output_tensor, perm=[0, 2, 1, 3])
|
|
256
|
+
|
|
257
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
|
|
258
|
+
|
|
259
|
+
# Post-process transpose
|
|
260
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
|
|
261
|
+
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
262
|
+
param_target='outputs',
|
|
263
|
+
param_name=graph_node.outputs[0].name,
|
|
264
|
+
**kwargs,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Generation of Debug Info
|
|
268
|
+
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
269
|
+
make_tf_node_info(
|
|
270
|
+
node_info={
|
|
271
|
+
'tf_op_type': 'RotaryEmbedding',
|
|
272
|
+
'tf_inputs': {
|
|
273
|
+
'input': input_tensor,
|
|
274
|
+
'cos_cache': cos_cache,
|
|
275
|
+
'sin_cache': sin_cache,
|
|
276
|
+
'position_ids': position_ids,
|
|
277
|
+
'interleaved': interleaved,
|
|
278
|
+
'rotary_embedding_dim': rotary_embedding_dim,
|
|
279
|
+
'num_heads': num_heads,
|
|
280
|
+
},
|
|
281
|
+
'tf_outputs': {
|
|
282
|
+
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
283
|
+
},
|
|
284
|
+
}
|
|
285
|
+
)
|