onnx2tf 1.29.16__py3-none-any.whl → 1.29.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/ops/Add.py +112 -0
- onnx2tf/ops/Concat.py +169 -23
- onnx2tf/ops/ImageDecoder.py +147 -0
- onnx2tf/ops/NegativeLogLikelihoodLoss.py +237 -0
- onnx2tf/ops/RMSNormalization.py +175 -0
- onnx2tf/ops/RegexFullMatch.py +108 -0
- onnx2tf/ops/RotaryEmbedding.py +285 -0
- onnx2tf/ops/Scan.py +438 -0
- onnx2tf/ops/SoftmaxCrossEntropyLoss.py +289 -0
- onnx2tf/ops/StringConcat.py +128 -0
- onnx2tf/ops/StringNormalizer.py +54 -39
- onnx2tf/ops/StringSplit.py +156 -0
- onnx2tf/ops/TensorScatter.py +223 -0
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/METADATA +13 -12
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/RECORD +18 -8
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/WHEEL +1 -1
- {onnx2tf-1.29.16.dist-info → onnx2tf-1.29.18.dist-info}/entry_points.txt +0 -0
onnx2tf/__init__.py
CHANGED
onnx2tf/ops/Add.py
CHANGED
|
@@ -21,6 +21,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
21
21
|
disable_unnecessary_transpose,
|
|
22
22
|
shape_unmatched_special_avoidance_workaround,
|
|
23
23
|
merge_two_consecutive_identical_ops_into_one,
|
|
24
|
+
transpose_with_flexing_deterrence,
|
|
24
25
|
deterring_shape_corruption_due_to_broadcast,
|
|
25
26
|
acquisition_of_validation_data,
|
|
26
27
|
onnx_tf_tensor_validation,
|
|
@@ -297,6 +298,117 @@ def make_node(
|
|
|
297
298
|
)
|
|
298
299
|
tf_type = tf.identity
|
|
299
300
|
|
|
301
|
+
def _normalize_dim(dim):
|
|
302
|
+
return int(dim) if isinstance(dim, (int, np.integer)) else None
|
|
303
|
+
|
|
304
|
+
def _get_static_shape(tensor):
|
|
305
|
+
shape = getattr(tensor, 'shape', None)
|
|
306
|
+
if shape is None or shape == tf.TensorShape(None):
|
|
307
|
+
return None
|
|
308
|
+
return [_normalize_dim(dim) for dim in list(shape)]
|
|
309
|
+
|
|
310
|
+
def _shape_match_with_none(expected, actual):
|
|
311
|
+
if expected is None or actual is None:
|
|
312
|
+
return False
|
|
313
|
+
if len(expected) != len(actual):
|
|
314
|
+
return False
|
|
315
|
+
for e_dim, a_dim in zip(expected, actual):
|
|
316
|
+
e_dim = _normalize_dim(e_dim)
|
|
317
|
+
a_dim = _normalize_dim(a_dim)
|
|
318
|
+
if e_dim is None or a_dim is None:
|
|
319
|
+
continue
|
|
320
|
+
if e_dim != a_dim:
|
|
321
|
+
return False
|
|
322
|
+
return True
|
|
323
|
+
|
|
324
|
+
def _perm_shape(shape, perm):
|
|
325
|
+
return [shape[i] for i in perm] if shape is not None else None
|
|
326
|
+
|
|
327
|
+
def _limited_perms(rank):
|
|
328
|
+
identity = list(range(rank))
|
|
329
|
+
perms = [identity]
|
|
330
|
+
if rank == 3:
|
|
331
|
+
perms.append([0, 2, 1])
|
|
332
|
+
elif rank == 4:
|
|
333
|
+
perms.extend([[0, 2, 3, 1], [0, 3, 1, 2]])
|
|
334
|
+
elif rank == 5:
|
|
335
|
+
perms.extend([[0, 2, 3, 4, 1], [0, 4, 1, 2, 3]])
|
|
336
|
+
return perms
|
|
337
|
+
|
|
338
|
+
def _ranked_perms(perms, input_shape, onnx_shape):
|
|
339
|
+
if input_shape is None or onnx_shape is None:
|
|
340
|
+
return perms
|
|
341
|
+
scored = []
|
|
342
|
+
for perm in perms:
|
|
343
|
+
score = 0
|
|
344
|
+
for out_idx, in_idx in enumerate(perm):
|
|
345
|
+
if out_idx >= len(onnx_shape) or in_idx >= len(input_shape):
|
|
346
|
+
continue
|
|
347
|
+
o_dim = _normalize_dim(onnx_shape[out_idx])
|
|
348
|
+
i_dim = input_shape[in_idx]
|
|
349
|
+
if isinstance(o_dim, int) and isinstance(i_dim, int) and o_dim == i_dim:
|
|
350
|
+
score += o_dim
|
|
351
|
+
scored.append((score, 1 if perm == list(range(len(perm))) else 0, perm))
|
|
352
|
+
scored.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
|
353
|
+
return [p for _, _, p in scored]
|
|
354
|
+
|
|
355
|
+
# Rescue guard for unexpected broadcasted shapes
|
|
356
|
+
if not enable_gelu:
|
|
357
|
+
expected_shape = None
|
|
358
|
+
if graph_node_output_shape is not None:
|
|
359
|
+
expected_shape = [_normalize_dim(dim) for dim in list(graph_node_output_shape)]
|
|
360
|
+
output_shape = _get_static_shape(tf_layers_dict[graph_node_output.name]['tf_node'])
|
|
361
|
+
input_shape_1 = _get_static_shape(input_tensor_1)
|
|
362
|
+
input_shape_2 = _get_static_shape(input_tensor_2)
|
|
363
|
+
if expected_shape is not None \
|
|
364
|
+
and output_shape is not None \
|
|
365
|
+
and not _shape_match_with_none(expected_shape, output_shape) \
|
|
366
|
+
and input_shape_1 is not None \
|
|
367
|
+
and input_shape_2 is not None \
|
|
368
|
+
and len(input_shape_1) == len(expected_shape) \
|
|
369
|
+
and len(input_shape_2) == len(expected_shape):
|
|
370
|
+
|
|
371
|
+
rank = len(expected_shape)
|
|
372
|
+
perms = _limited_perms(rank)
|
|
373
|
+
perm_list_1 = _ranked_perms(perms, input_shape_1, expected_shape)
|
|
374
|
+
perm_list_2 = _ranked_perms(perms, input_shape_2, expected_shape)
|
|
375
|
+
rescue_done = False
|
|
376
|
+
for perm_1 in perm_list_1:
|
|
377
|
+
for perm_2 in perm_list_2:
|
|
378
|
+
try_input_1 = transpose_with_flexing_deterrence(
|
|
379
|
+
input_tensor=input_tensor_1,
|
|
380
|
+
perm=perm_1,
|
|
381
|
+
**kwargs,
|
|
382
|
+
)
|
|
383
|
+
try_input_2 = transpose_with_flexing_deterrence(
|
|
384
|
+
input_tensor=input_tensor_2,
|
|
385
|
+
perm=perm_2,
|
|
386
|
+
**kwargs,
|
|
387
|
+
)
|
|
388
|
+
try:
|
|
389
|
+
rescue_tensor = tf.math.add(
|
|
390
|
+
x=try_input_1 \
|
|
391
|
+
if not isinstance(try_input_1, np.ndarray) \
|
|
392
|
+
else tf.convert_to_tensor(try_input_1),
|
|
393
|
+
y=try_input_2 \
|
|
394
|
+
if not isinstance(try_input_2, np.ndarray) \
|
|
395
|
+
else tf.convert_to_tensor(try_input_2),
|
|
396
|
+
name=graph_node.name,
|
|
397
|
+
)
|
|
398
|
+
except Exception as ex:
|
|
399
|
+
continue
|
|
400
|
+
|
|
401
|
+
rescue_shape = _get_static_shape(rescue_tensor)
|
|
402
|
+
if _shape_match_with_none(expected_shape, rescue_shape):
|
|
403
|
+
input_tensor_1 = try_input_1
|
|
404
|
+
input_tensor_2 = try_input_2
|
|
405
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = rescue_tensor
|
|
406
|
+
tf_type = tf.math.add
|
|
407
|
+
rescue_done = True
|
|
408
|
+
break
|
|
409
|
+
if rescue_done:
|
|
410
|
+
break
|
|
411
|
+
|
|
300
412
|
# Post-process transpose
|
|
301
413
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
302
414
|
post_process_transpose(
|
onnx2tf/ops/Concat.py
CHANGED
|
@@ -291,6 +291,78 @@ def make_node(
|
|
|
291
291
|
tf_type = tf.constant
|
|
292
292
|
|
|
293
293
|
else:
|
|
294
|
+
def _normalize_dim(dim):
|
|
295
|
+
return int(dim) if isinstance(dim, (int, np.integer)) else None
|
|
296
|
+
|
|
297
|
+
def _get_static_shape(tensor):
|
|
298
|
+
shape = getattr(tensor, 'shape', None)
|
|
299
|
+
if shape is None or shape == tf.TensorShape(None):
|
|
300
|
+
return None
|
|
301
|
+
return [_normalize_dim(dim) for dim in list(shape)]
|
|
302
|
+
|
|
303
|
+
def _shape_match_with_none(onnx_shape, tf_shape):
|
|
304
|
+
if onnx_shape is None or tf_shape is None:
|
|
305
|
+
return False
|
|
306
|
+
if len(onnx_shape) != len(tf_shape):
|
|
307
|
+
return False
|
|
308
|
+
for o_dim, t_dim in zip(onnx_shape, tf_shape):
|
|
309
|
+
o_dim = _normalize_dim(o_dim)
|
|
310
|
+
t_dim = _normalize_dim(t_dim)
|
|
311
|
+
if o_dim is None or t_dim is None:
|
|
312
|
+
continue
|
|
313
|
+
if o_dim != t_dim:
|
|
314
|
+
return False
|
|
315
|
+
return True
|
|
316
|
+
|
|
317
|
+
def _can_concat_shapes(shapes, axis):
|
|
318
|
+
if shapes is None or any(s is None for s in shapes):
|
|
319
|
+
return True
|
|
320
|
+
rank = len(shapes[0])
|
|
321
|
+
for idx in range(rank):
|
|
322
|
+
if idx == axis:
|
|
323
|
+
continue
|
|
324
|
+
dims = [s[idx] for s in shapes]
|
|
325
|
+
known = [d for d in dims if isinstance(d, int)]
|
|
326
|
+
if len(known) >= 2 and len(set(known)) != 1:
|
|
327
|
+
return False
|
|
328
|
+
return True
|
|
329
|
+
|
|
330
|
+
def _perm_shape(shape, perm):
|
|
331
|
+
return [shape[i] for i in perm] if shape is not None else None
|
|
332
|
+
|
|
333
|
+
def _limited_perms(rank):
|
|
334
|
+
identity = list(range(rank))
|
|
335
|
+
perms = [identity]
|
|
336
|
+
if rank == 3:
|
|
337
|
+
perms.append([0, 2, 1])
|
|
338
|
+
elif rank == 4:
|
|
339
|
+
perms.extend([[0, 2, 3, 1], [0, 3, 1, 2]])
|
|
340
|
+
elif rank == 5:
|
|
341
|
+
perms.extend([[0, 2, 3, 4, 1], [0, 4, 1, 2, 3]])
|
|
342
|
+
return perms
|
|
343
|
+
|
|
344
|
+
def _base_perms(rank):
|
|
345
|
+
if rank <= 1:
|
|
346
|
+
return [list(range(rank))]
|
|
347
|
+
return [list(p) for p in itertools.permutations(range(rank))]
|
|
348
|
+
|
|
349
|
+
def _ranked_perms(perms, input_shape, axis, onnx_shape):
|
|
350
|
+
identity = list(range(len(perms[0]))) if perms else []
|
|
351
|
+
scored = []
|
|
352
|
+
for perm in perms:
|
|
353
|
+
score = 0
|
|
354
|
+
if input_shape is not None and onnx_shape is not None:
|
|
355
|
+
for out_idx, in_idx in enumerate(perm):
|
|
356
|
+
if out_idx == axis:
|
|
357
|
+
continue
|
|
358
|
+
o_dim = _normalize_dim(onnx_shape[out_idx]) if out_idx < len(onnx_shape) else None
|
|
359
|
+
i_dim = input_shape[in_idx] if in_idx < len(input_shape) else None
|
|
360
|
+
if isinstance(o_dim, int) and isinstance(i_dim, int) and o_dim == i_dim:
|
|
361
|
+
score += o_dim
|
|
362
|
+
scored.append((score, 1 if perm == identity else 0, perm))
|
|
363
|
+
scored.sort(key=lambda x: (x[0], x[1]), reverse=True)
|
|
364
|
+
return [p for _, _, p in scored]
|
|
365
|
+
|
|
294
366
|
try:
|
|
295
367
|
# normal concat attempt
|
|
296
368
|
tf_layers_dict[graph_node_output.name]['tf_node'] = \
|
|
@@ -301,35 +373,109 @@ def make_node(
|
|
|
301
373
|
)
|
|
302
374
|
except:
|
|
303
375
|
# Workaround to reduce error rate when merging tensors with undefined dimensions
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
376
|
+
original_values = values
|
|
377
|
+
original_shapes = [_get_static_shape(v) for v in original_values]
|
|
378
|
+
value_rank = getattr(original_values[0].shape, 'rank', None)
|
|
379
|
+
if value_rank is None:
|
|
380
|
+
value_rank = len(original_values[0].shape)
|
|
381
|
+
|
|
382
|
+
onnx_shape_list = None
|
|
383
|
+
if onnx_output_shape is not None:
|
|
384
|
+
onnx_shape_list = [_normalize_dim(dim) for dim in list(onnx_output_shape)]
|
|
385
|
+
|
|
386
|
+
onnx_axis = int(graph_node.attrs.get('axis', 0))
|
|
387
|
+
onnx_axis = onnx_axis + value_rank if onnx_axis < 0 else onnx_axis
|
|
388
|
+
|
|
389
|
+
def _axis_score(axis_idx):
|
|
390
|
+
if onnx_shape_list is not None and axis_idx < len(onnx_shape_list):
|
|
391
|
+
onnx_dim = onnx_shape_list[axis_idx]
|
|
392
|
+
if isinstance(onnx_dim, int):
|
|
393
|
+
return onnx_dim
|
|
394
|
+
score = 0
|
|
395
|
+
for shape in original_shapes:
|
|
396
|
+
if shape is None or axis_idx >= len(shape):
|
|
397
|
+
continue
|
|
398
|
+
dim = shape[axis_idx]
|
|
399
|
+
if isinstance(dim, int):
|
|
400
|
+
score += dim
|
|
401
|
+
return score
|
|
402
|
+
|
|
403
|
+
axis_candidates = list(range(value_rank))
|
|
404
|
+
axis_candidates.sort(
|
|
405
|
+
key=lambda a: (_axis_score(a), 1 if a == onnx_axis else 0),
|
|
406
|
+
reverse=True,
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
base_perms = _base_perms(value_rank)
|
|
410
|
+
max_combo = 20000
|
|
411
|
+
if len(base_perms) ** len(original_values) > max_combo:
|
|
412
|
+
base_perms = _limited_perms(value_rank)
|
|
413
|
+
|
|
414
|
+
succeed = False
|
|
415
|
+
matched = False
|
|
416
|
+
chosen_axis = None
|
|
417
|
+
chosen_values = None
|
|
418
|
+
chosen_tensor = None
|
|
419
|
+
|
|
420
|
+
for axis_idx in axis_candidates:
|
|
421
|
+
perm_lists = [
|
|
422
|
+
_ranked_perms(base_perms, shape, axis_idx, onnx_shape_list)
|
|
423
|
+
for shape in original_shapes
|
|
424
|
+
]
|
|
425
|
+
for perm_combo in itertools.product(*perm_lists):
|
|
426
|
+
permuted_shapes = [
|
|
427
|
+
_perm_shape(shape, perm) for shape, perm in zip(original_shapes, perm_combo)
|
|
428
|
+
]
|
|
429
|
+
if not _can_concat_shapes(permuted_shapes, axis_idx):
|
|
430
|
+
continue
|
|
431
|
+
try_values = [
|
|
432
|
+
value if perm == list(range(value_rank)) else
|
|
433
|
+
transpose_with_flexing_deterrence(
|
|
434
|
+
input_tensor=value,
|
|
435
|
+
perm=perm,
|
|
436
|
+
**kwargs,
|
|
437
|
+
)
|
|
438
|
+
for value, perm in zip(original_values, perm_combo)
|
|
439
|
+
]
|
|
319
440
|
try:
|
|
320
|
-
|
|
441
|
+
concat_tensor = \
|
|
321
442
|
tf.concat(
|
|
322
|
-
values=
|
|
323
|
-
axis=
|
|
443
|
+
values=try_values,
|
|
444
|
+
axis=axis_idx,
|
|
324
445
|
name=graph_node.name,
|
|
325
446
|
)
|
|
326
|
-
|
|
447
|
+
except:
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
if not succeed:
|
|
327
451
|
succeed = True
|
|
452
|
+
chosen_axis = axis_idx
|
|
453
|
+
chosen_values = try_values
|
|
454
|
+
chosen_tensor = concat_tensor
|
|
455
|
+
|
|
456
|
+
if onnx_shape_list is None:
|
|
457
|
+
matched = True
|
|
458
|
+
chosen_axis = axis_idx
|
|
459
|
+
chosen_values = try_values
|
|
460
|
+
chosen_tensor = concat_tensor
|
|
328
461
|
break
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
462
|
+
|
|
463
|
+
output_shape = _get_static_shape(concat_tensor)
|
|
464
|
+
if _shape_match_with_none(onnx_shape_list, output_shape):
|
|
465
|
+
matched = True
|
|
466
|
+
chosen_axis = axis_idx
|
|
467
|
+
chosen_values = try_values
|
|
468
|
+
chosen_tensor = concat_tensor
|
|
469
|
+
break
|
|
470
|
+
if matched:
|
|
471
|
+
break
|
|
472
|
+
|
|
473
|
+
if succeed:
|
|
474
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = chosen_tensor
|
|
475
|
+
axis = chosen_axis
|
|
476
|
+
values = chosen_values
|
|
477
|
+
else:
|
|
478
|
+
raise
|
|
333
479
|
|
|
334
480
|
# Attempts to force axis correction when the number of axes in the combined tensor do not exactly match.
|
|
335
481
|
# However, if more than 2 patterns of correct answers exist, give up the correction.
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
import random
|
|
2
|
+
random.seed(0)
|
|
3
|
+
import numpy as np
|
|
4
|
+
np.random.seed(0)
|
|
5
|
+
import tensorflow as tf
|
|
6
|
+
import onnx_graphsurgeon as gs
|
|
7
|
+
import cv2
|
|
8
|
+
from onnx2tf.utils.common_functions import (
|
|
9
|
+
get_constant_or_variable,
|
|
10
|
+
print_node_info,
|
|
11
|
+
inverted_operation_enable_disable,
|
|
12
|
+
make_tf_node_info,
|
|
13
|
+
get_replacement_parameter,
|
|
14
|
+
pre_process_transpose,
|
|
15
|
+
post_process_transpose,
|
|
16
|
+
)
|
|
17
|
+
from onnx2tf.utils.logging import *
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _as_tensor(value):
|
|
21
|
+
if isinstance(value, np.ndarray):
|
|
22
|
+
return tf.convert_to_tensor(value)
|
|
23
|
+
if isinstance(value, (np.generic, int, float, bool, str, bytes)):
|
|
24
|
+
return tf.convert_to_tensor(value)
|
|
25
|
+
return value
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _decode_image_np(encoded_stream, pixel_format):
|
|
29
|
+
if encoded_stream is None:
|
|
30
|
+
return np.zeros((0, 0, 0), dtype=np.uint8)
|
|
31
|
+
if encoded_stream.dtype != np.uint8:
|
|
32
|
+
encoded_stream = encoded_stream.astype(np.uint8)
|
|
33
|
+
if encoded_stream.size == 0:
|
|
34
|
+
return np.zeros((0, 0, 0), dtype=np.uint8)
|
|
35
|
+
if encoded_stream.ndim != 1:
|
|
36
|
+
encoded_stream = encoded_stream.reshape(-1)
|
|
37
|
+
try:
|
|
38
|
+
if pixel_format == 'Grayscale':
|
|
39
|
+
flag = cv2.IMREAD_GRAYSCALE
|
|
40
|
+
else:
|
|
41
|
+
flag = cv2.IMREAD_COLOR
|
|
42
|
+
decoded = cv2.imdecode(encoded_stream, flag)
|
|
43
|
+
if decoded is None:
|
|
44
|
+
raise ValueError('cv2.imdecode failed')
|
|
45
|
+
if pixel_format == 'RGB':
|
|
46
|
+
decoded = cv2.cvtColor(decoded, cv2.COLOR_BGR2RGB)
|
|
47
|
+
if pixel_format == 'Grayscale' and decoded.ndim == 2:
|
|
48
|
+
decoded = decoded[..., np.newaxis]
|
|
49
|
+
return decoded.astype(np.uint8)
|
|
50
|
+
except Exception:
|
|
51
|
+
if pixel_format == 'Grayscale':
|
|
52
|
+
return np.zeros((0, 0, 1), dtype=np.uint8)
|
|
53
|
+
return np.zeros((0, 0, 3), dtype=np.uint8)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@print_node_info
|
|
57
|
+
@inverted_operation_enable_disable
|
|
58
|
+
@get_replacement_parameter
|
|
59
|
+
def make_node(
|
|
60
|
+
*,
|
|
61
|
+
graph_node: gs.Node,
|
|
62
|
+
tf_layers_dict: dict,
|
|
63
|
+
**kwargs: dict,
|
|
64
|
+
):
|
|
65
|
+
"""ImageDecoder
|
|
66
|
+
|
|
67
|
+
Parameters
|
|
68
|
+
----------
|
|
69
|
+
graph_node: gs.Node
|
|
70
|
+
graph_surgeon Node
|
|
71
|
+
|
|
72
|
+
tf_layers_dict: dict
|
|
73
|
+
optype, shape, dtype, tensorflow graph
|
|
74
|
+
"""
|
|
75
|
+
before_op_output_shape_trans_1 = \
|
|
76
|
+
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
|
|
77
|
+
before_op_output_shape_trans = \
|
|
78
|
+
before_op_output_shape_trans_1
|
|
79
|
+
|
|
80
|
+
graph_node_input = get_constant_or_variable(
|
|
81
|
+
graph_node.inputs[0],
|
|
82
|
+
before_op_output_shape_trans,
|
|
83
|
+
)
|
|
84
|
+
graph_node_output: gs.Variable = graph_node.outputs[0]
|
|
85
|
+
shape = graph_node_output.shape
|
|
86
|
+
dtype = graph_node_output.dtype
|
|
87
|
+
|
|
88
|
+
input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
|
|
89
|
+
if isinstance(graph_node_input, gs.Variable) else graph_node_input
|
|
90
|
+
|
|
91
|
+
# Preserving Graph Structure (Dict)
|
|
92
|
+
tf_layers_dict[graph_node_output.name] = {
|
|
93
|
+
'optype': graph_node.op,
|
|
94
|
+
'shape': shape,
|
|
95
|
+
'dtype': dtype,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Pre-process transpose
|
|
99
|
+
input_tensor = pre_process_transpose(
|
|
100
|
+
value_before_transpose=input_tensor,
|
|
101
|
+
param_target='inputs',
|
|
102
|
+
param_name=graph_node.inputs[0].name,
|
|
103
|
+
**kwargs,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Generation of TF OP
|
|
107
|
+
input_tensor = _as_tensor(input_tensor)
|
|
108
|
+
pixel_format = graph_node.attrs.get('pixel_format', 'RGB')
|
|
109
|
+
if pixel_format not in ['RGB', 'BGR', 'Grayscale']:
|
|
110
|
+
error(
|
|
111
|
+
f'ImageDecoder pixel_format={pixel_format} is not supported.\n' +
|
|
112
|
+
f'graph_node.name: {graph_node.name}'
|
|
113
|
+
)
|
|
114
|
+
pixel_format = 'RGB'
|
|
115
|
+
|
|
116
|
+
decoded = tf.numpy_function(
|
|
117
|
+
func=lambda x: _decode_image_np(x, pixel_format),
|
|
118
|
+
inp=[input_tensor],
|
|
119
|
+
Tout=tf.uint8,
|
|
120
|
+
name=graph_node.name,
|
|
121
|
+
)
|
|
122
|
+
channels = 1 if pixel_format == 'Grayscale' else 3
|
|
123
|
+
decoded = tf.ensure_shape(decoded, [None, None, channels])
|
|
124
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = decoded
|
|
125
|
+
|
|
126
|
+
# Post-process transpose
|
|
127
|
+
tf_layers_dict[graph_node_output.name]['tf_node'] = post_process_transpose(
|
|
128
|
+
value_before_transpose=tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
129
|
+
param_target='outputs',
|
|
130
|
+
param_name=graph_node.outputs[0].name,
|
|
131
|
+
**kwargs,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Generation of Debug Info
|
|
135
|
+
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
|
|
136
|
+
make_tf_node_info(
|
|
137
|
+
node_info={
|
|
138
|
+
'tf_op_type': 'ImageDecoder',
|
|
139
|
+
'tf_inputs': {
|
|
140
|
+
'encoded_stream': input_tensor,
|
|
141
|
+
'pixel_format': pixel_format,
|
|
142
|
+
},
|
|
143
|
+
'tf_outputs': {
|
|
144
|
+
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
|
|
145
|
+
},
|
|
146
|
+
}
|
|
147
|
+
)
|