onnx2tf 1.29.21__py3-none-any.whl → 1.29.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from onnx2tf.onnx2tf import convert, main
2
2
 
3
- __version__ = '1.29.21'
3
+ __version__ = '1.29.22'
onnx2tf/ops/MaxPool.py CHANGED
@@ -1,6 +1,7 @@
1
1
  import random
2
2
  random.seed(0)
3
3
  import numpy as np
4
+ import itertools
4
5
  np.random.seed(0)
5
6
  import tensorflow as tf
6
7
  import tf_keras
@@ -119,6 +120,59 @@ def make_node(
119
120
  **kwargs,
120
121
  )
121
122
 
123
+ # Guard: brute-force axis alignment between NCHW and NHWC when batch dim mismatches.
124
+ # Only trigger when shapes are fully known to avoid destabilizing existing behavior.
125
+ def _shape_matches(shape_a, shape_b):
126
+ if shape_a is None or shape_b is None or len(shape_a) != len(shape_b):
127
+ return False
128
+ for dim_a, dim_b in zip(shape_a, shape_b):
129
+ if dim_a is None or dim_b is None:
130
+ continue
131
+ if dim_a != dim_b:
132
+ return False
133
+ return True
134
+
135
+ def _best_perm_to_match(cur_shape, target_shape):
136
+ rank = len(cur_shape)
137
+ best_perm = None
138
+ best_cost = None
139
+ for perm in itertools.permutations(range(rank)):
140
+ permuted = [cur_shape[i] for i in perm]
141
+ if not _shape_matches(permuted, target_shape):
142
+ continue
143
+ cost = sum(abs(i - perm[i]) for i in range(rank))
144
+ if best_cost is None or cost < best_cost:
145
+ best_cost = cost
146
+ best_perm = perm
147
+ return best_perm
148
+
149
+ try:
150
+ current_shape = input_tensor.shape.as_list()
151
+ except Exception:
152
+ current_shape = None
153
+
154
+ if onnx_input_shape is not None and current_shape is not None:
155
+ onnx_shape = [
156
+ dim if isinstance(dim, int) else None for dim in onnx_input_shape
157
+ ]
158
+ cur_shape = [
159
+ dim if isinstance(dim, int) else None for dim in current_shape
160
+ ]
161
+ if len(onnx_shape) in (3, 4, 5) \
162
+ and len(cur_shape) == len(onnx_shape) \
163
+ and None not in onnx_shape \
164
+ and None not in cur_shape:
165
+ expected_shape = [onnx_shape[0]] + onnx_shape[2:] + [onnx_shape[1]]
166
+ if cur_shape[0] != onnx_shape[0] \
167
+ and not _shape_matches(cur_shape, expected_shape):
168
+ perm = _best_perm_to_match(cur_shape, expected_shape)
169
+ if perm is not None:
170
+ input_tensor = transpose_with_flexing_deterrence(
171
+ input_tensor=input_tensor,
172
+ perm=list(perm),
173
+ **kwargs,
174
+ )
175
+
122
176
  filter = None
123
177
 
124
178
  auto_pad = graph_node.attrs.get('auto_pad', 'NOTSET')
@@ -146,6 +146,48 @@ def make_node(
146
146
  axis=axis,
147
147
  indices=indices_tensor,
148
148
  )
149
+ indices_rank = None
150
+ if hasattr(indices_tensor, "shape") and indices_tensor.shape is not None:
151
+ try:
152
+ indices_rank = len(indices_tensor.shape)
153
+ except TypeError:
154
+ indices_rank = indices_tensor.shape.rank
155
+ updates_rank = updates_tensor_rank
156
+ broadcast_shape = None
157
+ pad_rank = 0
158
+
159
+ def _pad_and_broadcast(target_tensor, pad_rank, target_shape):
160
+ tensor = target_tensor
161
+ if isinstance(tensor, np.ndarray):
162
+ tensor = tf.convert_to_tensor(tensor)
163
+ if pad_rank <= 0:
164
+ return tf.broadcast_to(tensor, target_shape)
165
+ tensor_shape = tf.shape(tensor)
166
+ new_shape = tf.concat(
167
+ [tf.ones([pad_rank], dtype=tf.int32), tensor_shape],
168
+ axis=0,
169
+ )
170
+ tensor = tf.reshape(tensor, new_shape)
171
+ return tf.broadcast_to(tensor, target_shape)
172
+
173
+ updates_tensor_for_scatter = updates_tensor
174
+ if indices_rank is not None and updates_rank is not None and indices_rank != updates_rank:
175
+ if indices_rank > updates_rank:
176
+ broadcast_shape = tf.shape(indices_tensor)
177
+ pad_rank = indices_rank - updates_rank
178
+ updates_tensor_for_scatter = _pad_and_broadcast(
179
+ updates_tensor,
180
+ pad_rank,
181
+ broadcast_shape,
182
+ )
183
+ else:
184
+ broadcast_shape = tf.shape(updates_tensor)
185
+ pad_rank = updates_rank - indices_rank
186
+ indices_tensor = _pad_and_broadcast(
187
+ indices_tensor,
188
+ pad_rank,
189
+ broadcast_shape,
190
+ )
149
191
  sparsified_dense_idx_shape = updates_tensor_shape
150
192
 
151
193
  if None not in sparsified_dense_idx_shape:
@@ -160,6 +202,13 @@ def make_node(
160
202
  ]
161
203
 
162
204
  idx_tensors_per_axis = tf.meshgrid(*idx_tensors_per_axis, indexing='ij')
205
+ if indices_rank is not None \
206
+ and updates_rank is not None \
207
+ and indices_rank > updates_rank:
208
+ idx_tensors_per_axis = [
209
+ _pad_and_broadcast(idx_tensor, pad_rank, broadcast_shape)
210
+ for idx_tensor in idx_tensors_per_axis
211
+ ]
163
212
  idx_tensors_per_axis[axis] = indices_tensor
164
213
  dim_expanded_idx_tensors_per_axis = [
165
214
  tf.expand_dims(idx_tensor, axis=-1)
@@ -194,7 +243,7 @@ def make_node(
194
243
  )
195
244
 
196
245
  indices = tf.reshape(coordinate, [-1, input_tensor_rank])
197
- updates = tf.reshape(updates_tensor, [-1])
246
+ updates = tf.reshape(updates_tensor_for_scatter, [-1])
198
247
  output = tf.tensor_scatter_nd_update(
199
248
  tensor=input_tensor,
200
249
  indices=indices,
onnx2tf/ops/Unique.py CHANGED
@@ -14,6 +14,7 @@ from onnx2tf.utils.common_functions import (
14
14
  make_tf_node_info,
15
15
  get_replacement_parameter,
16
16
  pre_process_transpose,
17
+ convert_axis,
17
18
  )
18
19
  from onnx2tf.utils.logging import Color
19
20
 
@@ -69,8 +70,20 @@ def make_node(
69
70
  **kwargs,
70
71
  )
71
72
 
73
+ input_tensor_shape = input_tensor.shape
74
+ tensor_rank = len(input_tensor_shape) \
75
+ if input_tensor_shape != tf.TensorShape(None) else 1
76
+
72
77
  axis = graph_node.attrs.get('axis', None)
73
78
  sorted = graph_node.attrs.get('sorted', 1)
79
+ if axis is not None:
80
+ if isinstance(axis, np.ndarray) and axis.shape == ():
81
+ axis = int(axis)
82
+ axis = convert_axis(
83
+ axis=int(axis),
84
+ tensor_rank=tensor_rank,
85
+ before_op_output_shape_trans=before_op_output_shape_trans,
86
+ )
74
87
 
75
88
  # Preserving Graph Structure (Dict)
76
89
  for graph_node_output in graph_node_outputs:
@@ -101,17 +114,64 @@ def make_node(
101
114
 
102
115
  # tf unique returns unsorted tensor, need to sort if option is enabled
103
116
  if sorted:
104
- # TODO: implement sort
105
- error_msg = f'' + \
106
- Color.RED(f'WARNING:') + ' ' + \
107
- f'Sort option in Unique ops is not implemented yet.'
108
- print(error_msg)
109
- assert False, error_msg
110
-
111
- tf_layers_dict[graph_node_outputs[0].name]['tf_node'] = y
112
- tf_layers_dict[graph_node_outputs[1].name]['tf_node'] = indices
113
- tf_layers_dict[graph_node_outputs[2].name]['tf_node'] = inverse_indices
114
- tf_layers_dict[graph_node_outputs[3].name]['tf_node'] = count
117
+ # Sort unique outputs to match ONNX sorted behavior.
118
+ def _argsort_supported(dtype):
119
+ return dtype.is_floating or dtype.is_integer or dtype == tf.bool
120
+
121
+ y_rank = y.shape.rank
122
+ axis_ = axis
123
+ if axis_ is None:
124
+ axis_ = 0
125
+ if axis_ < 0 and y_rank is not None:
126
+ axis_ = axis_ + y_rank
127
+
128
+ def _lexsort_perm(flat_2d):
129
+ if not _argsort_supported(flat_2d.dtype):
130
+ return None
131
+ cols = flat_2d.shape[1]
132
+ if cols is None:
133
+ return None
134
+ order = tf.range(tf.shape(flat_2d)[0])
135
+ for col in reversed(range(cols)):
136
+ col_vals = tf.gather(flat_2d, order)[:, col]
137
+ if col_vals.dtype == tf.bool:
138
+ col_vals = tf.cast(col_vals, tf.int32)
139
+ order = tf.gather(order, tf.argsort(col_vals, stable=True))
140
+ return order
141
+
142
+ order = None
143
+ if y_rank is not None and y_rank == 1:
144
+ if _argsort_supported(y.dtype):
145
+ sort_vals = y
146
+ if sort_vals.dtype == tf.bool:
147
+ sort_vals = tf.cast(sort_vals, tf.int32)
148
+ order = tf.argsort(sort_vals, stable=True)
149
+ elif y_rank is not None and axis_ is not None and 0 <= axis_ < y_rank:
150
+ perm = [axis_] + [i for i in range(y_rank) if i != axis_]
151
+ y_t = tf.transpose(y, perm)
152
+ flat = tf.reshape(y_t, [tf.shape(y_t)[0], -1])
153
+ order = _lexsort_perm(flat)
154
+
155
+ if order is None:
156
+ warn_msg = f'' + \
157
+ Color.YELLOW(f'WARNING:') + ' ' + \
158
+ f'Unique sort fallback to unsorted due to dynamic shape or unsupported dtype.'
159
+ print(warn_msg)
160
+ else:
161
+ y = tf.gather(y, order, axis=axis_)
162
+ count = tf.gather(count, order)
163
+ indices = tf.gather(indices, order)
164
+ inv_order = tf.argsort(order)
165
+ inverse_indices = tf.gather(inv_order, inverse_indices)
166
+
167
+ if len(graph_node_outputs) >= 1:
168
+ tf_layers_dict[graph_node_outputs[0].name]['tf_node'] = y
169
+ if len(graph_node_outputs) >= 2:
170
+ tf_layers_dict[graph_node_outputs[1].name]['tf_node'] = indices
171
+ if len(graph_node_outputs) >= 3:
172
+ tf_layers_dict[graph_node_outputs[2].name]['tf_node'] = inverse_indices
173
+ if len(graph_node_outputs) >= 4:
174
+ tf_layers_dict[graph_node_outputs[3].name]['tf_node'] = count
115
175
 
116
176
  # Generation of Debug Info
117
177
  tf_outputs = {f"output{idx}": value for idx, value in enumerate([y, indices, inverse_indices, count])}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx2tf
3
- Version: 1.29.21
3
+ Version: 1.29.22
4
4
  Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
5
5
  Keywords: onnx,tensorflow,tflite,keras,deep-learning,machine-learning
6
6
  Author: Katsuya Hyodo
@@ -365,7 +365,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
365
365
  docker run --rm -it \
366
366
  -v `pwd`:/workdir \
367
367
  -w /workdir \
368
- ghcr.io/pinto0309/onnx2tf:1.29.21
368
+ ghcr.io/pinto0309/onnx2tf:1.29.22
369
369
 
370
370
  or
371
371
 
@@ -373,7 +373,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
373
373
  docker run --rm -it \
374
374
  -v `pwd`:/workdir \
375
375
  -w /workdir \
376
- docker.io/pinto0309/onnx2tf:1.29.21
376
+ docker.io/pinto0309/onnx2tf:1.29.22
377
377
 
378
378
  or
379
379
 
@@ -1,4 +1,4 @@
1
- onnx2tf/__init__.py,sha256=oN6Sb7PL3XQhFQGL8NsC07srojuozSUJPFJDOv1ST2k,67
1
+ onnx2tf/__init__.py,sha256=EUJwryuQmTqI1Huqr9hrm1fJwxQqmGxUjAFgGvuB950,67
2
2
  onnx2tf/__main__.py,sha256=2RSCQ7d4lc6CwD-rlGn9UicPFg-P5du7ZD_yh-kuBEU,57
3
3
  onnx2tf/onnx2tf.py,sha256=BC-BFMf8QUG7PtOvpwglhe1sc4FhTO8AMrdlxKUN5jc,208204
4
4
  onnx2tf/ops/Abs.py,sha256=V7btmCG_ZvK_qJovUsguq0ZMJ349mhNQ4FHSgzP_Yuo,4029
@@ -98,7 +98,7 @@ onnx2tf/ops/LpPool.py,sha256=96eI1FaDgW0M_USWBCHFedvtojHTLL28_lb3mcEV55A,10470
98
98
  onnx2tf/ops/MatMul.py,sha256=X9cQSD4BCogkDP6D4YZEZmOWnsceGL8ppN8E4kqyjB0,23926
99
99
  onnx2tf/ops/MatMulInteger.py,sha256=qHqzdJNI9SeJDbW8pR90baYCdGN6FdOez4hi9EzwXoc,6538
100
100
  onnx2tf/ops/Max.py,sha256=w5nMciO_6ApYUobHuwMGuS3xhuza7eSvKDRhvMPgAuo,3256
101
- onnx2tf/ops/MaxPool.py,sha256=_JC4eqBTh-qLkZCMG8RZhthRZ8D2d821zaFMWeGMEWc,15775
101
+ onnx2tf/ops/MaxPool.py,sha256=0zO5gNfcwAhXAVxIbURiXAlLQ56oJvtpcyLm4leZVsM,17948
102
102
  onnx2tf/ops/MaxRoiPool.py,sha256=RYZyjnINqJd6k7KLFJ-D9iHjA2vR-m7WvhrumD9cmDk,8490
103
103
  onnx2tf/ops/MaxUnpool.py,sha256=dGIEvC45rFuWoeG1S9j4sjEdEUqiWs_xdY5DZH6X7b4,5743
104
104
  onnx2tf/ops/Mean.py,sha256=xfTjKpQntJB8uXAkgDLS77oLXy2yQh1Hlz0K2GltMh0,3132
@@ -160,7 +160,7 @@ onnx2tf/ops/STFT.py,sha256=LDKN309_dBu4v9AYpz70uMJbNjRFiOte9O3wUL4bIJw,4463
160
160
  onnx2tf/ops/ScaleAndTranslate.py,sha256=VQDDhSs9TyMLQy0mF7n8pZ2TuvoKY-Lhlzd7Inf4UdI,11989
161
161
  onnx2tf/ops/Scan.py,sha256=hfN-DX6Gp-dG5158WMoHRrDWZAra3VSbsjsiphNqRIQ,16293
162
162
  onnx2tf/ops/Scatter.py,sha256=5_rTM60FPCq8unyNPDO-BZXcuz6w9Uyl2Xqx-zJTpgg,746
163
- onnx2tf/ops/ScatterElements.py,sha256=mp-TmswDTA9Nv0B3G3b-khOCPCKHnhCI97jDRofoEM0,8561
163
+ onnx2tf/ops/ScatterElements.py,sha256=vXURNSkorfm7iQ_HA5vY9nf6YIkrAdqmB7sdEtWnUCo,10452
164
164
  onnx2tf/ops/ScatterND.py,sha256=-mVbxXjQor2T6HVHSJy5e0FHQmEfaHknaKPuSc3Oz4o,11005
165
165
  onnx2tf/ops/Selu.py,sha256=CD0SqQlTTe0chO7lebkrdfDFSk6Cg9zLhvrKomsSH4Y,3799
166
166
  onnx2tf/ops/SequenceAt.py,sha256=jpjl9gVJFagtg223YY26I0pUUEgEFjJGvSZWwbo2-mQ,3278
@@ -199,7 +199,7 @@ onnx2tf/ops/Tile.py,sha256=xkprg6yTaykivcHFJ644opzVPctaeplu-Ed-OpS98Gg,12720
199
199
  onnx2tf/ops/TopK.py,sha256=f6OG-DcMWneXwSjIkmY935SPyOMD5tMteHnlQHoJwQo,6348
200
200
  onnx2tf/ops/Transpose.py,sha256=GwJFp7zVqodEsv5mGWviuFqeK93uVM7dbRQ1N8Ua1hg,9774
201
201
  onnx2tf/ops/Trilu.py,sha256=uz2TgdErpo9GDp9n4PCe0_koIpNLgBoPCjv3A6VBTl8,4789
202
- onnx2tf/ops/Unique.py,sha256=GUuOeTO9px22dHmlAn2SOmRHvBgSXo-SaPWm5rYUtPc,4084
202
+ onnx2tf/ops/Unique.py,sha256=Dms8Pi3uo8dyBFBddc-83_x4JSJ21pbaWhxzXzYotr4,6507
203
203
  onnx2tf/ops/Unsqueeze.py,sha256=UJun_DYfg7aQaHoeAvWlB85oRtDWq2lP7kvb0njcaC0,12219
204
204
  onnx2tf/ops/Upsample.py,sha256=SX3N_wZHD8G5Z0PLcPgX1ZCzOdct-uTzxKeMhhzeBOw,5304
205
205
  onnx2tf/ops/Where.py,sha256=MaCcY9g4mKZQqCgh4xtoylicP-xVu9f4boKiu_q9Ow8,7711
@@ -211,7 +211,7 @@ onnx2tf/utils/enums.py,sha256=7c5TqetqB07VjyHoxJHfLgtqBqk9ZRyUF33fPOJR1IM,1649
211
211
  onnx2tf/utils/iterative_json_optimizer.py,sha256=qqeIxWGxrhcCYk8-ebWnblnOkzDCwi-nseipHzHR_bk,10436
212
212
  onnx2tf/utils/json_auto_generator.py,sha256=OC-SfKtUg7zUxaXTAg6kT0ShzIc3ByjDa3FNp173DtA,60302
213
213
  onnx2tf/utils/logging.py,sha256=yUCmPuJ_XiUItM3sZMcaMO24JErkQy7zZwVTYWAuiKg,1982
214
- onnx2tf-1.29.21.dist-info/WHEEL,sha256=fAguSjoiATBe7TNBkJwOjyL1Tt4wwiaQGtNtjRPNMQA,80
215
- onnx2tf-1.29.21.dist-info/entry_points.txt,sha256=GuhvLu7ZlYECumbmoiFlKX0mFPtFi_Ti9L-E5yuQqKs,42
216
- onnx2tf-1.29.21.dist-info/METADATA,sha256=LKRmQIHHTw23h1BZd-KhyHQz46BWSK9ib2WaUvgyld8,156067
217
- onnx2tf-1.29.21.dist-info/RECORD,,
214
+ onnx2tf-1.29.22.dist-info/WHEEL,sha256=fAguSjoiATBe7TNBkJwOjyL1Tt4wwiaQGtNtjRPNMQA,80
215
+ onnx2tf-1.29.22.dist-info/entry_points.txt,sha256=GuhvLu7ZlYECumbmoiFlKX0mFPtFi_Ti9L-E5yuQqKs,42
216
+ onnx2tf-1.29.22.dist-info/METADATA,sha256=eWX7J2epMkaxU8YW1paH77k23Lw7K4i_7WsJC9f6-is,156067
217
+ onnx2tf-1.29.22.dist-info/RECORD,,