onnx2tf 1.29.21__py3-none-any.whl → 1.29.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx2tf/__init__.py +1 -1
- onnx2tf/ops/MaxPool.py +54 -0
- onnx2tf/ops/ScatterElements.py +50 -1
- onnx2tf/ops/Unique.py +71 -11
- {onnx2tf-1.29.21.dist-info → onnx2tf-1.29.22.dist-info}/METADATA +3 -3
- {onnx2tf-1.29.21.dist-info → onnx2tf-1.29.22.dist-info}/RECORD +8 -8
- {onnx2tf-1.29.21.dist-info → onnx2tf-1.29.22.dist-info}/WHEEL +0 -0
- {onnx2tf-1.29.21.dist-info → onnx2tf-1.29.22.dist-info}/entry_points.txt +0 -0
onnx2tf/__init__.py
CHANGED
onnx2tf/ops/MaxPool.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import random
|
|
2
2
|
random.seed(0)
|
|
3
3
|
import numpy as np
|
|
4
|
+
import itertools
|
|
4
5
|
np.random.seed(0)
|
|
5
6
|
import tensorflow as tf
|
|
6
7
|
import tf_keras
|
|
@@ -119,6 +120,59 @@ def make_node(
|
|
|
119
120
|
**kwargs,
|
|
120
121
|
)
|
|
121
122
|
|
|
123
|
+
# Guard: brute-force axis alignment between NCHW and NHWC when batch dim mismatches.
|
|
124
|
+
# Only trigger when shapes are fully known to avoid destabilizing existing behavior.
|
|
125
|
+
def _shape_matches(shape_a, shape_b):
|
|
126
|
+
if shape_a is None or shape_b is None or len(shape_a) != len(shape_b):
|
|
127
|
+
return False
|
|
128
|
+
for dim_a, dim_b in zip(shape_a, shape_b):
|
|
129
|
+
if dim_a is None or dim_b is None:
|
|
130
|
+
continue
|
|
131
|
+
if dim_a != dim_b:
|
|
132
|
+
return False
|
|
133
|
+
return True
|
|
134
|
+
|
|
135
|
+
def _best_perm_to_match(cur_shape, target_shape):
|
|
136
|
+
rank = len(cur_shape)
|
|
137
|
+
best_perm = None
|
|
138
|
+
best_cost = None
|
|
139
|
+
for perm in itertools.permutations(range(rank)):
|
|
140
|
+
permuted = [cur_shape[i] for i in perm]
|
|
141
|
+
if not _shape_matches(permuted, target_shape):
|
|
142
|
+
continue
|
|
143
|
+
cost = sum(abs(i - perm[i]) for i in range(rank))
|
|
144
|
+
if best_cost is None or cost < best_cost:
|
|
145
|
+
best_cost = cost
|
|
146
|
+
best_perm = perm
|
|
147
|
+
return best_perm
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
current_shape = input_tensor.shape.as_list()
|
|
151
|
+
except Exception:
|
|
152
|
+
current_shape = None
|
|
153
|
+
|
|
154
|
+
if onnx_input_shape is not None and current_shape is not None:
|
|
155
|
+
onnx_shape = [
|
|
156
|
+
dim if isinstance(dim, int) else None for dim in onnx_input_shape
|
|
157
|
+
]
|
|
158
|
+
cur_shape = [
|
|
159
|
+
dim if isinstance(dim, int) else None for dim in current_shape
|
|
160
|
+
]
|
|
161
|
+
if len(onnx_shape) in (3, 4, 5) \
|
|
162
|
+
and len(cur_shape) == len(onnx_shape) \
|
|
163
|
+
and None not in onnx_shape \
|
|
164
|
+
and None not in cur_shape:
|
|
165
|
+
expected_shape = [onnx_shape[0]] + onnx_shape[2:] + [onnx_shape[1]]
|
|
166
|
+
if cur_shape[0] != onnx_shape[0] \
|
|
167
|
+
and not _shape_matches(cur_shape, expected_shape):
|
|
168
|
+
perm = _best_perm_to_match(cur_shape, expected_shape)
|
|
169
|
+
if perm is not None:
|
|
170
|
+
input_tensor = transpose_with_flexing_deterrence(
|
|
171
|
+
input_tensor=input_tensor,
|
|
172
|
+
perm=list(perm),
|
|
173
|
+
**kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
122
176
|
filter = None
|
|
123
177
|
|
|
124
178
|
auto_pad = graph_node.attrs.get('auto_pad', 'NOTSET')
|
onnx2tf/ops/ScatterElements.py
CHANGED
|
@@ -146,6 +146,48 @@ def make_node(
|
|
|
146
146
|
axis=axis,
|
|
147
147
|
indices=indices_tensor,
|
|
148
148
|
)
|
|
149
|
+
indices_rank = None
|
|
150
|
+
if hasattr(indices_tensor, "shape") and indices_tensor.shape is not None:
|
|
151
|
+
try:
|
|
152
|
+
indices_rank = len(indices_tensor.shape)
|
|
153
|
+
except TypeError:
|
|
154
|
+
indices_rank = indices_tensor.shape.rank
|
|
155
|
+
updates_rank = updates_tensor_rank
|
|
156
|
+
broadcast_shape = None
|
|
157
|
+
pad_rank = 0
|
|
158
|
+
|
|
159
|
+
def _pad_and_broadcast(target_tensor, pad_rank, target_shape):
|
|
160
|
+
tensor = target_tensor
|
|
161
|
+
if isinstance(tensor, np.ndarray):
|
|
162
|
+
tensor = tf.convert_to_tensor(tensor)
|
|
163
|
+
if pad_rank <= 0:
|
|
164
|
+
return tf.broadcast_to(tensor, target_shape)
|
|
165
|
+
tensor_shape = tf.shape(tensor)
|
|
166
|
+
new_shape = tf.concat(
|
|
167
|
+
[tf.ones([pad_rank], dtype=tf.int32), tensor_shape],
|
|
168
|
+
axis=0,
|
|
169
|
+
)
|
|
170
|
+
tensor = tf.reshape(tensor, new_shape)
|
|
171
|
+
return tf.broadcast_to(tensor, target_shape)
|
|
172
|
+
|
|
173
|
+
updates_tensor_for_scatter = updates_tensor
|
|
174
|
+
if indices_rank is not None and updates_rank is not None and indices_rank != updates_rank:
|
|
175
|
+
if indices_rank > updates_rank:
|
|
176
|
+
broadcast_shape = tf.shape(indices_tensor)
|
|
177
|
+
pad_rank = indices_rank - updates_rank
|
|
178
|
+
updates_tensor_for_scatter = _pad_and_broadcast(
|
|
179
|
+
updates_tensor,
|
|
180
|
+
pad_rank,
|
|
181
|
+
broadcast_shape,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
broadcast_shape = tf.shape(updates_tensor)
|
|
185
|
+
pad_rank = updates_rank - indices_rank
|
|
186
|
+
indices_tensor = _pad_and_broadcast(
|
|
187
|
+
indices_tensor,
|
|
188
|
+
pad_rank,
|
|
189
|
+
broadcast_shape,
|
|
190
|
+
)
|
|
149
191
|
sparsified_dense_idx_shape = updates_tensor_shape
|
|
150
192
|
|
|
151
193
|
if None not in sparsified_dense_idx_shape:
|
|
@@ -160,6 +202,13 @@ def make_node(
|
|
|
160
202
|
]
|
|
161
203
|
|
|
162
204
|
idx_tensors_per_axis = tf.meshgrid(*idx_tensors_per_axis, indexing='ij')
|
|
205
|
+
if indices_rank is not None \
|
|
206
|
+
and updates_rank is not None \
|
|
207
|
+
and indices_rank > updates_rank:
|
|
208
|
+
idx_tensors_per_axis = [
|
|
209
|
+
_pad_and_broadcast(idx_tensor, pad_rank, broadcast_shape)
|
|
210
|
+
for idx_tensor in idx_tensors_per_axis
|
|
211
|
+
]
|
|
163
212
|
idx_tensors_per_axis[axis] = indices_tensor
|
|
164
213
|
dim_expanded_idx_tensors_per_axis = [
|
|
165
214
|
tf.expand_dims(idx_tensor, axis=-1)
|
|
@@ -194,7 +243,7 @@ def make_node(
|
|
|
194
243
|
)
|
|
195
244
|
|
|
196
245
|
indices = tf.reshape(coordinate, [-1, input_tensor_rank])
|
|
197
|
-
updates = tf.reshape(
|
|
246
|
+
updates = tf.reshape(updates_tensor_for_scatter, [-1])
|
|
198
247
|
output = tf.tensor_scatter_nd_update(
|
|
199
248
|
tensor=input_tensor,
|
|
200
249
|
indices=indices,
|
onnx2tf/ops/Unique.py
CHANGED
|
@@ -14,6 +14,7 @@ from onnx2tf.utils.common_functions import (
|
|
|
14
14
|
make_tf_node_info,
|
|
15
15
|
get_replacement_parameter,
|
|
16
16
|
pre_process_transpose,
|
|
17
|
+
convert_axis,
|
|
17
18
|
)
|
|
18
19
|
from onnx2tf.utils.logging import Color
|
|
19
20
|
|
|
@@ -69,8 +70,20 @@ def make_node(
|
|
|
69
70
|
**kwargs,
|
|
70
71
|
)
|
|
71
72
|
|
|
73
|
+
input_tensor_shape = input_tensor.shape
|
|
74
|
+
tensor_rank = len(input_tensor_shape) \
|
|
75
|
+
if input_tensor_shape != tf.TensorShape(None) else 1
|
|
76
|
+
|
|
72
77
|
axis = graph_node.attrs.get('axis', None)
|
|
73
78
|
sorted = graph_node.attrs.get('sorted', 1)
|
|
79
|
+
if axis is not None:
|
|
80
|
+
if isinstance(axis, np.ndarray) and axis.shape == ():
|
|
81
|
+
axis = int(axis)
|
|
82
|
+
axis = convert_axis(
|
|
83
|
+
axis=int(axis),
|
|
84
|
+
tensor_rank=tensor_rank,
|
|
85
|
+
before_op_output_shape_trans=before_op_output_shape_trans,
|
|
86
|
+
)
|
|
74
87
|
|
|
75
88
|
# Preserving Graph Structure (Dict)
|
|
76
89
|
for graph_node_output in graph_node_outputs:
|
|
@@ -101,17 +114,64 @@ def make_node(
|
|
|
101
114
|
|
|
102
115
|
# tf unique returns unsorted tensor, need to sort if option is enabled
|
|
103
116
|
if sorted:
|
|
104
|
-
#
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
117
|
+
# Sort unique outputs to match ONNX sorted behavior.
|
|
118
|
+
def _argsort_supported(dtype):
|
|
119
|
+
return dtype.is_floating or dtype.is_integer or dtype == tf.bool
|
|
120
|
+
|
|
121
|
+
y_rank = y.shape.rank
|
|
122
|
+
axis_ = axis
|
|
123
|
+
if axis_ is None:
|
|
124
|
+
axis_ = 0
|
|
125
|
+
if axis_ < 0 and y_rank is not None:
|
|
126
|
+
axis_ = axis_ + y_rank
|
|
127
|
+
|
|
128
|
+
def _lexsort_perm(flat_2d):
|
|
129
|
+
if not _argsort_supported(flat_2d.dtype):
|
|
130
|
+
return None
|
|
131
|
+
cols = flat_2d.shape[1]
|
|
132
|
+
if cols is None:
|
|
133
|
+
return None
|
|
134
|
+
order = tf.range(tf.shape(flat_2d)[0])
|
|
135
|
+
for col in reversed(range(cols)):
|
|
136
|
+
col_vals = tf.gather(flat_2d, order)[:, col]
|
|
137
|
+
if col_vals.dtype == tf.bool:
|
|
138
|
+
col_vals = tf.cast(col_vals, tf.int32)
|
|
139
|
+
order = tf.gather(order, tf.argsort(col_vals, stable=True))
|
|
140
|
+
return order
|
|
141
|
+
|
|
142
|
+
order = None
|
|
143
|
+
if y_rank is not None and y_rank == 1:
|
|
144
|
+
if _argsort_supported(y.dtype):
|
|
145
|
+
sort_vals = y
|
|
146
|
+
if sort_vals.dtype == tf.bool:
|
|
147
|
+
sort_vals = tf.cast(sort_vals, tf.int32)
|
|
148
|
+
order = tf.argsort(sort_vals, stable=True)
|
|
149
|
+
elif y_rank is not None and axis_ is not None and 0 <= axis_ < y_rank:
|
|
150
|
+
perm = [axis_] + [i for i in range(y_rank) if i != axis_]
|
|
151
|
+
y_t = tf.transpose(y, perm)
|
|
152
|
+
flat = tf.reshape(y_t, [tf.shape(y_t)[0], -1])
|
|
153
|
+
order = _lexsort_perm(flat)
|
|
154
|
+
|
|
155
|
+
if order is None:
|
|
156
|
+
warn_msg = f'' + \
|
|
157
|
+
Color.YELLOW(f'WARNING:') + ' ' + \
|
|
158
|
+
f'Unique sort fallback to unsorted due to dynamic shape or unsupported dtype.'
|
|
159
|
+
print(warn_msg)
|
|
160
|
+
else:
|
|
161
|
+
y = tf.gather(y, order, axis=axis_)
|
|
162
|
+
count = tf.gather(count, order)
|
|
163
|
+
indices = tf.gather(indices, order)
|
|
164
|
+
inv_order = tf.argsort(order)
|
|
165
|
+
inverse_indices = tf.gather(inv_order, inverse_indices)
|
|
166
|
+
|
|
167
|
+
if len(graph_node_outputs) >= 1:
|
|
168
|
+
tf_layers_dict[graph_node_outputs[0].name]['tf_node'] = y
|
|
169
|
+
if len(graph_node_outputs) >= 2:
|
|
170
|
+
tf_layers_dict[graph_node_outputs[1].name]['tf_node'] = indices
|
|
171
|
+
if len(graph_node_outputs) >= 3:
|
|
172
|
+
tf_layers_dict[graph_node_outputs[2].name]['tf_node'] = inverse_indices
|
|
173
|
+
if len(graph_node_outputs) >= 4:
|
|
174
|
+
tf_layers_dict[graph_node_outputs[3].name]['tf_node'] = count
|
|
115
175
|
|
|
116
176
|
# Generation of Debug Info
|
|
117
177
|
tf_outputs = {f"output{idx}": value for idx, value in enumerate([y, indices, inverse_indices, count])}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx2tf
|
|
3
|
-
Version: 1.29.
|
|
3
|
+
Version: 1.29.22
|
|
4
4
|
Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
|
|
5
5
|
Keywords: onnx,tensorflow,tflite,keras,deep-learning,machine-learning
|
|
6
6
|
Author: Katsuya Hyodo
|
|
@@ -365,7 +365,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
|
|
|
365
365
|
docker run --rm -it \
|
|
366
366
|
-v `pwd`:/workdir \
|
|
367
367
|
-w /workdir \
|
|
368
|
-
ghcr.io/pinto0309/onnx2tf:1.29.
|
|
368
|
+
ghcr.io/pinto0309/onnx2tf:1.29.22
|
|
369
369
|
|
|
370
370
|
or
|
|
371
371
|
|
|
@@ -373,7 +373,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
|
|
|
373
373
|
docker run --rm -it \
|
|
374
374
|
-v `pwd`:/workdir \
|
|
375
375
|
-w /workdir \
|
|
376
|
-
docker.io/pinto0309/onnx2tf:1.29.
|
|
376
|
+
docker.io/pinto0309/onnx2tf:1.29.22
|
|
377
377
|
|
|
378
378
|
or
|
|
379
379
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
onnx2tf/__init__.py,sha256=
|
|
1
|
+
onnx2tf/__init__.py,sha256=EUJwryuQmTqI1Huqr9hrm1fJwxQqmGxUjAFgGvuB950,67
|
|
2
2
|
onnx2tf/__main__.py,sha256=2RSCQ7d4lc6CwD-rlGn9UicPFg-P5du7ZD_yh-kuBEU,57
|
|
3
3
|
onnx2tf/onnx2tf.py,sha256=BC-BFMf8QUG7PtOvpwglhe1sc4FhTO8AMrdlxKUN5jc,208204
|
|
4
4
|
onnx2tf/ops/Abs.py,sha256=V7btmCG_ZvK_qJovUsguq0ZMJ349mhNQ4FHSgzP_Yuo,4029
|
|
@@ -98,7 +98,7 @@ onnx2tf/ops/LpPool.py,sha256=96eI1FaDgW0M_USWBCHFedvtojHTLL28_lb3mcEV55A,10470
|
|
|
98
98
|
onnx2tf/ops/MatMul.py,sha256=X9cQSD4BCogkDP6D4YZEZmOWnsceGL8ppN8E4kqyjB0,23926
|
|
99
99
|
onnx2tf/ops/MatMulInteger.py,sha256=qHqzdJNI9SeJDbW8pR90baYCdGN6FdOez4hi9EzwXoc,6538
|
|
100
100
|
onnx2tf/ops/Max.py,sha256=w5nMciO_6ApYUobHuwMGuS3xhuza7eSvKDRhvMPgAuo,3256
|
|
101
|
-
onnx2tf/ops/MaxPool.py,sha256=
|
|
101
|
+
onnx2tf/ops/MaxPool.py,sha256=0zO5gNfcwAhXAVxIbURiXAlLQ56oJvtpcyLm4leZVsM,17948
|
|
102
102
|
onnx2tf/ops/MaxRoiPool.py,sha256=RYZyjnINqJd6k7KLFJ-D9iHjA2vR-m7WvhrumD9cmDk,8490
|
|
103
103
|
onnx2tf/ops/MaxUnpool.py,sha256=dGIEvC45rFuWoeG1S9j4sjEdEUqiWs_xdY5DZH6X7b4,5743
|
|
104
104
|
onnx2tf/ops/Mean.py,sha256=xfTjKpQntJB8uXAkgDLS77oLXy2yQh1Hlz0K2GltMh0,3132
|
|
@@ -160,7 +160,7 @@ onnx2tf/ops/STFT.py,sha256=LDKN309_dBu4v9AYpz70uMJbNjRFiOte9O3wUL4bIJw,4463
|
|
|
160
160
|
onnx2tf/ops/ScaleAndTranslate.py,sha256=VQDDhSs9TyMLQy0mF7n8pZ2TuvoKY-Lhlzd7Inf4UdI,11989
|
|
161
161
|
onnx2tf/ops/Scan.py,sha256=hfN-DX6Gp-dG5158WMoHRrDWZAra3VSbsjsiphNqRIQ,16293
|
|
162
162
|
onnx2tf/ops/Scatter.py,sha256=5_rTM60FPCq8unyNPDO-BZXcuz6w9Uyl2Xqx-zJTpgg,746
|
|
163
|
-
onnx2tf/ops/ScatterElements.py,sha256=
|
|
163
|
+
onnx2tf/ops/ScatterElements.py,sha256=vXURNSkorfm7iQ_HA5vY9nf6YIkrAdqmB7sdEtWnUCo,10452
|
|
164
164
|
onnx2tf/ops/ScatterND.py,sha256=-mVbxXjQor2T6HVHSJy5e0FHQmEfaHknaKPuSc3Oz4o,11005
|
|
165
165
|
onnx2tf/ops/Selu.py,sha256=CD0SqQlTTe0chO7lebkrdfDFSk6Cg9zLhvrKomsSH4Y,3799
|
|
166
166
|
onnx2tf/ops/SequenceAt.py,sha256=jpjl9gVJFagtg223YY26I0pUUEgEFjJGvSZWwbo2-mQ,3278
|
|
@@ -199,7 +199,7 @@ onnx2tf/ops/Tile.py,sha256=xkprg6yTaykivcHFJ644opzVPctaeplu-Ed-OpS98Gg,12720
|
|
|
199
199
|
onnx2tf/ops/TopK.py,sha256=f6OG-DcMWneXwSjIkmY935SPyOMD5tMteHnlQHoJwQo,6348
|
|
200
200
|
onnx2tf/ops/Transpose.py,sha256=GwJFp7zVqodEsv5mGWviuFqeK93uVM7dbRQ1N8Ua1hg,9774
|
|
201
201
|
onnx2tf/ops/Trilu.py,sha256=uz2TgdErpo9GDp9n4PCe0_koIpNLgBoPCjv3A6VBTl8,4789
|
|
202
|
-
onnx2tf/ops/Unique.py,sha256=
|
|
202
|
+
onnx2tf/ops/Unique.py,sha256=Dms8Pi3uo8dyBFBddc-83_x4JSJ21pbaWhxzXzYotr4,6507
|
|
203
203
|
onnx2tf/ops/Unsqueeze.py,sha256=UJun_DYfg7aQaHoeAvWlB85oRtDWq2lP7kvb0njcaC0,12219
|
|
204
204
|
onnx2tf/ops/Upsample.py,sha256=SX3N_wZHD8G5Z0PLcPgX1ZCzOdct-uTzxKeMhhzeBOw,5304
|
|
205
205
|
onnx2tf/ops/Where.py,sha256=MaCcY9g4mKZQqCgh4xtoylicP-xVu9f4boKiu_q9Ow8,7711
|
|
@@ -211,7 +211,7 @@ onnx2tf/utils/enums.py,sha256=7c5TqetqB07VjyHoxJHfLgtqBqk9ZRyUF33fPOJR1IM,1649
|
|
|
211
211
|
onnx2tf/utils/iterative_json_optimizer.py,sha256=qqeIxWGxrhcCYk8-ebWnblnOkzDCwi-nseipHzHR_bk,10436
|
|
212
212
|
onnx2tf/utils/json_auto_generator.py,sha256=OC-SfKtUg7zUxaXTAg6kT0ShzIc3ByjDa3FNp173DtA,60302
|
|
213
213
|
onnx2tf/utils/logging.py,sha256=yUCmPuJ_XiUItM3sZMcaMO24JErkQy7zZwVTYWAuiKg,1982
|
|
214
|
-
onnx2tf-1.29.
|
|
215
|
-
onnx2tf-1.29.
|
|
216
|
-
onnx2tf-1.29.
|
|
217
|
-
onnx2tf-1.29.
|
|
214
|
+
onnx2tf-1.29.22.dist-info/WHEEL,sha256=fAguSjoiATBe7TNBkJwOjyL1Tt4wwiaQGtNtjRPNMQA,80
|
|
215
|
+
onnx2tf-1.29.22.dist-info/entry_points.txt,sha256=GuhvLu7ZlYECumbmoiFlKX0mFPtFi_Ti9L-E5yuQqKs,42
|
|
216
|
+
onnx2tf-1.29.22.dist-info/METADATA,sha256=eWX7J2epMkaxU8YW1paH77k23Lw7K4i_7WsJC9f6-is,156067
|
|
217
|
+
onnx2tf-1.29.22.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|