onnx2tf 1.29.16__py3-none-any.whl → 1.29.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,289 @@
1
+ import sys
2
+ import random
3
+ random.seed(0)
4
+ import numpy as np
5
+ np.random.seed(0)
6
+ import tensorflow as tf
7
+ import onnx_graphsurgeon as gs
8
+ from onnx2tf.utils.common_functions import (
9
+ get_constant_or_variable,
10
+ convert_axis,
11
+ print_node_info,
12
+ inverted_operation_enable_disable,
13
+ make_tf_node_info,
14
+ get_replacement_parameter,
15
+ pre_process_transpose,
16
+ post_process_transpose,
17
+ )
18
+ from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
19
+ from onnx2tf.utils.logging import *
20
+
21
+
22
+ def _as_tensor(value):
23
+ if isinstance(value, np.ndarray):
24
+ return tf.convert_to_tensor(value)
25
+ if isinstance(value, (np.generic, int, float, bool, str, bytes)):
26
+ return tf.convert_to_tensor(value)
27
+ return value
28
+
29
+
30
+ def _move_class_to_last(tensor, class_axis):
31
+ rank = tensor.shape.rank
32
+ if rank is None:
33
+ rank = tf.rank(tensor)
34
+ if isinstance(rank, int):
35
+ if class_axis == rank - 1:
36
+ return tensor, None
37
+ perm = [i for i in range(rank) if i != class_axis] + [class_axis]
38
+ return tf.transpose(tensor, perm=perm), perm
39
+ return tensor, None
40
+
41
+
42
+ @print_node_info
43
+ @inverted_operation_enable_disable
44
+ @get_replacement_parameter
45
+ def make_node(
46
+ *,
47
+ graph_node: gs.Node,
48
+ tf_layers_dict: dict,
49
+ **kwargs: dict,
50
+ ):
51
+ """SoftmaxCrossEntropyLoss
52
+
53
+ Parameters
54
+ ----------
55
+ graph_node: gs.Node
56
+ graph_surgeon Node
57
+
58
+ tf_layers_dict: dict
59
+ optype, shape, dtype, tensorflow graph
60
+ """
61
+ before_op_output_shape_trans_1 = \
62
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
63
+ before_op_output_shape_trans_2 = \
64
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
65
+ before_op_output_shape_trans = \
66
+ before_op_output_shape_trans_1 \
67
+ and before_op_output_shape_trans_2
68
+ if len(graph_node.inputs) >= 3:
69
+ before_op_output_shape_trans_3 = \
70
+ tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
71
+ before_op_output_shape_trans = \
72
+ before_op_output_shape_trans \
73
+ and before_op_output_shape_trans_3
74
+
75
+ graph_node_input_1 = get_constant_or_variable(
76
+ graph_node.inputs[0],
77
+ before_op_output_shape_trans,
78
+ )
79
+ graph_node_input_2 = get_constant_or_variable(
80
+ graph_node.inputs[1],
81
+ before_op_output_shape_trans,
82
+ )
83
+ graph_node_input_3 = None
84
+ if len(graph_node.inputs) >= 3:
85
+ graph_node_input_3 = get_constant_or_variable(
86
+ graph_node.inputs[2],
87
+ before_op_output_shape_trans,
88
+ )
89
+
90
+ graph_node_output_1: gs.Variable = graph_node.outputs[0]
91
+ output_1_shape = graph_node_output_1.shape
92
+ output_1_dtype = graph_node_output_1.dtype
93
+ output_1_tf_dtype = NUMPY_DTYPES_TO_TF_DTYPES[output_1_dtype] \
94
+ if isinstance(output_1_dtype, np.dtype) else output_1_dtype
95
+
96
+ graph_node_output_2 = None
97
+ if len(graph_node.outputs) >= 2:
98
+ graph_node_output_2 = graph_node.outputs[1]
99
+ output_2_shape = graph_node_output_2.shape
100
+ output_2_dtype = graph_node_output_2.dtype
101
+ output_2_tf_dtype = NUMPY_DTYPES_TO_TF_DTYPES[output_2_dtype] \
102
+ if isinstance(output_2_dtype, np.dtype) else output_2_dtype
103
+ else:
104
+ output_2_tf_dtype = None
105
+
106
+ scores_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
107
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
108
+ labels_tensor = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
109
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
110
+ weight_tensor = None
111
+ if graph_node_input_3 is not None:
112
+ weight_tensor = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
113
+ if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
114
+
115
+ reduction = graph_node.attrs.get('reduction', 'mean')
116
+ ignore_index = graph_node.attrs.get('ignore_index', None)
117
+
118
+ input_rank = len(scores_tensor.shape)
119
+ class_axis = convert_axis(
120
+ axis=1,
121
+ tensor_rank=input_rank,
122
+ before_op_output_shape_trans=before_op_output_shape_trans,
123
+ )
124
+
125
+ # Preserving Graph Structure (Dict)
126
+ output_entry = {
127
+ 'optype': graph_node.op,
128
+ 'shape': output_1_shape,
129
+ 'dtype': output_1_dtype,
130
+ }
131
+ if reduction == 'none':
132
+ output_entry['nhwc'] = tf_layers_dict[graph_node_input_1.name]['nhwc'] \
133
+ if isinstance(graph_node_input_1, gs.Variable) \
134
+ and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
135
+ tf_layers_dict[graph_node_output_1.name] = output_entry
136
+ if graph_node_output_2 is not None:
137
+ tf_layers_dict[graph_node_output_2.name] = {
138
+ 'optype': graph_node.op,
139
+ 'shape': output_2_shape,
140
+ 'dtype': output_2_dtype,
141
+ 'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
142
+ if isinstance(graph_node_input_1, gs.Variable) \
143
+ and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
144
+ }
145
+
146
+ # Pre-process transpose
147
+ scores_tensor = pre_process_transpose(
148
+ value_before_transpose=scores_tensor,
149
+ param_target='inputs',
150
+ param_name=graph_node.inputs[0].name,
151
+ **kwargs,
152
+ )
153
+ labels_tensor = pre_process_transpose(
154
+ value_before_transpose=labels_tensor,
155
+ param_target='inputs',
156
+ param_name=graph_node.inputs[1].name,
157
+ **kwargs,
158
+ )
159
+ if weight_tensor is not None:
160
+ weight_tensor = pre_process_transpose(
161
+ value_before_transpose=weight_tensor,
162
+ param_target='inputs',
163
+ param_name=graph_node.inputs[2].name,
164
+ **kwargs,
165
+ )
166
+
167
+ # Generation of TF OP
168
+ scores_tensor = _as_tensor(scores_tensor)
169
+ labels_tensor = _as_tensor(labels_tensor)
170
+ if weight_tensor is not None:
171
+ weight_tensor = _as_tensor(weight_tensor)
172
+
173
+ log_prob = tf.nn.log_softmax(
174
+ logits=scores_tensor,
175
+ axis=class_axis,
176
+ )
177
+
178
+ log_prob_for_loss, _ = _move_class_to_last(log_prob, class_axis)
179
+
180
+ depth = log_prob_for_loss.shape[-1]
181
+ if depth is None:
182
+ depth = tf.shape(log_prob_for_loss)[-1]
183
+ depth = tf.cast(depth, tf.int32)
184
+
185
+ labels = tf.cast(labels_tensor, tf.int32)
186
+ if ignore_index is not None:
187
+ ignore_index_val = tf.cast(ignore_index, labels.dtype)
188
+ mask = tf.equal(labels, ignore_index_val)
189
+ labels_safe = tf.where(mask, tf.zeros_like(labels), labels)
190
+ else:
191
+ mask = None
192
+ labels_safe = labels
193
+
194
+ one_hot = tf.one_hot(
195
+ indices=labels_safe,
196
+ depth=depth,
197
+ axis=-1,
198
+ dtype=log_prob_for_loss.dtype,
199
+ )
200
+ selected = tf.reduce_sum(log_prob_for_loss * one_hot, axis=-1)
201
+ loss = -selected
202
+
203
+ weight_per_label = None
204
+ if weight_tensor is not None:
205
+ weight_per_label = tf.gather(weight_tensor, labels_safe)
206
+ weight_per_label = tf.cast(weight_per_label, loss.dtype)
207
+ if mask is not None:
208
+ weight_per_label = tf.where(mask, tf.zeros_like(weight_per_label), weight_per_label)
209
+ loss = loss * weight_per_label
210
+
211
+ if mask is not None:
212
+ loss = tf.where(mask, tf.zeros_like(loss), loss)
213
+
214
+ if reduction == 'none':
215
+ output_tensor = loss
216
+ elif reduction == 'sum':
217
+ output_tensor = tf.reduce_sum(loss)
218
+ elif reduction == 'mean':
219
+ if weight_per_label is None:
220
+ output_tensor = tf.reduce_mean(loss)
221
+ else:
222
+ denom = tf.reduce_sum(weight_per_label)
223
+ output_tensor = tf.math.divide_no_nan(tf.reduce_sum(loss), denom)
224
+ else:
225
+ error(
226
+ f'SoftmaxCrossEntropyLoss reduction={reduction} is not supported.\n' +
227
+ f'graph_node.name: {graph_node.name}'
228
+ )
229
+ sys.exit(1)
230
+
231
+ if output_1_tf_dtype is not None and output_tensor.dtype != output_1_tf_dtype:
232
+ output_tensor = tf.cast(output_tensor, output_1_tf_dtype)
233
+
234
+ tf_layers_dict[graph_node_output_1.name]['tf_node'] = output_tensor
235
+
236
+ if graph_node_output_2 is not None:
237
+ log_prob_out = log_prob
238
+ if output_2_tf_dtype is not None and log_prob_out.dtype != output_2_tf_dtype:
239
+ log_prob_out = tf.cast(log_prob_out, output_2_tf_dtype)
240
+ tf_layers_dict[graph_node_output_2.name]['tf_node'] = log_prob_out
241
+
242
+ # Post-process transpose
243
+ tf_layers_dict[graph_node_output_1.name]['tf_node'] = post_process_transpose(
244
+ value_before_transpose=tf_layers_dict[graph_node_output_1.name]['tf_node'],
245
+ param_target='outputs',
246
+ param_name=graph_node.outputs[0].name,
247
+ **kwargs,
248
+ )
249
+ if graph_node_output_2 is not None:
250
+ tf_layers_dict[graph_node_output_2.name]['tf_node'] = post_process_transpose(
251
+ value_before_transpose=tf_layers_dict[graph_node_output_2.name]['tf_node'],
252
+ param_target='outputs',
253
+ param_name=graph_node.outputs[1].name,
254
+ **kwargs,
255
+ )
256
+
257
+ # Generation of Debug Info
258
+ tf_layers_dict[graph_node_output_1.name]['tf_node_info'] = \
259
+ make_tf_node_info(
260
+ node_info={
261
+ 'tf_op_type': 'SoftmaxCrossEntropyLoss',
262
+ 'tf_inputs': {
263
+ 'scores': scores_tensor,
264
+ 'labels': labels_tensor,
265
+ 'weights': weight_tensor,
266
+ 'reduction': reduction,
267
+ 'ignore_index': ignore_index,
268
+ },
269
+ 'tf_outputs': {
270
+ 'output': tf_layers_dict[graph_node_output_1.name]['tf_node'],
271
+ 'log_prob': tf_layers_dict[graph_node_output_2.name]['tf_node'] \
272
+ if graph_node_output_2 is not None else None,
273
+ },
274
+ }
275
+ )
276
+ if graph_node_output_2 is not None:
277
+ tf_layers_dict[graph_node_output_2.name]['tf_node_info'] = \
278
+ make_tf_node_info(
279
+ node_info={
280
+ 'tf_op_type': tf.nn.log_softmax,
281
+ 'tf_inputs': {
282
+ 'logits': scores_tensor,
283
+ 'axis': class_axis,
284
+ },
285
+ 'tf_outputs': {
286
+ 'output': tf_layers_dict[graph_node_output_2.name]['tf_node'],
287
+ },
288
+ }
289
+ )
@@ -0,0 +1,128 @@
1
+ import random
2
+ random.seed(0)
3
+ import numpy as np
4
+ np.random.seed(0)
5
+ import tensorflow as tf
6
+ import onnx_graphsurgeon as gs
7
+ from onnx2tf.utils.common_functions import (
8
+ get_constant_or_variable,
9
+ print_node_info,
10
+ inverted_operation_enable_disable,
11
+ make_tf_node_info,
12
+ get_replacement_parameter,
13
+ pre_process_transpose,
14
+ post_process_transpose,
15
+ )
16
+
17
+
18
+ def _as_tensor(value):
19
+ if isinstance(value, np.ndarray):
20
+ return tf.convert_to_tensor(value)
21
+ if isinstance(value, (np.generic, str, bytes)):
22
+ return tf.convert_to_tensor(value)
23
+ return value
24
+
25
+
26
+ @print_node_info
27
+ @inverted_operation_enable_disable
28
+ @get_replacement_parameter
29
+ def make_node(
30
+ *,
31
+ graph_node: gs.Node,
32
+ tf_layers_dict: dict,
33
+ **kwargs: dict,
34
+ ):
35
+ """StringConcat
36
+
37
+ Parameters
38
+ ----------
39
+ graph_node: gs.Node
40
+ graph_surgeon Node
41
+
42
+ tf_layers_dict: dict
43
+ optype, shape, dtype, tensorflow graph
44
+ """
45
+ before_op_output_shape_trans_1 = \
46
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
47
+ before_op_output_shape_trans_2 = \
48
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
49
+ before_op_output_shape_trans = \
50
+ before_op_output_shape_trans_1 \
51
+ and before_op_output_shape_trans_2
52
+
53
+ graph_node_input_1 = get_constant_or_variable(
54
+ graph_node.inputs[0],
55
+ before_op_output_shape_trans,
56
+ )
57
+ graph_node_input_2 = get_constant_or_variable(
58
+ graph_node.inputs[1],
59
+ before_op_output_shape_trans,
60
+ )
61
+ graph_node_output: gs.Variable = graph_node.outputs[0]
62
+ shape = graph_node_output.shape
63
+ dtype = graph_node_output.dtype
64
+
65
+ input_tensor_1 = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
66
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
67
+ input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
68
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
69
+
70
+ # Preserving Graph Structure (Dict)
71
+ tf_layers_dict[graph_node_output.name] = {
72
+ 'optype': graph_node.op,
73
+ 'shape': shape,
74
+ 'dtype': dtype,
75
+ }
76
+
77
+ # Pre-process transpose
78
+ input_tensor_1 = pre_process_transpose(
79
+ value_before_transpose=input_tensor_1,
80
+ param_target='inputs',
81
+ param_name=graph_node.inputs[0].name,
82
+ **kwargs,
83
+ )
84
+ input_tensor_2 = pre_process_transpose(
85
+ value_before_transpose=input_tensor_2,
86
+ param_target='inputs',
87
+ param_name=graph_node.inputs[1].name,
88
+ **kwargs,
89
+ )
90
+
91
+ # Generation of TF OP
92
+ input_tensor_1 = _as_tensor(input_tensor_1)
93
+ input_tensor_2 = _as_tensor(input_tensor_2)
94
+ if input_tensor_1.shape != input_tensor_2.shape:
95
+ out_shape = tf.broadcast_dynamic_shape(
96
+ tf.shape(input_tensor_1),
97
+ tf.shape(input_tensor_2),
98
+ )
99
+ input_tensor_1 = tf.broadcast_to(input_tensor_1, out_shape)
100
+ input_tensor_2 = tf.broadcast_to(input_tensor_2, out_shape)
101
+
102
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
103
+ tf.strings.join(
104
+ inputs=[input_tensor_1, input_tensor_2],
105
+ name=graph_node.name,
106
+ )
107
+
108
+ # Post-process transpose
109
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
110
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
111
+ param_target='outputs',
112
+ param_name=graph_node.outputs[0].name,
113
+ **kwargs,
114
+ )
115
+
116
+ # Generation of Debug Info
117
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
118
+ make_tf_node_info(
119
+ node_info={
120
+ 'tf_op_type': tf.strings.join,
121
+ 'tf_inputs': {
122
+ 'inputs': [input_tensor_1, input_tensor_2],
123
+ },
124
+ 'tf_outputs': {
125
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
126
+ },
127
+ }
128
+ )
@@ -31,50 +31,65 @@ class StringNormalizer(tf_keras.layers.Layer):
31
31
  self.case_change_action = case_change_action
32
32
  self.is_case_sensitive = is_case_sensitive
33
33
  self.locale = locale
34
- self.stopwords = set(stopwords)
34
+ self.stopwords = list(stopwords) if stopwords is not None else []
35
+
36
+ def _apply_case_action(self, inputs):
37
+ if self.case_change_action == "LOWER":
38
+ return tf.strings.lower(inputs)
39
+ if self.case_change_action == "UPPER":
40
+ return tf.strings.upper(inputs)
41
+ return inputs
42
+
43
+ def _stopword_mask(self, inputs):
44
+ if len(self.stopwords) == 0:
45
+ return tf.ones_like(inputs, dtype=tf.bool)
46
+ stopwords = tf.constant(self.stopwords, dtype=tf.string)
47
+ compare_inputs = inputs
48
+ compare_stopwords = stopwords
49
+ if not self.is_case_sensitive:
50
+ compare_inputs = tf.strings.lower(inputs)
51
+ compare_stopwords = tf.strings.lower(stopwords)
52
+ matches = tf.reduce_any(
53
+ tf.equal(
54
+ tf.expand_dims(compare_inputs, axis=-1),
55
+ compare_stopwords,
56
+ ),
57
+ axis=-1,
58
+ )
59
+ return tf.logical_not(matches)
35
60
 
36
61
  def call(self, inputs):
37
- if not self.is_case_sensitive:
38
- # if self.locale:
39
- # inputs = text.case_fold_utf8(inputs)
40
- # else:
41
- # inputs = tf.strings.lower(inputs)
42
- inputs = tf.strings.lower(inputs)
43
- elif self.case_change_action == "LOWER":
44
- inputs = tf.strings.lower(inputs)
45
- elif self.case_change_action == "UPPER":
46
- inputs = tf.strings.upper(inputs)
47
-
48
- # if self.tokenizer:
49
- # tokenized = self.tokenizer.tokenize(inputs)
50
- # else:
51
- # tokenized = tf.strings.split(inputs)
52
- tokenized = tf.strings.split(inputs)
62
+ def process_1d():
63
+ mask = self._stopword_mask(inputs)
64
+ filtered = tf.boolean_mask(inputs, mask)
65
+ filtered = self._apply_case_action(filtered)
66
+ return tf.cond(
67
+ tf.equal(tf.size(filtered), 0),
68
+ lambda: tf.constant([""], dtype=tf.string),
69
+ lambda: filtered,
70
+ )
53
71
 
54
- if not self.is_case_sensitive:
55
- # stopwords = [
56
- # tf.strings.lower(word) \
57
- # if self.locale is None else text.case_fold_utf8(word) \
58
- # for word in self.stopwords
59
- # ]
60
- stopwords = [
61
- tf.strings.lower(word) for word in self.stopwords
62
- ]
63
- else:
64
- stopwords = self.stopwords
65
-
66
- return \
67
- tf.ragged.boolean_mask(
68
- tokenized,
69
- ~tf.reduce_any(
70
- tf.equal(
71
- tf.expand_dims(tokenized, axis=-1),
72
- stopwords,
73
- ),
74
- axis=-1,
75
- )
72
+ def process_2d():
73
+ row = inputs[0]
74
+ mask = self._stopword_mask(row)
75
+ filtered = tf.boolean_mask(row, mask)
76
+ filtered = self._apply_case_action(filtered)
77
+ filtered = tf.expand_dims(filtered, axis=0)
78
+ return tf.cond(
79
+ tf.equal(tf.size(filtered), 0),
80
+ lambda: tf.constant([[""]], dtype=tf.string),
81
+ lambda: filtered,
76
82
  )
77
83
 
84
+ input_rank = inputs.shape.rank
85
+ if input_rank is None:
86
+ input_rank = tf.rank(inputs)
87
+ if isinstance(input_rank, int):
88
+ if input_rank == 1:
89
+ return process_1d()
90
+ return process_2d()
91
+ return tf.cond(tf.equal(input_rank, 1), process_1d, process_2d)
92
+
78
93
 
79
94
  @print_node_info
80
95
  @inverted_operation_enable_disable
@@ -0,0 +1,156 @@
1
+ import random
2
+ random.seed(0)
3
+ import numpy as np
4
+ np.random.seed(0)
5
+ import tensorflow as tf
6
+ import onnx_graphsurgeon as gs
7
+ from onnx2tf.utils.common_functions import (
8
+ get_constant_or_variable,
9
+ print_node_info,
10
+ inverted_operation_enable_disable,
11
+ make_tf_node_info,
12
+ get_replacement_parameter,
13
+ pre_process_transpose,
14
+ post_process_transpose,
15
+ )
16
+
17
+
18
+ def _as_tensor(value):
19
+ if isinstance(value, np.ndarray):
20
+ return tf.convert_to_tensor(value)
21
+ if isinstance(value, (np.generic, str, bytes)):
22
+ return tf.convert_to_tensor(value)
23
+ return value
24
+
25
+
26
+ def _normalize_delimiter(delimiter):
27
+ if delimiter is None:
28
+ return None
29
+ if isinstance(delimiter, bytes):
30
+ delimiter = delimiter.decode('utf-8')
31
+ if delimiter == "":
32
+ return None
33
+ return delimiter
34
+
35
+
36
+ @print_node_info
37
+ @inverted_operation_enable_disable
38
+ @get_replacement_parameter
39
+ def make_node(
40
+ *,
41
+ graph_node: gs.Node,
42
+ tf_layers_dict: dict,
43
+ **kwargs: dict,
44
+ ):
45
+ """StringSplit
46
+
47
+ Parameters
48
+ ----------
49
+ graph_node: gs.Node
50
+ graph_surgeon Node
51
+
52
+ tf_layers_dict: dict
53
+ optype, shape, dtype, tensorflow graph
54
+ """
55
+ before_op_output_shape_trans_1 = \
56
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
57
+ before_op_output_shape_trans = \
58
+ before_op_output_shape_trans_1
59
+
60
+ graph_node_input = get_constant_or_variable(
61
+ graph_node.inputs[0],
62
+ before_op_output_shape_trans,
63
+ )
64
+ graph_node_output_1: gs.Variable = graph_node.outputs[0]
65
+ graph_node_output_2: gs.Variable = graph_node.outputs[1]
66
+ output_1_shape = graph_node_output_1.shape
67
+ output_1_dtype = graph_node_output_1.dtype
68
+ output_2_shape = graph_node_output_2.shape
69
+ output_2_dtype = graph_node_output_2.dtype
70
+
71
+ input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
72
+ if isinstance(graph_node_input, gs.Variable) else graph_node_input
73
+
74
+ # Preserving Graph Structure (Dict)
75
+ tf_layers_dict[graph_node_output_1.name] = {
76
+ 'optype': graph_node.op,
77
+ 'shape': output_1_shape,
78
+ 'dtype': output_1_dtype,
79
+ }
80
+ tf_layers_dict[graph_node_output_2.name] = {
81
+ 'optype': graph_node.op,
82
+ 'shape': output_2_shape,
83
+ 'dtype': output_2_dtype,
84
+ }
85
+
86
+ # Pre-process transpose
87
+ input_tensor = pre_process_transpose(
88
+ value_before_transpose=input_tensor,
89
+ param_target='inputs',
90
+ param_name=graph_node.inputs[0].name,
91
+ **kwargs,
92
+ )
93
+
94
+ # Generation of TF OP
95
+ input_tensor = _as_tensor(input_tensor)
96
+ delimiter = _normalize_delimiter(graph_node.attrs.get('delimiter', None))
97
+ maxsplit = graph_node.attrs.get('maxsplit', None)
98
+ if maxsplit is None:
99
+ maxsplit = -1
100
+
101
+ split_rt = tf.strings.split(
102
+ input=input_tensor,
103
+ sep=delimiter,
104
+ maxsplit=maxsplit,
105
+ )
106
+ output_strings = split_rt.to_tensor(default_value="")
107
+ output_counts = split_rt.row_lengths()
108
+ output_counts = tf.reshape(output_counts, tf.shape(input_tensor))
109
+ output_counts = tf.cast(output_counts, tf.int64)
110
+
111
+ tf_layers_dict[graph_node_output_1.name]['tf_node'] = output_strings
112
+ tf_layers_dict[graph_node_output_2.name]['tf_node'] = output_counts
113
+
114
+ # Post-process transpose
115
+ tf_layers_dict[graph_node_output_1.name]['tf_node'] = post_process_transpose(
116
+ value_before_transpose=tf_layers_dict[graph_node_output_1.name]['tf_node'],
117
+ param_target='outputs',
118
+ param_name=graph_node.outputs[0].name,
119
+ **kwargs,
120
+ )
121
+ tf_layers_dict[graph_node_output_2.name]['tf_node'] = post_process_transpose(
122
+ value_before_transpose=tf_layers_dict[graph_node_output_2.name]['tf_node'],
123
+ param_target='outputs',
124
+ param_name=graph_node.outputs[1].name,
125
+ **kwargs,
126
+ )
127
+
128
+ # Generation of Debug Info
129
+ tf_layers_dict[graph_node_output_1.name]['tf_node_info'] = \
130
+ make_tf_node_info(
131
+ node_info={
132
+ 'tf_op_type': tf.strings.split,
133
+ 'tf_inputs': {
134
+ 'input': input_tensor,
135
+ 'sep': delimiter,
136
+ 'maxsplit': maxsplit,
137
+ },
138
+ 'tf_outputs': {
139
+ 'output': tf_layers_dict[graph_node_output_1.name]['tf_node'],
140
+ },
141
+ }
142
+ )
143
+ tf_layers_dict[graph_node_output_2.name]['tf_node_info'] = \
144
+ make_tf_node_info(
145
+ node_info={
146
+ 'tf_op_type': tf.strings.split,
147
+ 'tf_inputs': {
148
+ 'input': input_tensor,
149
+ 'sep': delimiter,
150
+ 'maxsplit': maxsplit,
151
+ },
152
+ 'tf_outputs': {
153
+ 'output': tf_layers_dict[graph_node_output_2.name]['tf_node'],
154
+ },
155
+ }
156
+ )