onnx2tf 1.29.17__py3-none-any.whl → 1.29.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from onnx2tf.onnx2tf import convert, main
2
2
 
3
- __version__ = '1.29.17'
3
+ __version__ = '1.29.18'
@@ -0,0 +1,147 @@
1
+ import random
2
+ random.seed(0)
3
+ import numpy as np
4
+ np.random.seed(0)
5
+ import tensorflow as tf
6
+ import onnx_graphsurgeon as gs
7
+ import cv2
8
+ from onnx2tf.utils.common_functions import (
9
+ get_constant_or_variable,
10
+ print_node_info,
11
+ inverted_operation_enable_disable,
12
+ make_tf_node_info,
13
+ get_replacement_parameter,
14
+ pre_process_transpose,
15
+ post_process_transpose,
16
+ )
17
+ from onnx2tf.utils.logging import *
18
+
19
+
20
+ def _as_tensor(value):
21
+ if isinstance(value, np.ndarray):
22
+ return tf.convert_to_tensor(value)
23
+ if isinstance(value, (np.generic, int, float, bool, str, bytes)):
24
+ return tf.convert_to_tensor(value)
25
+ return value
26
+
27
+
28
+ def _decode_image_np(encoded_stream, pixel_format):
29
+ if encoded_stream is None:
30
+ return np.zeros((0, 0, 0), dtype=np.uint8)
31
+ if encoded_stream.dtype != np.uint8:
32
+ encoded_stream = encoded_stream.astype(np.uint8)
33
+ if encoded_stream.size == 0:
34
+ return np.zeros((0, 0, 0), dtype=np.uint8)
35
+ if encoded_stream.ndim != 1:
36
+ encoded_stream = encoded_stream.reshape(-1)
37
+ try:
38
+ if pixel_format == 'Grayscale':
39
+ flag = cv2.IMREAD_GRAYSCALE
40
+ else:
41
+ flag = cv2.IMREAD_COLOR
42
+ decoded = cv2.imdecode(encoded_stream, flag)
43
+ if decoded is None:
44
+ raise ValueError('cv2.imdecode failed')
45
+ if pixel_format == 'RGB':
46
+ decoded = cv2.cvtColor(decoded, cv2.COLOR_BGR2RGB)
47
+ if pixel_format == 'Grayscale' and decoded.ndim == 2:
48
+ decoded = decoded[..., np.newaxis]
49
+ return decoded.astype(np.uint8)
50
+ except Exception:
51
+ if pixel_format == 'Grayscale':
52
+ return np.zeros((0, 0, 1), dtype=np.uint8)
53
+ return np.zeros((0, 0, 3), dtype=np.uint8)
54
+
55
+
56
+ @print_node_info
57
+ @inverted_operation_enable_disable
58
+ @get_replacement_parameter
59
+ def make_node(
60
+ *,
61
+ graph_node: gs.Node,
62
+ tf_layers_dict: dict,
63
+ **kwargs: dict,
64
+ ):
65
+ """ImageDecoder
66
+
67
+ Parameters
68
+ ----------
69
+ graph_node: gs.Node
70
+ graph_surgeon Node
71
+
72
+ tf_layers_dict: dict
73
+ optype, shape, dtype, tensorflow graph
74
+ """
75
+ before_op_output_shape_trans_1 = \
76
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
77
+ before_op_output_shape_trans = \
78
+ before_op_output_shape_trans_1
79
+
80
+ graph_node_input = get_constant_or_variable(
81
+ graph_node.inputs[0],
82
+ before_op_output_shape_trans,
83
+ )
84
+ graph_node_output: gs.Variable = graph_node.outputs[0]
85
+ shape = graph_node_output.shape
86
+ dtype = graph_node_output.dtype
87
+
88
+ input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
89
+ if isinstance(graph_node_input, gs.Variable) else graph_node_input
90
+
91
+ # Preserving Graph Structure (Dict)
92
+ tf_layers_dict[graph_node_output.name] = {
93
+ 'optype': graph_node.op,
94
+ 'shape': shape,
95
+ 'dtype': dtype,
96
+ }
97
+
98
+ # Pre-process transpose
99
+ input_tensor = pre_process_transpose(
100
+ value_before_transpose=input_tensor,
101
+ param_target='inputs',
102
+ param_name=graph_node.inputs[0].name,
103
+ **kwargs,
104
+ )
105
+
106
+ # Generation of TF OP
107
+ input_tensor = _as_tensor(input_tensor)
108
+ pixel_format = graph_node.attrs.get('pixel_format', 'RGB')
109
+ if pixel_format not in ['RGB', 'BGR', 'Grayscale']:
110
+ error(
111
+ f'ImageDecoder pixel_format={pixel_format} is not supported.\n' +
112
+ f'graph_node.name: {graph_node.name}'
113
+ )
114
+ pixel_format = 'RGB'
115
+
116
+ decoded = tf.numpy_function(
117
+ func=lambda x: _decode_image_np(x, pixel_format),
118
+ inp=[input_tensor],
119
+ Tout=tf.uint8,
120
+ name=graph_node.name,
121
+ )
122
+ channels = 1 if pixel_format == 'Grayscale' else 3
123
+ decoded = tf.ensure_shape(decoded, [None, None, channels])
124
+ tf_layers_dict[graph_node_output.name]['tf_node'] = decoded
125
+
126
+ # Post-process transpose
127
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
128
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
129
+ param_target='outputs',
130
+ param_name=graph_node.outputs[0].name,
131
+ **kwargs,
132
+ )
133
+
134
+ # Generation of Debug Info
135
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
136
+ make_tf_node_info(
137
+ node_info={
138
+ 'tf_op_type': 'ImageDecoder',
139
+ 'tf_inputs': {
140
+ 'encoded_stream': input_tensor,
141
+ 'pixel_format': pixel_format,
142
+ },
143
+ 'tf_outputs': {
144
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
145
+ },
146
+ }
147
+ )
@@ -0,0 +1,237 @@
1
+ import sys
2
+ import random
3
+ random.seed(0)
4
+ import numpy as np
5
+ np.random.seed(0)
6
+ import tensorflow as tf
7
+ import onnx_graphsurgeon as gs
8
+ from onnx2tf.utils.common_functions import (
9
+ get_constant_or_variable,
10
+ convert_axis,
11
+ print_node_info,
12
+ inverted_operation_enable_disable,
13
+ make_tf_node_info,
14
+ get_replacement_parameter,
15
+ pre_process_transpose,
16
+ post_process_transpose,
17
+ )
18
+ from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
19
+ from onnx2tf.utils.logging import *
20
+
21
+
22
+ def _as_tensor(value):
23
+ if isinstance(value, np.ndarray):
24
+ return tf.convert_to_tensor(value)
25
+ if isinstance(value, (np.generic, int, float, bool, str, bytes)):
26
+ return tf.convert_to_tensor(value)
27
+ return value
28
+
29
+
30
+ def _move_class_to_last(tensor, class_axis):
31
+ rank = tensor.shape.rank
32
+ if rank is None:
33
+ rank = tf.rank(tensor)
34
+ if isinstance(rank, int):
35
+ if class_axis == rank - 1:
36
+ return tensor, None
37
+ perm = [i for i in range(rank) if i != class_axis] + [class_axis]
38
+ return tf.transpose(tensor, perm=perm), perm
39
+ return tensor, None
40
+
41
+
42
+ @print_node_info
43
+ @inverted_operation_enable_disable
44
+ @get_replacement_parameter
45
+ def make_node(
46
+ *,
47
+ graph_node: gs.Node,
48
+ tf_layers_dict: dict,
49
+ **kwargs: dict,
50
+ ):
51
+ """NegativeLogLikelihoodLoss
52
+
53
+ Parameters
54
+ ----------
55
+ graph_node: gs.Node
56
+ graph_surgeon Node
57
+
58
+ tf_layers_dict: dict
59
+ optype, shape, dtype, tensorflow graph
60
+ """
61
+ before_op_output_shape_trans_1 = \
62
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
63
+ before_op_output_shape_trans_2 = \
64
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
65
+ before_op_output_shape_trans = \
66
+ before_op_output_shape_trans_1 \
67
+ and before_op_output_shape_trans_2
68
+ if len(graph_node.inputs) >= 3:
69
+ before_op_output_shape_trans_3 = \
70
+ tf_layers_dict.get(graph_node.inputs[2].name, {}).get('before_op_output_shape_trans', True)
71
+ before_op_output_shape_trans = \
72
+ before_op_output_shape_trans \
73
+ and before_op_output_shape_trans_3
74
+
75
+ graph_node_input_1 = get_constant_or_variable(
76
+ graph_node.inputs[0],
77
+ before_op_output_shape_trans,
78
+ )
79
+ graph_node_input_2 = get_constant_or_variable(
80
+ graph_node.inputs[1],
81
+ before_op_output_shape_trans,
82
+ )
83
+ graph_node_input_3 = None
84
+ if len(graph_node.inputs) >= 3:
85
+ graph_node_input_3 = get_constant_or_variable(
86
+ graph_node.inputs[2],
87
+ before_op_output_shape_trans,
88
+ )
89
+
90
+ graph_node_output: gs.Variable = graph_node.outputs[0]
91
+ shape = graph_node_output.shape
92
+ dtype = graph_node_output.dtype
93
+ output_tf_dtype = NUMPY_DTYPES_TO_TF_DTYPES[dtype] \
94
+ if isinstance(dtype, np.dtype) else dtype
95
+
96
+ input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
97
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
98
+ target_tensor = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
99
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
100
+ weight_tensor = None
101
+ if graph_node_input_3 is not None:
102
+ weight_tensor = tf_layers_dict[graph_node_input_3.name]['tf_node'] \
103
+ if isinstance(graph_node_input_3, gs.Variable) else graph_node_input_3
104
+
105
+ reduction = graph_node.attrs.get('reduction', 'mean')
106
+ ignore_index = graph_node.attrs.get('ignore_index', None)
107
+
108
+ input_rank = len(input_tensor.shape)
109
+ class_axis = convert_axis(
110
+ axis=1,
111
+ tensor_rank=input_rank,
112
+ before_op_output_shape_trans=before_op_output_shape_trans,
113
+ )
114
+
115
+ # Preserving Graph Structure (Dict)
116
+ output_entry = {
117
+ 'optype': graph_node.op,
118
+ 'shape': shape,
119
+ 'dtype': dtype,
120
+ }
121
+ if reduction == 'none':
122
+ output_entry['nhwc'] = tf_layers_dict[graph_node_input_1.name]['nhwc'] \
123
+ if isinstance(graph_node_input_1, gs.Variable) \
124
+ and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
125
+ tf_layers_dict[graph_node_output.name] = output_entry
126
+
127
+ # Pre-process transpose
128
+ input_tensor = pre_process_transpose(
129
+ value_before_transpose=input_tensor,
130
+ param_target='inputs',
131
+ param_name=graph_node.inputs[0].name,
132
+ **kwargs,
133
+ )
134
+ target_tensor = pre_process_transpose(
135
+ value_before_transpose=target_tensor,
136
+ param_target='inputs',
137
+ param_name=graph_node.inputs[1].name,
138
+ **kwargs,
139
+ )
140
+ if weight_tensor is not None:
141
+ weight_tensor = pre_process_transpose(
142
+ value_before_transpose=weight_tensor,
143
+ param_target='inputs',
144
+ param_name=graph_node.inputs[2].name,
145
+ **kwargs,
146
+ )
147
+
148
+ # Generation of TF OP
149
+ input_tensor = _as_tensor(input_tensor)
150
+ target_tensor = _as_tensor(target_tensor)
151
+ if weight_tensor is not None:
152
+ weight_tensor = _as_tensor(weight_tensor)
153
+
154
+ log_prob = input_tensor
155
+ log_prob, _ = _move_class_to_last(log_prob, class_axis)
156
+
157
+ depth = log_prob.shape[-1]
158
+ if depth is None:
159
+ depth = tf.shape(log_prob)[-1]
160
+ depth = tf.cast(depth, tf.int32)
161
+
162
+ labels = tf.cast(target_tensor, tf.int32)
163
+ if ignore_index is not None:
164
+ ignore_index_val = tf.cast(ignore_index, labels.dtype)
165
+ mask = tf.equal(labels, ignore_index_val)
166
+ labels_safe = tf.where(mask, tf.zeros_like(labels), labels)
167
+ else:
168
+ mask = None
169
+ labels_safe = labels
170
+
171
+ one_hot = tf.one_hot(
172
+ indices=labels_safe,
173
+ depth=depth,
174
+ axis=-1,
175
+ dtype=log_prob.dtype,
176
+ )
177
+ selected = tf.reduce_sum(log_prob * one_hot, axis=-1)
178
+ loss = -selected
179
+
180
+ weight_per_label = None
181
+ if weight_tensor is not None:
182
+ weight_per_label = tf.gather(weight_tensor, labels_safe)
183
+ weight_per_label = tf.cast(weight_per_label, loss.dtype)
184
+ if mask is not None:
185
+ weight_per_label = tf.where(mask, tf.zeros_like(weight_per_label), weight_per_label)
186
+ loss = loss * weight_per_label
187
+
188
+ if mask is not None:
189
+ loss = tf.where(mask, tf.zeros_like(loss), loss)
190
+
191
+ if reduction == 'none':
192
+ output_tensor = loss
193
+ elif reduction == 'sum':
194
+ output_tensor = tf.reduce_sum(loss)
195
+ elif reduction == 'mean':
196
+ if weight_per_label is None:
197
+ output_tensor = tf.reduce_mean(loss)
198
+ else:
199
+ denom = tf.reduce_sum(weight_per_label)
200
+ output_tensor = tf.math.divide_no_nan(tf.reduce_sum(loss), denom)
201
+ else:
202
+ error(
203
+ f'NegativeLogLikelihoodLoss reduction={reduction} is not supported.\n' +
204
+ f'graph_node.name: {graph_node.name}'
205
+ )
206
+ sys.exit(1)
207
+
208
+ if output_tf_dtype is not None and output_tensor.dtype != output_tf_dtype:
209
+ output_tensor = tf.cast(output_tensor, output_tf_dtype)
210
+
211
+ tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
212
+
213
+ # Post-process transpose
214
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
215
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
216
+ param_target='outputs',
217
+ param_name=graph_node.outputs[0].name,
218
+ **kwargs,
219
+ )
220
+
221
+ # Generation of Debug Info
222
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
223
+ make_tf_node_info(
224
+ node_info={
225
+ 'tf_op_type': 'NegativeLogLikelihoodLoss',
226
+ 'tf_inputs': {
227
+ 'input': input_tensor,
228
+ 'target': target_tensor,
229
+ 'weight': weight_tensor,
230
+ 'reduction': reduction,
231
+ 'ignore_index': ignore_index,
232
+ },
233
+ 'tf_outputs': {
234
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
235
+ },
236
+ }
237
+ )
@@ -0,0 +1,175 @@
1
+ import sys
2
+ import random
3
+ random.seed(0)
4
+ import numpy as np
5
+ np.random.seed(0)
6
+ import tensorflow as tf
7
+ import onnx_graphsurgeon as gs
8
+ from onnx2tf.utils.common_functions import (
9
+ get_constant_or_variable,
10
+ convert_axis,
11
+ print_node_info,
12
+ inverted_operation_enable_disable,
13
+ make_tf_node_info,
14
+ get_replacement_parameter,
15
+ pre_process_transpose,
16
+ post_process_transpose,
17
+ )
18
+ from onnx2tf.utils.enums import NUMPY_DTYPES_TO_TF_DTYPES
19
+ from onnx2tf.utils.logging import *
20
+
21
+
22
+ def _as_tensor(value):
23
+ if isinstance(value, np.ndarray):
24
+ return tf.convert_to_tensor(value)
25
+ if isinstance(value, (np.generic, int, float, bool, str, bytes)):
26
+ return tf.convert_to_tensor(value)
27
+ return value
28
+
29
+
30
+ @print_node_info
31
+ @inverted_operation_enable_disable
32
+ @get_replacement_parameter
33
+ def make_node(
34
+ *,
35
+ graph_node: gs.Node,
36
+ tf_layers_dict: dict,
37
+ **kwargs: dict,
38
+ ):
39
+ """RMSNormalization
40
+
41
+ Parameters
42
+ ----------
43
+ graph_node: gs.Node
44
+ graph_surgeon Node
45
+
46
+ tf_layers_dict: dict
47
+ optype, shape, dtype, tensorflow graph
48
+ """
49
+ before_op_output_shape_trans_1 = \
50
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
51
+ before_op_output_shape_trans_2 = \
52
+ tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
53
+ before_op_output_shape_trans = \
54
+ before_op_output_shape_trans_1 \
55
+ and before_op_output_shape_trans_2
56
+
57
+ graph_node_input_1 = get_constant_or_variable(
58
+ graph_node.inputs[0],
59
+ before_op_output_shape_trans,
60
+ )
61
+ graph_node_input_2 = get_constant_or_variable(
62
+ graph_node.inputs[1],
63
+ before_op_output_shape_trans,
64
+ )
65
+ graph_node_output: gs.Variable = graph_node.outputs[0]
66
+ shape = graph_node_output.shape
67
+ dtype = graph_node_output.dtype
68
+ output_tf_dtype = NUMPY_DTYPES_TO_TF_DTYPES[dtype] \
69
+ if isinstance(dtype, np.dtype) else dtype
70
+
71
+ input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
72
+ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
73
+ scale_tensor = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
74
+ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
75
+
76
+ # Preserving Graph Structure (Dict)
77
+ tf_layers_dict[graph_node_output.name] = {
78
+ 'optype': graph_node.op,
79
+ 'shape': shape,
80
+ 'dtype': dtype,
81
+ 'nhwc': tf_layers_dict[graph_node_input_1.name]['nhwc'] \
82
+ if isinstance(graph_node_input_1, gs.Variable) \
83
+ and 'nhwc' in tf_layers_dict[graph_node_input_1.name].keys() else False
84
+ }
85
+
86
+ # Pre-process transpose
87
+ input_tensor = pre_process_transpose(
88
+ value_before_transpose=input_tensor,
89
+ param_target='inputs',
90
+ param_name=graph_node.inputs[0].name,
91
+ **kwargs,
92
+ )
93
+ scale_tensor = pre_process_transpose(
94
+ value_before_transpose=scale_tensor,
95
+ param_target='inputs',
96
+ param_name=graph_node.inputs[1].name,
97
+ **kwargs,
98
+ )
99
+
100
+ # Generation of TF OP
101
+ input_tensor = _as_tensor(input_tensor)
102
+ scale_tensor = _as_tensor(scale_tensor)
103
+
104
+ input_rank = input_tensor.shape.rank
105
+ if input_rank is None and graph_node.inputs[0].shape is not None:
106
+ input_rank = len(graph_node.inputs[0].shape)
107
+ if input_rank is None:
108
+ error(
109
+ f'RMSNormalization requires known input rank.\n' +
110
+ f'graph_node.name: {graph_node.name}'
111
+ )
112
+ sys.exit(1)
113
+
114
+ axis = graph_node.attrs.get('axis', -1)
115
+ axis = convert_axis(
116
+ axis=axis,
117
+ tensor_rank=input_rank,
118
+ before_op_output_shape_trans=before_op_output_shape_trans,
119
+ )
120
+ epsilon = graph_node.attrs.get('epsilon', 1e-05)
121
+ stash_type = int(graph_node.attrs.get('stash_type', 1))
122
+
123
+ axes = list(range(axis, input_rank))
124
+
125
+ compute_dtype = input_tensor.dtype
126
+ if stash_type == 1 and input_tensor.dtype != tf.float32:
127
+ compute_dtype = tf.float32
128
+
129
+ x = tf.cast(input_tensor, compute_dtype)
130
+ xsquared = tf.math.square(x)
131
+ xsquared_mean = tf.reduce_mean(xsquared, axis=axes, keepdims=True)
132
+ rms = tf.sqrt(xsquared_mean + tf.cast(epsilon, compute_dtype))
133
+ normalized = x / rms
134
+
135
+ if compute_dtype != input_tensor.dtype:
136
+ normalized = tf.cast(normalized, input_tensor.dtype)
137
+ if scale_tensor.dtype != normalized.dtype:
138
+ scale_tensor = tf.cast(scale_tensor, normalized.dtype)
139
+
140
+ output_tensor = tf.math.multiply(
141
+ normalized,
142
+ scale_tensor,
143
+ name=graph_node.name,
144
+ )
145
+
146
+ if output_tf_dtype is not None and output_tensor.dtype != output_tf_dtype:
147
+ output_tensor = tf.cast(output_tensor, output_tf_dtype)
148
+
149
+ tf_layers_dict[graph_node_output.name]['tf_node'] = output_tensor
150
+
151
+ # Post-process transpose
152
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
153
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
154
+ param_target='outputs',
155
+ param_name=graph_node.outputs[0].name,
156
+ **kwargs,
157
+ )
158
+
159
+ # Generation of Debug Info
160
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
161
+ make_tf_node_info(
162
+ node_info={
163
+ 'tf_op_type': 'RMSNormalization',
164
+ 'tf_inputs': {
165
+ 'input': input_tensor,
166
+ 'scale': scale_tensor,
167
+ 'axis': axis,
168
+ 'epsilon': epsilon,
169
+ 'stash_type': stash_type,
170
+ },
171
+ 'tf_outputs': {
172
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
173
+ },
174
+ }
175
+ )
@@ -0,0 +1,108 @@
1
+ import random
2
+ random.seed(0)
3
+ import numpy as np
4
+ np.random.seed(0)
5
+ import tensorflow as tf
6
+ import onnx_graphsurgeon as gs
7
+ from onnx2tf.utils.common_functions import (
8
+ get_constant_or_variable,
9
+ print_node_info,
10
+ inverted_operation_enable_disable,
11
+ make_tf_node_info,
12
+ get_replacement_parameter,
13
+ pre_process_transpose,
14
+ post_process_transpose,
15
+ )
16
+
17
+
18
+ def _as_tensor(value):
19
+ if isinstance(value, np.ndarray):
20
+ return tf.convert_to_tensor(value)
21
+ if isinstance(value, (np.generic, int, float, bool, str, bytes)):
22
+ return tf.convert_to_tensor(value)
23
+ return value
24
+
25
+
26
+ @print_node_info
27
+ @inverted_operation_enable_disable
28
+ @get_replacement_parameter
29
+ def make_node(
30
+ *,
31
+ graph_node: gs.Node,
32
+ tf_layers_dict: dict,
33
+ **kwargs: dict,
34
+ ):
35
+ """RegexFullMatch
36
+
37
+ Parameters
38
+ ----------
39
+ graph_node: gs.Node
40
+ graph_surgeon Node
41
+
42
+ tf_layers_dict: dict
43
+ optype, shape, dtype, tensorflow graph
44
+ """
45
+ before_op_output_shape_trans_1 = \
46
+ tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
47
+ before_op_output_shape_trans = \
48
+ before_op_output_shape_trans_1
49
+
50
+ graph_node_input = get_constant_or_variable(
51
+ graph_node.inputs[0],
52
+ before_op_output_shape_trans,
53
+ )
54
+ graph_node_output: gs.Variable = graph_node.outputs[0]
55
+ shape = graph_node_output.shape
56
+ dtype = graph_node_output.dtype
57
+
58
+ input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
59
+ if isinstance(graph_node_input, gs.Variable) else graph_node_input
60
+
61
+ # Preserving Graph Structure (Dict)
62
+ tf_layers_dict[graph_node_output.name] = {
63
+ 'optype': graph_node.op,
64
+ 'shape': shape,
65
+ 'dtype': dtype,
66
+ }
67
+
68
+ # Pre-process transpose
69
+ input_tensor = pre_process_transpose(
70
+ value_before_transpose=input_tensor,
71
+ param_target='inputs',
72
+ param_name=graph_node.inputs[0].name,
73
+ **kwargs,
74
+ )
75
+
76
+ # Generation of TF OP
77
+ input_tensor = _as_tensor(input_tensor)
78
+ pattern = graph_node.attrs.get('pattern', '')
79
+
80
+ tf_layers_dict[graph_node_output.name]['tf_node'] = \
81
+ tf.strings.regex_full_match(
82
+ input=input_tensor,
83
+ pattern=pattern,
84
+ name=graph_node.name,
85
+ )
86
+
87
+ # Post-process transpose
88
+ tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
89
+ value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
90
+ param_target='outputs',
91
+ param_name=graph_node.outputs[0].name,
92
+ **kwargs,
93
+ )
94
+
95
+ # Generation of Debug Info
96
+ tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
97
+ make_tf_node_info(
98
+ node_info={
99
+ 'tf_op_type': tf.strings.regex_full_match,
100
+ 'tf_inputs': {
101
+ 'input': input_tensor,
102
+ 'pattern': pattern,
103
+ },
104
+ 'tf_outputs': {
105
+ 'output': tf_layers_dict[graph_node_output.name]['tf_node'],
106
+ },
107
+ }
108
+ )