onnx2tf 1.28.1__py3-none-any.whl → 1.28.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnx2tf/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from onnx2tf.onnx2tf import convert, main
2
2
 
3
- __version__ = '1.28.1'
3
+ __version__ = '1.28.2'
onnx2tf/ops/MatMul.py CHANGED
@@ -25,6 +25,59 @@ from onnx2tf.utils.enums import (
25
25
  )
26
26
  from typing import Any, Dict, List
27
27
 
28
+ def _matmul_output_shape(shape_a, shape_b):
29
+ """ _matmul_output_shape
30
+ Computes the shape of matrix M=np.matmul(A, B) given shapes of A and B.
31
+
32
+ Parameters
33
+ ----------
34
+ shape_a: tuple with shape of matrix A.
35
+ shape_b: tuple with shape of matrix B.
36
+
37
+ Returns
38
+ -------
39
+ Shape of matrix M=np.matmul(A, B)
40
+ """
41
+ if len(shape_a) == 0 or len(shape_b) == 0 or \
42
+ None in shape_a or None in shape_b or \
43
+ any(dim <= 0 for dim in shape_a) or any(dim <= 0 for dim in shape_b):
44
+ # If there are no dimensions, any dimension is None or not positive, we cannot determine the output shape
45
+ return None
46
+ # Handle 1D cases as per numpy.matmul rules
47
+ if len(shape_a) == 1 and len(shape_b) == 1:
48
+ # Vector dot product -> scalar
49
+ if shape_a[0] != shape_b[0]:
50
+ return None # Incompatible shapes for matmul
51
+ else:
52
+ return ()
53
+ elif len(shape_a) == 1:
54
+ # (K,) @ (..., K, N) -> (..., N)
55
+ if shape_a[0] != shape_b[-2]:
56
+ return None # Incompatible shapes for matmul
57
+ batch_shape = shape_b[:-2]
58
+ return batch_shape + (shape_b[-1],)
59
+ elif len(shape_b) == 1:
60
+ # (..., M, K) @ (K,) -> (..., M)
61
+ if shape_a[-1] != shape_b[0]:
62
+ return None # Incompatible shapes for matmul
63
+ batch_shape = shape_a[:-2]
64
+ return batch_shape + (shape_a[-2],)
65
+ else:
66
+ # (..., M, K) @ (..., K, N) -> broadcast(...), M, N
67
+ # prepend the shorter shape with 1s to match lengths
68
+ if len(shape_a) < len(shape_b):
69
+ shape_a = (1,) * (len(shape_b) - len(shape_a)) + shape_a
70
+ elif len(shape_b) < len(shape_a):
71
+ shape_b = (1,) * (len(shape_a) - len(shape_b)) + shape_b
72
+
73
+ if shape_a[-1] != shape_b[-2]:
74
+ return None # Incompatible shapes for matmul
75
+
76
+ try:
77
+ batch_shape = np.broadcast_shapes(shape_a[:-2], shape_b[:-2])
78
+ except ValueError:
79
+ return None # If broadcasting fails, it means the batch dimensions are incompatible
80
+ return batch_shape + (shape_a[-2], shape_b[-1])
28
81
 
29
82
  @print_node_info
30
83
  @inverted_operation_enable_disable
@@ -199,8 +252,13 @@ def make_node(
199
252
  and sum([1 if isinstance(s, str) else 0 for s in target_onnx_output_shape]) == 0:
200
253
  dummy_np_1 = np.ones(list(input_tensor_1.shape), dtype=np.float32).transpose(tensor_1_candidate_for_transposition)
201
254
  dummy_np_2 = np.ones(list(input_tensor_2.shape), dtype=np.float32).transpose(tensor_2_candidate_for_transposition)
202
- dummy_np_result: np.ndarray = np.matmul(dummy_np_1, dummy_np_2)
203
- if np.prod(dummy_np_result.shape) != np.prod(target_onnx_output_shape):
255
+
256
+ actual_output_shape = _matmul_output_shape(dummy_np_1.shape, dummy_np_2.shape)
257
+ if actual_output_shape is None:
258
+ dummy_np_result: np.ndarray = np.matmul(dummy_np_1, dummy_np_2)
259
+ actual_output_shape = dummy_np_result.shape
260
+
261
+ if np.prod(actual_output_shape) != np.prod(target_onnx_output_shape):
204
262
  continue
205
263
 
206
264
  # Build TF dummy model
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx2tf
3
- Version: 1.28.1
3
+ Version: 1.28.2
4
4
  Summary: Self-Created Tools to convert ONNX files (NCHW) to TensorFlow/TFLite/Keras format (NHWC). The purpose of this tool is to solve the massive Transpose extrapolation problem in onnx-tensorflow (onnx-tf).
5
5
  Home-page: https://github.com/PINTO0309/onnx2tf
6
6
  Author: Katsuya Hyodo
@@ -334,7 +334,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
334
334
  docker run --rm -it \
335
335
  -v `pwd`:/workdir \
336
336
  -w /workdir \
337
- ghcr.io/pinto0309/onnx2tf:1.28.1
337
+ ghcr.io/pinto0309/onnx2tf:1.28.2
338
338
 
339
339
  or
340
340
 
@@ -342,7 +342,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
342
342
  docker run --rm -it \
343
343
  -v `pwd`:/workdir \
344
344
  -w /workdir \
345
- docker.io/pinto0309/onnx2tf:1.28.1
345
+ docker.io/pinto0309/onnx2tf:1.28.2
346
346
 
347
347
  or
348
348
 
@@ -1,4 +1,4 @@
1
- onnx2tf/__init__.py,sha256=bDhHm0Li-GVAqSndFjjdNfRBkb8ErJ0Zq5yO8jupiuo,66
1
+ onnx2tf/__init__.py,sha256=Q-EJ5Kmj8v4pyqNo39_LSMw82TTyA_BlTLfHEgfsSiE,66
2
2
  onnx2tf/__main__.py,sha256=2RSCQ7d4lc6CwD-rlGn9UicPFg-P5du7ZD_yh-kuBEU,57
3
3
  onnx2tf/onnx2tf.py,sha256=ufjdjeokS96PyvhLAV4nOaiPZ69FCP2kbvBVGahzxxQ,146784
4
4
  onnx2tf/ops/Abs.py,sha256=V7btmCG_ZvK_qJovUsguq0ZMJ349mhNQ4FHSgzP_Yuo,4029
@@ -82,7 +82,7 @@ onnx2tf/ops/LessOrEqual.py,sha256=9Lc8qaYUPVC6yZoQluNqcdHnvpUbfWBOI4Ow38RRAJo,45
82
82
  onnx2tf/ops/Log.py,sha256=UZebF3SGq85BnoPgYyN2j-zzFRp67fJnYPNyu33W55o,3582
83
83
  onnx2tf/ops/LogSoftmax.py,sha256=j2nhYY7__8ViLFJVLA5tS98QEvGS1gTIW0QCdnZWUPQ,3923
84
84
  onnx2tf/ops/LpNormalization.py,sha256=Uu15HgxFNXb6gNMgdTJyf0SLPaLbcbkOYqY_4hMBxNA,3153
85
- onnx2tf/ops/MatMul.py,sha256=95HrWr3Dt6BLqx_zqm3WXBw_WzrWLObYVgz4K1yrhqE,19060
85
+ onnx2tf/ops/MatMul.py,sha256=KHhRyQCyxe6845f-AOI1UJzA3rGTssG6eyKmDw0oegs,21466
86
86
  onnx2tf/ops/MatMulInteger.py,sha256=qHqzdJNI9SeJDbW8pR90baYCdGN6FdOez4hi9EzwXoc,6538
87
87
  onnx2tf/ops/Max.py,sha256=w5nMciO_6ApYUobHuwMGuS3xhuza7eSvKDRhvMPgAuo,3256
88
88
  onnx2tf/ops/MaxPool.py,sha256=_JC4eqBTh-qLkZCMG8RZhthRZ8D2d821zaFMWeGMEWc,15775
@@ -190,10 +190,10 @@ onnx2tf/utils/enums.py,sha256=7c5TqetqB07VjyHoxJHfLgtqBqk9ZRyUF33fPOJR1IM,1649
190
190
  onnx2tf/utils/iterative_json_optimizer.py,sha256=qqeIxWGxrhcCYk8-ebWnblnOkzDCwi-nseipHzHR_bk,10436
191
191
  onnx2tf/utils/json_auto_generator.py,sha256=Vyy21SYEoSL0b-I1cUnaXR-CPoO8LJYQ3fAS2ulZSMM,61964
192
192
  onnx2tf/utils/logging.py,sha256=yUCmPuJ_XiUItM3sZMcaMO24JErkQy7zZwVTYWAuiKg,1982
193
- onnx2tf-1.28.1.dist-info/licenses/LICENSE,sha256=5v_Kxihy8i6mzHVl349ikSREaIdsl9YeUnX1KBDLD2w,1070
194
- onnx2tf-1.28.1.dist-info/licenses/LICENSE_onnx-tensorflow,sha256=gK4GtS9S5YcyINu6uuNNWdo-kBClyEM4MFLFGiNTeRM,11231
195
- onnx2tf-1.28.1.dist-info/METADATA,sha256=YJh7w6UDBuxyXgmBP_2fYu8lAjnKBc1arwPPJCT_MCc,151177
196
- onnx2tf-1.28.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
197
- onnx2tf-1.28.1.dist-info/entry_points.txt,sha256=gDPK8ToCFPKMvm8jr9xrGOkXtORJJVh4736fBEKO5k0,41
198
- onnx2tf-1.28.1.dist-info/top_level.txt,sha256=WgfPiEy3f6vZ_FOpAIEA2CF3TCx1eYrhGw93Ih6b9Fw,8
199
- onnx2tf-1.28.1.dist-info/RECORD,,
193
+ onnx2tf-1.28.2.dist-info/licenses/LICENSE,sha256=5v_Kxihy8i6mzHVl349ikSREaIdsl9YeUnX1KBDLD2w,1070
194
+ onnx2tf-1.28.2.dist-info/licenses/LICENSE_onnx-tensorflow,sha256=gK4GtS9S5YcyINu6uuNNWdo-kBClyEM4MFLFGiNTeRM,11231
195
+ onnx2tf-1.28.2.dist-info/METADATA,sha256=psCTn2IAkkXxmDaFe_1wLdA_lyCdP_ooOUvoiOluBfk,151177
196
+ onnx2tf-1.28.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
197
+ onnx2tf-1.28.2.dist-info/entry_points.txt,sha256=gDPK8ToCFPKMvm8jr9xrGOkXtORJJVh4736fBEKO5k0,41
198
+ onnx2tf-1.28.2.dist-info/top_level.txt,sha256=WgfPiEy3f6vZ_FOpAIEA2CF3TCx1eYrhGw93Ih6b9Fw,8
199
+ onnx2tf-1.28.2.dist-info/RECORD,,