onnx2tf 1.29.17__py3-none-any.whl → 1.29.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,285 @@
1
+ import sys
2
+ import random
3
+ random.seed(0)
4
+ import numpy as np
5
+ np.random.seed(0)
6
+ import tensorflow as tf
7
+ import onnx_graphsurgeon as gs
8
+ from onnx2tf.utils.common_functions import (
9
+ get_constant_or_variable,
10
+ print_node_info,
11
+ inverted_operation_enable_disable,
12
+ make_tf_node_info,
13
+ get_replacement_parameter,
14
+ pre_process_transpose,
15
+ post_process_transpose,
16
+ )
17
+ from onnx2tf.utils.logging import *
18
+
19
+
20
+ def _as_tensor(value):
21
+ if isinstance(value, np.ndarray):
22
+ return tf.convert_to_tensor(value)
23
+ if isinstance(value, (np.generic, int, float, bool, str, bytes)):
24
+ return tf.convert_to_tensor(value)
25
+ return value
26
+
27
+
28
+ def _split_rotary(input_tensor, rotary_dim):
29
+ if isinstance(rotary_dim, int):
30
+ x_rotate = input_tensor[:, :, :, :rotary_dim]
31
+ x_not_rotate = input_tensor[:, :, :, rotary_dim:]
32
+ return x_rotate, x_not_rotate
33
+ rotary_dim = tf.cast(rotary_dim, tf.int32)
34
+ input_shape = tf.shape(input_tensor)
35
+ head_size = input_shape[-1]
36
+ x_rotate = tf.slice(
37
+ input_tensor,
38
+ [0, 0, 0, 0],
39
+ [-1, -1, -1, rotary_dim],
40
+ )
41
+ x_not_rotate = tf.slice(
42
+ input_tensor,
43
+ [0, 0, 0, rotary_dim],
44
+ [-1, -1, -1, head_size - rotary_dim],
45
+ )
46
+ return x_rotate, x_not_rotate
47
+
48
+
49
+ @print_node_info
50
+ @inverted_operation_enable_disable
51
+ @get_replacement_parameter
52
+ def make_node(
53
+ *,
54
+ graph_node: gs.Node,
55
+ tf_layers_dict: dict,
56
+ **kwargs: dict,
57
+ ):
58
+ """RotaryEmbedding
59
+
60
+ Parameters
61
+ ----------
62
+ graph_node: gs.Node
63
+ graph_surgeon Node
64
+
65
+ tf_layers_dict: dict
66
+ optype, shape, dtype, tensorflow graph
67
+ """
68
+ before_op_output_shape_trans_1 = \
69
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
70
+ before_op_output_shape_trans_2 = \
71
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
72
+ before_op_output_shape_trans_3 = \
73
+ tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
74
+ before_op_output_shape_trans = \
75
+ before_op_output_shape_trans_1 \
76
+ and before_op_output_shape_trans_2 \
77
+ and before_op_output_shape_trans_3
78
+ if len(graph_node.inputs) >= 4:
79
+ before_op_output_shape_trans_4 = \
80
+ tf_layers_dict.get(graph_node.inputs[3].name, {}).get('before_op_output_shape_trans', True)
81
+ before_op_output_shape_trans = \
82
+ before_op_output_shape_trans \
83
+ and before_op_output_shape_trans_4
84
+
85
+ graph_node_input_1 = get_constant_or_variable(
86
+ graph_node.inputs[0],
87
+ before_op_output_shape_trans,
88
+ )
89
+ graph_node_input_2 = get_constant_or_variable(
90
+ graph_node.inputs[1],
91
+ before_op_output_shape_trans,
92
+ )
93
+ graph_node_input_3 = get_constant_or_variable(
94
+ graph_node.inputs[2],
95
+ before_op_output_shape_trans,
96
+ )
97
+ graph_node_input_4 = None
98
+ if len(graph_node.inputs) >= 4:
99
+ graph_node_input_4 = get_constant_or_variable(
100
+ graph_node.inputs[3],
101
+ before_op_output_shape_trans=False,
102
+ )
103
+
104
+ graph_node_output: gs.Variable = graph_node.outputs[0]
105
+ shape = graph_node_output.shape
106
+ dtype = graph_node_output.dtype
107
+
108
+ input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
109
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
110
+ cos_cache = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
111
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
112
+ sin_cache = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
113
+ if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
114
+ position_ids = None
115
+ if graph_node_input_4 is not None:
116
+ position_ids = tf_layers_dict[graph_node_input_4.name]['tf_node'] \
117
+ if isinstance(graph_node_input_4, gs.Variable) else graph_node_input_4
118
+
119
+ # Preserving Graph Structure (Dict)
120
+ tf_layers_dict[graph_node_output.name] = {
121
+ 'optype': graph_node.op,
122
+ 'shape': shape,
123
+ 'dtype': dtype,
124
+ }
125
+
126
+ # Pre-process transpose
127
+ input_tensor = pre_process_transpose(
128
+ value_before_transpose=input_tensor,
129
+ param_target='inputs',
130
+ param_name=graph_node.inputs[0].name,
131
+ **kwargs,
132
+ )
133
+ cos_cache = pre_process_transpose(
134
+ value_before_transpose=cos_cache,
135
+ param_target='inputs',
136
+ param_name=graph_node.inputs[1].name,
137
+ **kwargs,
138
+ )
139
+ sin_cache = pre_process_transpose(
140
+ value_before_transpose=sin_cache,
141
+ param_target='inputs',
142
+ param_name=graph_node.inputs[2].name,
143
+ **kwargs,
144
+ )
145
+ if position_ids is not None:
146
+ position_ids = pre_process_transpose(
147
+ value_before_transpose=position_ids,
148
+ param_target='inputs',
149
+ param_name=graph_node.inputs[3].name,
150
+ **kwargs,
151
+ )
152
+
153
+ # Generation of TF OP
154
+ input_tensor = _as_tensor(input_tensor)
155
+ cos_cache = _as_tensor(cos_cache)
156
+ sin_cache = _as_tensor(sin_cache)
157
+ if position_ids is not None:
158
+ position_ids = _as_tensor(position_ids)
159
+
160
+ input_dtype = input_tensor.dtype
161
+ if cos_cache.dtype != input_dtype:
162
+ cos_cache = tf.cast(cos_cache, input_dtype)
163
+ if sin_cache.dtype != input_dtype:
164
+ sin_cache = tf.cast(sin_cache, input_dtype)
165
+
166
+ input_rank = input_tensor.shape.rank
167
+ if input_rank is None:
168
+ error(
169
+ f'RotaryEmbedding only supports 3D/4D input with known rank.\n' +
170
+ f'graph_node.name: {graph_node.name}'
171
+ )
172
+ sys.exit(1)
173
+
174
+ original_input_tensor = input_tensor
175
+ original_input_shape = tf.shape(original_input_tensor)
176
+
177
+ num_heads = graph_node.attrs.get('num_heads', None)
178
+ rotary_embedding_dim = graph_node.attrs.get('rotary_embedding_dim', 0)
179
+ interleaved = bool(graph_node.attrs.get('interleaved', 0))
180
+
181
+ if input_rank == 4:
182
+ input_tensor = tf.transpose(input_tensor, perm=[0, 2, 1, 3])
183
+ elif input_rank == 3:
184
+ if num_heads is None or int(num_heads) == 0:
185
+ error(
186
+ f'num_heads attribute is required for 3D input in RotaryEmbedding.\n' +
187
+ f'graph_node.name: {graph_node.name}'
188
+ )
189
+ sys.exit(1)
190
+ num_heads = int(num_heads)
191
+ input_shape = tf.shape(input_tensor)
192
+ head_size = tf.math.floordiv(
193
+ input_shape[-1],
194
+ tf.constant(num_heads, dtype=input_shape.dtype),
195
+ )
196
+ input_tensor = tf.reshape(
197
+ input_tensor,
198
+ tf.stack(
199
+ [
200
+ input_shape[0],
201
+ input_shape[1],
202
+ tf.constant(num_heads, dtype=input_shape.dtype),
203
+ head_size,
204
+ ]
205
+ ),
206
+ )
207
+ else:
208
+ error(
209
+ f'RotaryEmbedding only supports 3D/4D input.\n' +
210
+ f'graph_node.name: {graph_node.name}'
211
+ )
212
+ sys.exit(1)
213
+
214
+ head_size = input_tensor.shape[-1]
215
+ if head_size is None:
216
+ head_size = tf.shape(input_tensor)[-1]
217
+
218
+ if rotary_embedding_dim is None or int(rotary_embedding_dim) == 0:
219
+ rotary_embedding_dim = head_size
220
+ else:
221
+ rotary_embedding_dim = int(rotary_embedding_dim)
222
+
223
+ x_rotate, x_not_rotate = _split_rotary(input_tensor, rotary_embedding_dim)
224
+
225
+ if position_ids is not None:
226
+ cos_cache = tf.gather(cos_cache, position_ids)
227
+ sin_cache = tf.gather(sin_cache, position_ids)
228
+
229
+ cos_cache = tf.expand_dims(cos_cache, axis=2)
230
+ sin_cache = tf.expand_dims(sin_cache, axis=2)
231
+
232
+ if interleaved:
233
+ x1 = x_rotate[:, :, :, 0::2]
234
+ x2 = x_rotate[:, :, :, 1::2]
235
+ else:
236
+ x1, x2 = tf.split(x_rotate, num_or_size_splits=2, axis=-1)
237
+
238
+ real = (cos_cache * x1) - (sin_cache * x2)
239
+ imag = (sin_cache * x1) + (cos_cache * x2)
240
+
241
+ if interleaved:
242
+ real = tf.expand_dims(real, axis=-1)
243
+ imag = tf.expand_dims(imag, axis=-1)
244
+ x_rotate = tf.reshape(
245
+ tf.concat([real, imag], axis=-1),
246
+ tf.shape(x_rotate),
247
+ )
248
+ else:
249
+ x_rotate = tf.concat([real, imag], axis=-1)
250
+
251
+ output_tensor = tf.concat([x_rotate, x_not_rotate], axis=-1)
252
+ if input_rank == 3:
253
+ output_tensor = tf.reshape(output_tensor, original_input_shape)
254
+ else:
255
+ output_tensor = tf.transpose(output_tensor, perm=[0, 2, 1, 3])
256
+
257
+ tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
258
+
259
+ # Post-process transpose
260
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
261
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
262
+ param_target='outputs',
263
+ param_name=graph_node.outputs[0].name,
264
+ **kwargs,
265
+ )
266
+
267
+ # Generation of Debug Info
268
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
269
+ make_tf_node_info(
270
+ node_info={
271
+ 'tf_op_type': 'RotaryEmbedding',
272
+ 'tf_inputs': {
273
+ 'input': input_tensor,
274
+ 'cos_cache': cos_cache,
275
+ 'sin_cache': sin_cache,
276
+ 'position_ids': position_ids,
277
+ 'interleaved': interleaved,
278
+ 'rotary_embedding_dim': rotary_embedding_dim,
279
+ 'num_heads': num_heads,
280
+ },
281
+ 'tf_outputs': {
282
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
283
+ },
284
+ }
285
+ )