onnx 1.16.1__cp38-cp38-win32.whl → 1.17.0__cp38-cp38-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx might be problematic. Click here for more details.

Files changed (843) hide show
  1. onnx/__init__.py +3 -1
  2. onnx/_custom_element_types.py +63 -0
  3. onnx/backend/base.py +17 -15
  4. onnx/backend/sample/ops/__init__.py +4 -4
  5. onnx/backend/sample/ops/abs.py +1 -0
  6. onnx/backend/test/__init__.py +1 -0
  7. onnx/backend/test/case/__init__.py +2 -2
  8. onnx/backend/test/case/base.py +6 -5
  9. onnx/backend/test/case/model/__init__.py +4 -3
  10. onnx/backend/test/case/model/expand.py +1 -0
  11. onnx/backend/test/case/model/gradient.py +1 -0
  12. onnx/backend/test/case/model/sequence.py +3 -1
  13. onnx/backend/test/case/model/shrink.py +1 -0
  14. onnx/backend/test/case/model/sign.py +1 -0
  15. onnx/backend/test/case/model/single-relu.py +1 -0
  16. onnx/backend/test/case/model/stringnormalizer.py +1 -1
  17. onnx/backend/test/case/node/__init__.py +31 -22
  18. onnx/backend/test/case/node/_image_decoder_data.py +1 -0
  19. onnx/backend/test/case/node/abs.py +1 -0
  20. onnx/backend/test/case/node/acos.py +1 -0
  21. onnx/backend/test/case/node/acosh.py +1 -0
  22. onnx/backend/test/case/node/adagrad.py +2 -1
  23. onnx/backend/test/case/node/adam.py +4 -1
  24. onnx/backend/test/case/node/add.py +1 -0
  25. onnx/backend/test/case/node/affinegrid.py +1 -0
  26. onnx/backend/test/case/node/ai_onnx_ml/array_feature_extractor.py +1 -0
  27. onnx/backend/test/case/node/ai_onnx_ml/binarizer.py +1 -0
  28. onnx/backend/test/case/node/ai_onnx_ml/label_encoder.py +1 -0
  29. onnx/backend/test/case/node/ai_onnx_ml/tree_ensemble.py +1 -0
  30. onnx/backend/test/case/node/and.py +1 -0
  31. onnx/backend/test/case/node/argmax.py +1 -0
  32. onnx/backend/test/case/node/argmin.py +1 -0
  33. onnx/backend/test/case/node/asin.py +1 -0
  34. onnx/backend/test/case/node/asinh.py +1 -0
  35. onnx/backend/test/case/node/atan.py +1 -0
  36. onnx/backend/test/case/node/atanh.py +1 -0
  37. onnx/backend/test/case/node/averagepool.py +1 -0
  38. onnx/backend/test/case/node/batchnorm.py +1 -0
  39. onnx/backend/test/case/node/bernoulli.py +1 -0
  40. onnx/backend/test/case/node/bitshift.py +1 -0
  41. onnx/backend/test/case/node/bitwiseand.py +1 -0
  42. onnx/backend/test/case/node/bitwisenot.py +1 -0
  43. onnx/backend/test/case/node/bitwiseor.py +1 -0
  44. onnx/backend/test/case/node/bitwisexor.py +1 -0
  45. onnx/backend/test/case/node/blackmanwindow.py +13 -3
  46. onnx/backend/test/case/node/cast.py +2 -1
  47. onnx/backend/test/case/node/castlike.py +1 -0
  48. onnx/backend/test/case/node/ceil.py +1 -0
  49. onnx/backend/test/case/node/celu.py +1 -0
  50. onnx/backend/test/case/node/center_crop_pad.py +1 -0
  51. onnx/backend/test/case/node/clip.py +1 -0
  52. onnx/backend/test/case/node/col2im.py +1 -1
  53. onnx/backend/test/case/node/compress.py +1 -0
  54. onnx/backend/test/case/node/concat.py +3 -2
  55. onnx/backend/test/case/node/constant.py +1 -0
  56. onnx/backend/test/case/node/constantofshape.py +1 -0
  57. onnx/backend/test/case/node/conv.py +1 -0
  58. onnx/backend/test/case/node/convinteger.py +1 -0
  59. onnx/backend/test/case/node/convtranspose.py +135 -0
  60. onnx/backend/test/case/node/cos.py +1 -0
  61. onnx/backend/test/case/node/cosh.py +1 -0
  62. onnx/backend/test/case/node/cumsum.py +1 -0
  63. onnx/backend/test/case/node/deformconv.py +17 -26
  64. onnx/backend/test/case/node/depthtospace.py +1 -0
  65. onnx/backend/test/case/node/dequantizelinear.py +1 -0
  66. onnx/backend/test/case/node/det.py +1 -0
  67. onnx/backend/test/case/node/dft.py +1 -0
  68. onnx/backend/test/case/node/div.py +1 -0
  69. onnx/backend/test/case/node/dropout.py +1 -0
  70. onnx/backend/test/case/node/dynamicquantizelinear.py +1 -0
  71. onnx/backend/test/case/node/einsum.py +2 -3
  72. onnx/backend/test/case/node/elu.py +1 -0
  73. onnx/backend/test/case/node/equal.py +1 -0
  74. onnx/backend/test/case/node/erf.py +1 -0
  75. onnx/backend/test/case/node/exp.py +1 -0
  76. onnx/backend/test/case/node/expand.py +1 -0
  77. onnx/backend/test/case/node/eyelike.py +1 -0
  78. onnx/backend/test/case/node/flatten.py +1 -0
  79. onnx/backend/test/case/node/floor.py +1 -0
  80. onnx/backend/test/case/node/gather.py +1 -0
  81. onnx/backend/test/case/node/gatherelements.py +1 -0
  82. onnx/backend/test/case/node/gathernd.py +1 -0
  83. onnx/backend/test/case/node/gelu.py +1 -0
  84. onnx/backend/test/case/node/gemm.py +3 -4
  85. onnx/backend/test/case/node/globalaveragepool.py +1 -0
  86. onnx/backend/test/case/node/globalmaxpool.py +1 -0
  87. onnx/backend/test/case/node/greater.py +1 -0
  88. onnx/backend/test/case/node/greater_equal.py +1 -0
  89. onnx/backend/test/case/node/gridsample.py +1 -0
  90. onnx/backend/test/case/node/groupnormalization.py +1 -0
  91. onnx/backend/test/case/node/gru.py +3 -2
  92. onnx/backend/test/case/node/hammingwindow.py +13 -2
  93. onnx/backend/test/case/node/hannwindow.py +10 -2
  94. onnx/backend/test/case/node/hardmax.py +1 -0
  95. onnx/backend/test/case/node/hardsigmoid.py +1 -0
  96. onnx/backend/test/case/node/hardswish.py +1 -0
  97. onnx/backend/test/case/node/identity.py +1 -0
  98. onnx/backend/test/case/node/if.py +1 -0
  99. onnx/backend/test/case/node/instancenorm.py +1 -0
  100. onnx/backend/test/case/node/isinf.py +1 -0
  101. onnx/backend/test/case/node/isnan.py +1 -0
  102. onnx/backend/test/case/node/layernormalization.py +1 -0
  103. onnx/backend/test/case/node/leakyrelu.py +1 -0
  104. onnx/backend/test/case/node/less.py +1 -0
  105. onnx/backend/test/case/node/less_equal.py +1 -0
  106. onnx/backend/test/case/node/log.py +1 -0
  107. onnx/backend/test/case/node/logsoftmax.py +1 -0
  108. onnx/backend/test/case/node/loop.py +4 -3
  109. onnx/backend/test/case/node/lppool.py +1 -0
  110. onnx/backend/test/case/node/lrn.py +1 -0
  111. onnx/backend/test/case/node/lstm.py +3 -2
  112. onnx/backend/test/case/node/matmul.py +1 -0
  113. onnx/backend/test/case/node/matmulinteger.py +1 -0
  114. onnx/backend/test/case/node/max.py +1 -0
  115. onnx/backend/test/case/node/maxpool.py +1 -0
  116. onnx/backend/test/case/node/maxunpool.py +1 -0
  117. onnx/backend/test/case/node/mean.py +1 -0
  118. onnx/backend/test/case/node/meanvariancenormalization.py +1 -0
  119. onnx/backend/test/case/node/melweightmatrix.py +1 -0
  120. onnx/backend/test/case/node/min.py +1 -0
  121. onnx/backend/test/case/node/mish.py +1 -0
  122. onnx/backend/test/case/node/mod.py +1 -0
  123. onnx/backend/test/case/node/momentum.py +1 -0
  124. onnx/backend/test/case/node/mul.py +1 -0
  125. onnx/backend/test/case/node/neg.py +1 -0
  126. onnx/backend/test/case/node/negativeloglikelihoodloss.py +4 -1
  127. onnx/backend/test/case/node/nonmaxsuppression.py +1 -0
  128. onnx/backend/test/case/node/nonzero.py +1 -0
  129. onnx/backend/test/case/node/not.py +1 -0
  130. onnx/backend/test/case/node/onehot.py +1 -0
  131. onnx/backend/test/case/node/optionalgetelement.py +3 -2
  132. onnx/backend/test/case/node/optionalhaselement.py +2 -3
  133. onnx/backend/test/case/node/or.py +1 -0
  134. onnx/backend/test/case/node/pad.py +2 -1
  135. onnx/backend/test/case/node/pow.py +1 -0
  136. onnx/backend/test/case/node/prelu.py +1 -0
  137. onnx/backend/test/case/node/qlinearconv.py +1 -0
  138. onnx/backend/test/case/node/qlinearmatmul.py +1 -0
  139. onnx/backend/test/case/node/quantizelinear.py +1 -0
  140. onnx/backend/test/case/node/rangeop.py +1 -0
  141. onnx/backend/test/case/node/reciprocal.py +1 -0
  142. onnx/backend/test/case/node/reduce_log_sum.py +1 -0
  143. onnx/backend/test/case/node/reduce_log_sum_exp.py +1 -0
  144. onnx/backend/test/case/node/reducel1.py +1 -0
  145. onnx/backend/test/case/node/reducel2.py +1 -0
  146. onnx/backend/test/case/node/reducemax.py +2 -1
  147. onnx/backend/test/case/node/reducemean.py +1 -0
  148. onnx/backend/test/case/node/reducemin.py +1 -0
  149. onnx/backend/test/case/node/reduceprod.py +1 -0
  150. onnx/backend/test/case/node/reducesum.py +2 -1
  151. onnx/backend/test/case/node/reducesumsquare.py +1 -0
  152. onnx/backend/test/case/node/regex_full_match.py +1 -0
  153. onnx/backend/test/case/node/relu.py +1 -0
  154. onnx/backend/test/case/node/reshape.py +1 -0
  155. onnx/backend/test/case/node/resize.py +3 -2
  156. onnx/backend/test/case/node/reversesequence.py +1 -0
  157. onnx/backend/test/case/node/rnn.py +3 -2
  158. onnx/backend/test/case/node/roialign.py +1 -0
  159. onnx/backend/test/case/node/round.py +4 -3
  160. onnx/backend/test/case/node/scan.py +1 -0
  161. onnx/backend/test/case/node/scatter.py +1 -0
  162. onnx/backend/test/case/node/scatterelements.py +7 -3
  163. onnx/backend/test/case/node/scatternd.py +1 -0
  164. onnx/backend/test/case/node/selu.py +1 -0
  165. onnx/backend/test/case/node/sequence_map.py +1 -0
  166. onnx/backend/test/case/node/sequenceinsert.py +4 -3
  167. onnx/backend/test/case/node/shape.py +1 -0
  168. onnx/backend/test/case/node/shrink.py +1 -0
  169. onnx/backend/test/case/node/sigmoid.py +1 -0
  170. onnx/backend/test/case/node/sign.py +1 -0
  171. onnx/backend/test/case/node/sin.py +1 -0
  172. onnx/backend/test/case/node/sinh.py +1 -0
  173. onnx/backend/test/case/node/size.py +1 -0
  174. onnx/backend/test/case/node/slice.py +1 -0
  175. onnx/backend/test/case/node/softmax.py +1 -0
  176. onnx/backend/test/case/node/softmaxcrossentropy.py +4 -1
  177. onnx/backend/test/case/node/softplus.py +1 -0
  178. onnx/backend/test/case/node/softsign.py +1 -0
  179. onnx/backend/test/case/node/spacetodepth.py +1 -0
  180. onnx/backend/test/case/node/split.py +1 -0
  181. onnx/backend/test/case/node/splittosequence.py +1 -0
  182. onnx/backend/test/case/node/sqrt.py +1 -0
  183. onnx/backend/test/case/node/squeeze.py +1 -0
  184. onnx/backend/test/case/node/stft.py +4 -1
  185. onnx/backend/test/case/node/string_concat.py +1 -0
  186. onnx/backend/test/case/node/string_split.py +1 -0
  187. onnx/backend/test/case/node/stringnormalizer.py +1 -0
  188. onnx/backend/test/case/node/sub.py +1 -0
  189. onnx/backend/test/case/node/sum.py +1 -0
  190. onnx/backend/test/case/node/tan.py +1 -0
  191. onnx/backend/test/case/node/tanh.py +1 -0
  192. onnx/backend/test/case/node/tfidfvectorizer.py +1 -0
  193. onnx/backend/test/case/node/thresholdedrelu.py +1 -0
  194. onnx/backend/test/case/node/tile.py +1 -0
  195. onnx/backend/test/case/node/topk.py +1 -0
  196. onnx/backend/test/case/node/transpose.py +1 -0
  197. onnx/backend/test/case/node/trilu.py +1 -0
  198. onnx/backend/test/case/node/unique.py +7 -0
  199. onnx/backend/test/case/node/unsqueeze.py +1 -0
  200. onnx/backend/test/case/node/upsample.py +1 -0
  201. onnx/backend/test/case/node/where.py +1 -0
  202. onnx/backend/test/case/node/xor.py +1 -0
  203. onnx/backend/test/case/test_case.py +6 -5
  204. onnx/backend/test/case/utils.py +2 -2
  205. onnx/backend/test/cmd_tools.py +1 -0
  206. onnx/backend/test/data/node/test_acos/model.onnx +0 -0
  207. onnx/backend/test/data/node/test_acos/test_data_set_0/output_0.pb +0 -0
  208. onnx/backend/test/data/node/test_acos_example/model.onnx +0 -0
  209. onnx/backend/test/data/node/test_acosh/model.onnx +0 -0
  210. onnx/backend/test/data/node/test_acosh/test_data_set_0/output_0.pb +1 -1
  211. onnx/backend/test/data/node/test_acosh_example/model.onnx +0 -0
  212. onnx/backend/test/data/node/test_asin/model.onnx +0 -0
  213. onnx/backend/test/data/node/test_asin/test_data_set_0/output_0.pb +1 -1
  214. onnx/backend/test/data/node/test_asin_example/model.onnx +0 -0
  215. onnx/backend/test/data/node/test_asinh/model.onnx +0 -0
  216. onnx/backend/test/data/node/test_asinh/test_data_set_0/output_0.pb +1 -1
  217. onnx/backend/test/data/node/test_asinh_example/model.onnx +0 -0
  218. onnx/backend/test/data/node/test_atan/model.onnx +0 -0
  219. onnx/backend/test/data/node/test_atan/test_data_set_0/output_0.pb +1 -1
  220. onnx/backend/test/data/node/test_atan_example/model.onnx +0 -0
  221. onnx/backend/test/data/node/test_atanh/model.onnx +0 -0
  222. onnx/backend/test/data/node/test_atanh/test_data_set_0/output_0.pb +2 -2
  223. onnx/backend/test/data/node/test_atanh_example/model.onnx +0 -0
  224. onnx/backend/test/data/node/test_averagepool_1d_default/model.onnx +0 -0
  225. onnx/backend/test/data/node/test_averagepool_2d_ceil/model.onnx +0 -0
  226. onnx/backend/test/data/node/test_averagepool_2d_default/model.onnx +0 -0
  227. onnx/backend/test/data/node/test_averagepool_2d_dilations/model.onnx +0 -0
  228. onnx/backend/test/data/node/test_averagepool_2d_pads/model.onnx +0 -0
  229. onnx/backend/test/data/node/test_averagepool_2d_pads_count_include_pad/model.onnx +0 -0
  230. onnx/backend/test/data/node/test_averagepool_2d_precomputed_pads/model.onnx +0 -0
  231. onnx/backend/test/data/node/test_averagepool_2d_precomputed_pads_count_include_pad/model.onnx +0 -0
  232. onnx/backend/test/data/node/test_averagepool_2d_precomputed_same_upper/model.onnx +0 -0
  233. onnx/backend/test/data/node/test_averagepool_2d_precomputed_strides/model.onnx +0 -0
  234. onnx/backend/test/data/node/test_averagepool_2d_same_lower/model.onnx +0 -0
  235. onnx/backend/test/data/node/test_averagepool_2d_same_upper/model.onnx +0 -0
  236. onnx/backend/test/data/node/test_averagepool_2d_strides/model.onnx +0 -0
  237. onnx/backend/test/data/node/test_averagepool_3d_default/model.onnx +0 -0
  238. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_0_ceil_mode_is_False/model.onnx +0 -0
  239. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_0_ceil_mode_is_True/model.onnx +0 -0
  240. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_False/model.onnx +0 -0
  241. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True/model.onnx +0 -0
  242. onnx/backend/test/data/node/test_averagepool_3d_dilations_small/model.onnx +0 -0
  243. onnx/backend/test/data/node/test_basic_conv_with_padding/model.onnx +0 -0
  244. onnx/backend/test/data/node/test_basic_conv_without_padding/model.onnx +0 -0
  245. onnx/backend/test/data/node/test_basic_deform_conv_with_padding/model.onnx +0 -0
  246. onnx/backend/test/data/node/test_basic_deform_conv_without_padding/model.onnx +0 -0
  247. onnx/backend/test/data/node/test_bernoulli/model.onnx +0 -0
  248. onnx/backend/test/data/node/test_bernoulli_double/model.onnx +0 -0
  249. onnx/backend/test/data/node/test_bernoulli_double_expanded/model.onnx +0 -0
  250. onnx/backend/test/data/node/test_bernoulli_expanded/model.onnx +0 -0
  251. onnx/backend/test/data/node/test_bernoulli_seed/model.onnx +0 -0
  252. onnx/backend/test/data/node/test_bernoulli_seed_expanded/model.onnx +0 -0
  253. onnx/backend/test/data/node/test_blackmanwindow/test_data_set_0/output_0.pb +0 -0
  254. onnx/backend/test/data/node/test_blackmanwindow_expanded/test_data_set_0/output_0.pb +0 -0
  255. onnx/backend/test/data/node/test_blackmanwindow_symmetric/test_data_set_0/output_0.pb +0 -0
  256. onnx/backend/test/data/node/test_blackmanwindow_symmetric_expanded/test_data_set_0/output_0.pb +0 -0
  257. onnx/backend/test/data/node/test_cast_FLOAT16_to_INT4/test_data_set_0/output_0.pb +1 -1
  258. onnx/backend/test/data/node/test_cast_FLOAT_to_INT4/test_data_set_0/output_0.pb +1 -1
  259. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT/test_data_set_0/input_0.pb +1 -1
  260. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT16/test_data_set_0/input_0.pb +1 -1
  261. onnx/backend/test/data/node/test_cast_INT4_to_INT8/test_data_set_0/input_0.pb +1 -1
  262. onnx/backend/test/data/node/test_conv_with_autopad_same/model.onnx +0 -0
  263. onnx/backend/test/data/node/test_conv_with_strides_and_asymmetric_padding/model.onnx +0 -0
  264. onnx/backend/test/data/node/test_conv_with_strides_no_padding/model.onnx +0 -0
  265. onnx/backend/test/data/node/test_conv_with_strides_padding/model.onnx +0 -0
  266. onnx/backend/test/data/node/test_convtranspose/model.onnx +0 -0
  267. onnx/backend/test/data/node/test_convtranspose_1d/model.onnx +0 -0
  268. onnx/backend/test/data/node/test_convtranspose_3d/model.onnx +0 -0
  269. onnx/backend/test/data/node/test_convtranspose_autopad_same/model.onnx +0 -0
  270. onnx/backend/test/data/node/test_convtranspose_dilations/model.onnx +0 -0
  271. onnx/backend/test/data/node/test_convtranspose_group_2/model.onnx +0 -0
  272. onnx/backend/test/data/node/test_convtranspose_group_2/test_data_set_0/input_0.pb +0 -0
  273. onnx/backend/test/data/node/test_convtranspose_group_2/test_data_set_0/input_1.pb +0 -0
  274. onnx/backend/test/data/node/test_convtranspose_group_2/test_data_set_0/output_0.pb +0 -0
  275. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/model.onnx +0 -0
  276. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/test_data_set_0/input_0.pb +0 -0
  277. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/test_data_set_0/input_1.pb +0 -0
  278. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/test_data_set_0/output_0.pb +0 -0
  279. onnx/backend/test/data/node/test_convtranspose_kernel_shape/model.onnx +0 -0
  280. onnx/backend/test/data/node/test_convtranspose_output_shape/model.onnx +0 -0
  281. onnx/backend/test/data/node/test_convtranspose_pad/model.onnx +0 -0
  282. onnx/backend/test/data/node/test_convtranspose_pads/model.onnx +0 -0
  283. onnx/backend/test/data/node/test_cos/model.onnx +0 -0
  284. onnx/backend/test/data/node/test_cos_example/model.onnx +0 -0
  285. onnx/backend/test/data/node/test_cosh/model.onnx +0 -0
  286. onnx/backend/test/data/node/test_cosh/test_data_set_0/output_0.pb +1 -1
  287. onnx/backend/test/data/node/test_cosh_example/model.onnx +0 -0
  288. onnx/backend/test/data/node/test_cosh_example/test_data_set_0/output_0.pb +0 -0
  289. onnx/backend/test/data/node/test_deform_conv_with_mask_bias/model.onnx +0 -0
  290. onnx/backend/test/data/node/test_deform_conv_with_multiple_offset_groups/model.onnx +0 -0
  291. onnx/backend/test/data/node/test_dequantizelinear_int4/test_data_set_0/input_0.pb +1 -1
  292. onnx/backend/test/data/node/test_det_2d/model.onnx +0 -0
  293. onnx/backend/test/data/node/test_det_nd/model.onnx +0 -0
  294. onnx/backend/test/data/node/test_dft/test_data_set_0/output_0.pb +0 -0
  295. onnx/backend/test/data/node/test_dft_axis/test_data_set_0/output_0.pb +0 -0
  296. onnx/backend/test/data/node/test_dft_axis_opset19/test_data_set_0/output_0.pb +0 -0
  297. onnx/backend/test/data/node/test_dft_inverse/test_data_set_0/output_0.pb +0 -0
  298. onnx/backend/test/data/node/test_dft_inverse_opset19/test_data_set_0/output_0.pb +0 -0
  299. onnx/backend/test/data/node/test_dft_opset19/test_data_set_0/output_0.pb +0 -0
  300. onnx/backend/test/data/node/test_dropout_default/model.onnx +0 -0
  301. onnx/backend/test/data/node/test_dropout_default_mask/model.onnx +0 -0
  302. onnx/backend/test/data/node/test_dropout_default_mask_ratio/model.onnx +0 -0
  303. onnx/backend/test/data/node/test_dropout_default_ratio/model.onnx +0 -0
  304. onnx/backend/test/data/node/test_elu/model.onnx +0 -0
  305. onnx/backend/test/data/node/test_elu_default/model.onnx +0 -0
  306. onnx/backend/test/data/node/test_elu_example/model.onnx +0 -0
  307. onnx/backend/test/data/node/test_eyelike_populate_off_main_diagonal/model.onnx +0 -0
  308. onnx/backend/test/data/node/test_eyelike_with_dtype/model.onnx +0 -0
  309. onnx/backend/test/data/node/test_eyelike_without_dtype/model.onnx +0 -0
  310. onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb +0 -0
  311. onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb +0 -0
  312. onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb +4 -3
  313. onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb +4 -3
  314. onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb +0 -0
  315. onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb +0 -0
  316. onnx/backend/test/data/node/test_globalaveragepool/model.onnx +0 -0
  317. onnx/backend/test/data/node/test_globalaveragepool_precomputed/model.onnx +0 -0
  318. onnx/backend/test/data/node/test_globalmaxpool/model.onnx +0 -0
  319. onnx/backend/test/data/node/test_globalmaxpool_precomputed/model.onnx +0 -0
  320. onnx/backend/test/data/node/test_gridsample/model.onnx +0 -0
  321. onnx/backend/test/data/node/test_gridsample_aligncorners_true/model.onnx +0 -0
  322. onnx/backend/test/data/node/test_gridsample_bicubic/model.onnx +0 -0
  323. onnx/backend/test/data/node/test_gridsample_bicubic_align_corners_0_additional_1/model.onnx +0 -0
  324. onnx/backend/test/data/node/test_gridsample_bicubic_align_corners_1_additional_1/model.onnx +0 -0
  325. onnx/backend/test/data/node/test_gridsample_bilinear/model.onnx +0 -0
  326. onnx/backend/test/data/node/test_gridsample_bilinear_align_corners_0_additional_1/model.onnx +0 -0
  327. onnx/backend/test/data/node/test_gridsample_bilinear_align_corners_1_additional_1/model.onnx +0 -0
  328. onnx/backend/test/data/node/test_gridsample_border_padding/model.onnx +0 -0
  329. onnx/backend/test/data/node/test_gridsample_nearest/model.onnx +0 -0
  330. onnx/backend/test/data/node/test_gridsample_nearest_align_corners_0_additional_1/model.onnx +0 -0
  331. onnx/backend/test/data/node/test_gridsample_nearest_align_corners_1_additional_1/model.onnx +0 -0
  332. onnx/backend/test/data/node/test_gridsample_reflection_padding/model.onnx +0 -0
  333. onnx/backend/test/data/node/test_gridsample_volumetric_bilinear_align_corners_0/model.onnx +0 -0
  334. onnx/backend/test/data/node/test_gridsample_volumetric_bilinear_align_corners_1/model.onnx +0 -0
  335. onnx/backend/test/data/node/test_gridsample_volumetric_nearest_align_corners_0/model.onnx +0 -0
  336. onnx/backend/test/data/node/test_gridsample_volumetric_nearest_align_corners_1/model.onnx +0 -0
  337. onnx/backend/test/data/node/test_gridsample_zeros_padding/model.onnx +0 -0
  338. onnx/backend/test/data/node/test_gru_batchwise/model.onnx +0 -0
  339. onnx/backend/test/data/node/test_gru_defaults/model.onnx +0 -0
  340. onnx/backend/test/data/node/test_gru_seq_length/model.onnx +0 -0
  341. onnx/backend/test/data/node/test_gru_with_initial_bias/model.onnx +0 -0
  342. onnx/backend/test/data/node/test_hammingwindow/test_data_set_0/output_0.pb +0 -0
  343. onnx/backend/test/data/node/test_hammingwindow_expanded/test_data_set_0/output_0.pb +0 -0
  344. onnx/backend/test/data/node/test_hammingwindow_symmetric/test_data_set_0/output_0.pb +1 -1
  345. onnx/backend/test/data/node/test_hammingwindow_symmetric_expanded/test_data_set_0/output_0.pb +1 -1
  346. onnx/backend/test/data/node/test_hannwindow/test_data_set_0/output_0.pb +0 -0
  347. onnx/backend/test/data/node/test_hannwindow_expanded/test_data_set_0/output_0.pb +0 -0
  348. onnx/backend/test/data/node/test_hannwindow_symmetric/test_data_set_0/output_0.pb +0 -0
  349. onnx/backend/test/data/node/test_hannwindow_symmetric_expanded/test_data_set_0/output_0.pb +0 -0
  350. onnx/backend/test/data/node/test_hardsigmoid/model.onnx +0 -0
  351. onnx/backend/test/data/node/test_hardsigmoid_default/model.onnx +0 -0
  352. onnx/backend/test/data/node/test_hardsigmoid_example/model.onnx +0 -0
  353. onnx/backend/test/data/node/test_hardswish/model.onnx +0 -0
  354. onnx/backend/test/data/node/test_hardswish_expanded/model.onnx +0 -0
  355. onnx/backend/test/data/node/test_image_decoder_decode_jpeg2k_rgb/test_data_set_0/input_0.pb +0 -0
  356. onnx/backend/test/data/node/test_instancenorm_epsilon/model.onnx +0 -0
  357. onnx/backend/test/data/node/test_instancenorm_example/model.onnx +0 -0
  358. onnx/backend/test/data/node/test_lppool_1d_default/model.onnx +0 -0
  359. onnx/backend/test/data/node/test_lppool_1d_default/test_data_set_0/output_0.pb +2 -2
  360. onnx/backend/test/data/node/test_lppool_2d_default/model.onnx +0 -0
  361. onnx/backend/test/data/node/test_lppool_2d_default/test_data_set_0/output_0.pb +0 -0
  362. onnx/backend/test/data/node/test_lppool_2d_dilations/model.onnx +0 -0
  363. onnx/backend/test/data/node/test_lppool_2d_pads/model.onnx +0 -0
  364. onnx/backend/test/data/node/test_lppool_2d_pads/test_data_set_0/output_0.pb +0 -0
  365. onnx/backend/test/data/node/test_lppool_2d_same_lower/model.onnx +0 -0
  366. onnx/backend/test/data/node/test_lppool_2d_same_lower/test_data_set_0/output_0.pb +0 -0
  367. onnx/backend/test/data/node/test_lppool_2d_same_upper/model.onnx +0 -0
  368. onnx/backend/test/data/node/test_lppool_2d_same_upper/test_data_set_0/output_0.pb +0 -0
  369. onnx/backend/test/data/node/test_lppool_2d_strides/model.onnx +0 -0
  370. onnx/backend/test/data/node/test_lppool_2d_strides/test_data_set_0/output_0.pb +0 -0
  371. onnx/backend/test/data/node/test_lppool_3d_default/model.onnx +0 -0
  372. onnx/backend/test/data/node/test_lppool_3d_default/test_data_set_0/output_0.pb +0 -0
  373. onnx/backend/test/data/node/test_lstm_batchwise/model.onnx +0 -0
  374. onnx/backend/test/data/node/test_lstm_defaults/model.onnx +0 -0
  375. onnx/backend/test/data/node/test_lstm_with_initial_bias/model.onnx +0 -0
  376. onnx/backend/test/data/node/test_lstm_with_peepholes/model.onnx +0 -0
  377. onnx/backend/test/data/node/test_maxpool_1d_default/model.onnx +0 -0
  378. onnx/backend/test/data/node/test_maxpool_2d_ceil/model.onnx +0 -0
  379. onnx/backend/test/data/node/test_maxpool_2d_ceil_output_size_reduce_by_one/model.onnx +0 -0
  380. onnx/backend/test/data/node/test_maxpool_2d_default/model.onnx +0 -0
  381. onnx/backend/test/data/node/test_maxpool_2d_dilations/model.onnx +0 -0
  382. onnx/backend/test/data/node/test_maxpool_2d_pads/model.onnx +0 -0
  383. onnx/backend/test/data/node/test_maxpool_2d_precomputed_pads/model.onnx +0 -0
  384. onnx/backend/test/data/node/test_maxpool_2d_precomputed_same_upper/model.onnx +0 -0
  385. onnx/backend/test/data/node/test_maxpool_2d_precomputed_strides/model.onnx +0 -0
  386. onnx/backend/test/data/node/test_maxpool_2d_same_lower/model.onnx +0 -0
  387. onnx/backend/test/data/node/test_maxpool_2d_same_upper/model.onnx +0 -0
  388. onnx/backend/test/data/node/test_maxpool_2d_strides/model.onnx +0 -0
  389. onnx/backend/test/data/node/test_maxpool_2d_uint8/model.onnx +0 -0
  390. onnx/backend/test/data/node/test_maxpool_3d_default/model.onnx +0 -0
  391. onnx/backend/test/data/node/test_maxpool_3d_dilations/model.onnx +0 -0
  392. onnx/backend/test/data/node/test_maxpool_3d_dilations_use_ref_impl/model.onnx +0 -0
  393. onnx/backend/test/data/node/test_maxpool_3d_dilations_use_ref_impl_large/model.onnx +0 -0
  394. onnx/backend/test/data/node/test_maxpool_with_argmax_2d_precomputed_pads/model.onnx +0 -0
  395. onnx/backend/test/data/node/test_maxpool_with_argmax_2d_precomputed_strides/model.onnx +0 -0
  396. onnx/backend/test/data/node/test_maxunpool_export_with_output_shape/model.onnx +0 -0
  397. onnx/backend/test/data/node/test_maxunpool_export_without_output_shape/model.onnx +0 -0
  398. onnx/backend/test/data/node/test_mish/model.onnx +0 -0
  399. onnx/backend/test/data/node/test_mish/test_data_set_0/output_0.pb +0 -0
  400. onnx/backend/test/data/node/test_mish_expanded/model.onnx +0 -0
  401. onnx/backend/test/data/node/test_mish_expanded/test_data_set_0/output_0.pb +0 -0
  402. onnx/backend/test/data/node/test_nllloss_NC/model.onnx +0 -0
  403. onnx/backend/test/data/node/test_nllloss_NC_expanded/model.onnx +0 -0
  404. onnx/backend/test/data/node/test_nllloss_NCd1/model.onnx +0 -0
  405. onnx/backend/test/data/node/test_nllloss_NCd1_expanded/model.onnx +0 -0
  406. onnx/backend/test/data/node/test_nllloss_NCd1_ii/model.onnx +0 -0
  407. onnx/backend/test/data/node/test_nllloss_NCd1_ii_expanded/model.onnx +0 -0
  408. onnx/backend/test/data/node/test_nllloss_NCd1_mean_weight_negative_ii/model.onnx +0 -0
  409. onnx/backend/test/data/node/test_nllloss_NCd1_mean_weight_negative_ii_expanded/model.onnx +0 -0
  410. onnx/backend/test/data/node/test_nllloss_NCd1_weight/model.onnx +0 -0
  411. onnx/backend/test/data/node/test_nllloss_NCd1_weight_expanded/model.onnx +0 -0
  412. onnx/backend/test/data/node/test_nllloss_NCd1_weight_ii/model.onnx +0 -0
  413. onnx/backend/test/data/node/test_nllloss_NCd1_weight_ii_expanded/model.onnx +0 -0
  414. onnx/backend/test/data/node/test_nllloss_NCd1d2/model.onnx +0 -0
  415. onnx/backend/test/data/node/test_nllloss_NCd1d2_expanded/model.onnx +0 -0
  416. onnx/backend/test/data/node/test_nllloss_NCd1d2_no_weight_reduction_mean_ii/model.onnx +0 -0
  417. onnx/backend/test/data/node/test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded/model.onnx +0 -0
  418. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_mean/model.onnx +0 -0
  419. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_mean_expanded/model.onnx +0 -0
  420. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_sum/model.onnx +0 -0
  421. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_sum_expanded/model.onnx +0 -0
  422. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight/model.onnx +0 -0
  423. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_expanded/model.onnx +0 -0
  424. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_mean/model.onnx +0 -0
  425. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_mean_expanded/model.onnx +0 -0
  426. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum/model.onnx +0 -0
  427. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum_expanded/model.onnx +0 -0
  428. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum_ii/model.onnx +0 -0
  429. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded/model.onnx +0 -0
  430. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_none_no_weight_negative_ii/model.onnx +0 -0
  431. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded/model.onnx +0 -0
  432. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_sum_weight_high_ii/model.onnx +0 -0
  433. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded/model.onnx +0 -0
  434. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_mean_weight/model.onnx +0 -0
  435. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_mean_weight_expanded/model.onnx +0 -0
  436. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_none_no_weight/model.onnx +0 -0
  437. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded/model.onnx +0 -0
  438. onnx/backend/test/data/node/test_quantizelinear_int4/test_data_set_0/output_0.pb +1 -1
  439. onnx/backend/test/data/node/test_reduce_log_sum_exp_do_not_keepdims_random/test_data_set_0/output_0.pb +1 -1
  440. onnx/backend/test/data/node/test_reduce_log_sum_exp_do_not_keepdims_random_expanded/test_data_set_0/output_0.pb +1 -1
  441. onnx/backend/test/data/node/test_reduce_log_sum_exp_keepdims_random/test_data_set_0/output_0.pb +1 -1
  442. onnx/backend/test/data/node/test_reduce_log_sum_exp_keepdims_random_expanded/test_data_set_0/output_0.pb +1 -1
  443. onnx/backend/test/data/node/test_reduce_log_sum_exp_negative_axes_keepdims_random/test_data_set_0/output_0.pb +1 -1
  444. onnx/backend/test/data/node/test_reduce_log_sum_exp_negative_axes_keepdims_random_expanded/test_data_set_0/output_0.pb +1 -1
  445. onnx/backend/test/data/node/test_reduce_max_empty_set/model.onnx +0 -0
  446. onnx/backend/test/data/node/test_reduce_max_empty_set/test_data_set_0/input_0.pb +0 -0
  447. onnx/backend/test/data/node/test_reduce_max_empty_set/test_data_set_0/input_1.pb +0 -0
  448. onnx/backend/test/data/node/test_reduce_max_empty_set/test_data_set_0/output_0.pb +0 -0
  449. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/model.onnx +0 -0
  450. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/test_data_set_0/input_0.pb +1 -0
  451. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/test_data_set_0/input_1.pb +0 -0
  452. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/test_data_set_0/output_0.pb +1 -0
  453. onnx/backend/test/data/node/test_reduce_sum_negative_axes_keepdims_random/model.onnx +0 -0
  454. onnx/backend/test/data/node/test_reduce_sum_negative_axes_keepdims_random/test_data_set_0/input_1.pb +0 -0
  455. onnx/backend/test/data/node/test_reduce_sum_negative_axes_keepdims_random/test_data_set_0/output_0.pb +1 -1
  456. onnx/backend/test/data/node/test_resize_tf_crop_and_resize/model.onnx +0 -0
  457. onnx/backend/test/data/node/test_resize_tf_crop_and_resize/test_data_set_0/input_1.pb +0 -0
  458. onnx/backend/test/data/node/test_resize_tf_crop_and_resize/test_data_set_0/output_0.pb +0 -0
  459. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/model.onnx +0 -0
  460. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/input_0.pb +0 -0
  461. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/input_1.pb +0 -0
  462. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/input_2.pb +0 -0
  463. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/output_0.pb +0 -0
  464. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_larger/model.onnx +0 -0
  465. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_larger/test_data_set_0/output_0.pb +0 -0
  466. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/model.onnx +0 -0
  467. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/test_data_set_0/input_0.pb +0 -0
  468. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/test_data_set_0/input_1.pb +0 -0
  469. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/test_data_set_0/output_0.pb +0 -0
  470. onnx/backend/test/data/node/test_rnn_seq_length/model.onnx +0 -0
  471. onnx/backend/test/data/node/test_roialign_aligned_false/model.onnx +0 -0
  472. onnx/backend/test/data/node/test_roialign_aligned_true/model.onnx +0 -0
  473. onnx/backend/test/data/node/test_roialign_mode_max/model.onnx +0 -0
  474. onnx/backend/test/data/node/test_round/model.onnx +0 -0
  475. onnx/backend/test/data/node/test_selu/model.onnx +0 -0
  476. onnx/backend/test/data/node/test_selu_default/model.onnx +0 -0
  477. onnx/backend/test/data/node/test_selu_example/model.onnx +0 -0
  478. onnx/backend/test/data/node/test_simple_rnn_batchwise/model.onnx +0 -0
  479. onnx/backend/test/data/node/test_simple_rnn_defaults/model.onnx +0 -0
  480. onnx/backend/test/data/node/test_simple_rnn_with_initial_bias/model.onnx +0 -0
  481. onnx/backend/test/data/node/test_sin/model.onnx +0 -0
  482. onnx/backend/test/data/node/test_sin_example/model.onnx +0 -0
  483. onnx/backend/test/data/node/test_sinh/model.onnx +0 -0
  484. onnx/backend/test/data/node/test_sinh/test_data_set_0/output_0.pb +1 -1
  485. onnx/backend/test/data/node/test_sinh_example/model.onnx +0 -0
  486. onnx/backend/test/data/node/test_softplus/model.onnx +0 -0
  487. onnx/backend/test/data/node/test_softplus_example/model.onnx +0 -0
  488. onnx/backend/test/data/node/test_softsign/model.onnx +0 -0
  489. onnx/backend/test/data/node/test_softsign_example/model.onnx +0 -0
  490. onnx/backend/test/data/node/test_stft_with_window/test_data_set_0/input_2.pb +0 -0
  491. onnx/backend/test/data/node/test_stft_with_window/test_data_set_0/output_0.pb +0 -0
  492. onnx/backend/test/data/node/test_tan/model.onnx +0 -0
  493. onnx/backend/test/data/node/test_tan/test_data_set_0/output_0.pb +1 -1
  494. onnx/backend/test/data/node/test_tan_example/model.onnx +0 -0
  495. onnx/backend/test/data/node/test_thresholdedrelu/model.onnx +0 -0
  496. onnx/backend/test/data/node/test_thresholdedrelu_default/model.onnx +0 -0
  497. onnx/backend/test/data/node/test_thresholdedrelu_example/model.onnx +0 -0
  498. onnx/backend/test/data/node/test_training_dropout/model.onnx +0 -0
  499. onnx/backend/test/data/node/test_training_dropout_default/model.onnx +0 -0
  500. onnx/backend/test/data/node/test_training_dropout_default_mask/model.onnx +0 -0
  501. onnx/backend/test/data/node/test_training_dropout_mask/model.onnx +0 -0
  502. onnx/backend/test/data/node/test_training_dropout_zero_ratio/model.onnx +0 -0
  503. onnx/backend/test/data/node/test_training_dropout_zero_ratio_mask/model.onnx +0 -0
  504. onnx/backend/test/loader/__init__.py +11 -6
  505. onnx/backend/test/report/__init__.py +4 -3
  506. onnx/backend/test/report/base.py +1 -0
  507. onnx/backend/test/report/coverage.py +21 -20
  508. onnx/backend/test/runner/__init__.py +13 -11
  509. onnx/backend/test/runner/item.py +3 -2
  510. onnx/backend/test/stat_coverage.py +6 -5
  511. onnx/bin/checker.py +1 -0
  512. onnx/checker.cc +6 -1
  513. onnx/common/version.h +1 -1
  514. onnx/compose.py +66 -50
  515. onnx/cpp2py_export.cc +4 -0
  516. onnx/defs/__init__.py +2 -2
  517. onnx/defs/data_type_utils.cc +0 -1
  518. onnx/defs/gen_doc.py +9 -8
  519. onnx/defs/gen_shape_inference_information.py +1 -0
  520. onnx/defs/generator/defs.cc +32 -84
  521. onnx/defs/generator/old.cc +389 -0
  522. onnx/defs/math/defs.cc +308 -313
  523. onnx/defs/math/old.cc +996 -9
  524. onnx/defs/math/utils.cc +12 -1
  525. onnx/defs/math/utils.h +2 -0
  526. onnx/defs/nn/defs.cc +57 -75
  527. onnx/defs/nn/old.cc +1536 -2
  528. onnx/defs/object_detection/defs.cc +4 -7
  529. onnx/defs/object_detection/old.cc +117 -0
  530. onnx/defs/operator_sets.h +108 -1
  531. onnx/defs/parser.cc +10 -1
  532. onnx/defs/quantization/defs.cc +3 -2
  533. onnx/defs/quantization/old.cc +4 -1
  534. onnx/defs/rnn/defs.cc +10 -13
  535. onnx/defs/rnn/old.cc +517 -2
  536. onnx/defs/schema.cc +53 -59
  537. onnx/defs/schema.h +58 -2
  538. onnx/defs/shape_inference.h +67 -18
  539. onnx/defs/tensor/defs.cc +22 -20
  540. onnx/defs/tensor/old.cc +114 -3
  541. onnx/external_data_helper.py +27 -14
  542. onnx/gen_proto.py +3 -2
  543. onnx/helper.py +86 -61
  544. onnx/hub.py +39 -35
  545. onnx/inliner/inliner.cc +0 -1
  546. onnx/mapping.py +3 -2
  547. onnx/numpy_helper.py +159 -23
  548. onnx/onnx-ml.proto +1 -1
  549. onnx/onnx.in.proto +1 -1
  550. onnx/onnx.proto +1 -1
  551. onnx/onnx_cpp2py_export/defs.pyi +0 -2
  552. onnx/onnx_cpp2py_export/inliner.pyi +0 -4
  553. onnx/onnx_cpp2py_export/parser.pyi +0 -4
  554. onnx/onnx_cpp2py_export.cp38-win32.pyd +0 -0
  555. onnx/parser.py +1 -0
  556. onnx/printer.py +2 -3
  557. onnx/reference/__init__.py +1 -0
  558. onnx/reference/custom_element_types.py +73 -8
  559. onnx/reference/op_run.py +13 -58
  560. onnx/reference/ops/__init__.py +1 -0
  561. onnx/reference/ops/_helpers.py +6 -4
  562. onnx/reference/ops/_op.py +16 -5
  563. onnx/reference/ops/_op_common_indices.py +1 -1
  564. onnx/reference/ops/_op_common_pool.py +38 -29
  565. onnx/reference/ops/_op_common_random.py +1 -1
  566. onnx/reference/ops/_op_common_window.py +2 -2
  567. onnx/reference/ops/_op_list.py +9 -6
  568. onnx/reference/ops/aionnx_preview_training/__init__.py +1 -0
  569. onnx/reference/ops/aionnx_preview_training/_op_list.py +5 -7
  570. onnx/reference/ops/aionnx_preview_training/_op_run_training.py +1 -1
  571. onnx/reference/ops/aionnx_preview_training/op_adagrad.py +14 -5
  572. onnx/reference/ops/aionnx_preview_training/op_adam.py +2 -2
  573. onnx/reference/ops/aionnx_preview_training/op_momentum.py +14 -2
  574. onnx/reference/ops/aionnxml/__init__.py +1 -0
  575. onnx/reference/ops/aionnxml/_common_classifier.py +1 -0
  576. onnx/reference/ops/aionnxml/_op_list.py +5 -6
  577. onnx/reference/ops/aionnxml/_op_run_aionnxml.py +1 -1
  578. onnx/reference/ops/aionnxml/op_array_feature_extractor.py +1 -1
  579. onnx/reference/ops/aionnxml/op_binarizer.py +1 -1
  580. onnx/reference/ops/aionnxml/op_dict_vectorizer.py +2 -2
  581. onnx/reference/ops/aionnxml/op_feature_vectorizer.py +1 -1
  582. onnx/reference/ops/aionnxml/op_imputer.py +3 -3
  583. onnx/reference/ops/aionnxml/op_label_encoder.py +1 -1
  584. onnx/reference/ops/aionnxml/op_linear_classifier.py +2 -2
  585. onnx/reference/ops/aionnxml/op_linear_regressor.py +1 -1
  586. onnx/reference/ops/aionnxml/op_normalizer.py +1 -1
  587. onnx/reference/ops/aionnxml/op_one_hot_encoder.py +1 -1
  588. onnx/reference/ops/aionnxml/op_scaler.py +1 -1
  589. onnx/reference/ops/aionnxml/op_svm_classifier.py +10 -7
  590. onnx/reference/ops/aionnxml/op_svm_helper.py +2 -2
  591. onnx/reference/ops/aionnxml/op_svm_regressor.py +1 -1
  592. onnx/reference/ops/aionnxml/op_tree_ensemble.py +3 -3
  593. onnx/reference/ops/aionnxml/op_tree_ensemble_classifier.py +1 -1
  594. onnx/reference/ops/aionnxml/op_tree_ensemble_helper.py +2 -2
  595. onnx/reference/ops/aionnxml/op_tree_ensemble_regressor.py +5 -3
  596. onnx/reference/ops/experimental/__init__.py +1 -0
  597. onnx/reference/ops/experimental/_op_list.py +6 -12
  598. onnx/reference/ops/experimental/_op_run_experimental.py +1 -1
  599. onnx/reference/ops/experimental/op_im2col.py +1 -1
  600. onnx/reference/ops/op_abs.py +1 -1
  601. onnx/reference/ops/op_acos.py +1 -1
  602. onnx/reference/ops/op_acosh.py +1 -1
  603. onnx/reference/ops/op_add.py +1 -1
  604. onnx/reference/ops/op_affine_grid.py +1 -1
  605. onnx/reference/ops/op_and.py +1 -1
  606. onnx/reference/ops/op_argmax.py +1 -1
  607. onnx/reference/ops/op_argmin.py +1 -1
  608. onnx/reference/ops/op_asin.py +1 -1
  609. onnx/reference/ops/op_asinh.py +1 -1
  610. onnx/reference/ops/op_atan.py +1 -1
  611. onnx/reference/ops/op_atanh.py +1 -1
  612. onnx/reference/ops/op_attribute_has_value.py +15 -15
  613. onnx/reference/ops/op_average_pool.py +1 -1
  614. onnx/reference/ops/op_batch_normalization.py +13 -2
  615. onnx/reference/ops/op_bernoulli.py +1 -1
  616. onnx/reference/ops/op_bitshift.py +1 -1
  617. onnx/reference/ops/op_bitwise_and.py +1 -1
  618. onnx/reference/ops/op_bitwise_not.py +1 -1
  619. onnx/reference/ops/op_bitwise_or.py +1 -1
  620. onnx/reference/ops/op_bitwise_xor.py +1 -1
  621. onnx/reference/ops/op_blackman_window.py +1 -1
  622. onnx/reference/ops/op_cast.py +11 -10
  623. onnx/reference/ops/op_cast_like.py +1 -1
  624. onnx/reference/ops/op_ceil.py +1 -1
  625. onnx/reference/ops/op_celu.py +1 -1
  626. onnx/reference/ops/op_center_crop_pad.py +1 -1
  627. onnx/reference/ops/op_clip.py +1 -1
  628. onnx/reference/ops/op_col2im.py +10 -4
  629. onnx/reference/ops/op_compress.py +1 -1
  630. onnx/reference/ops/op_concat.py +1 -1
  631. onnx/reference/ops/op_concat_from_sequence.py +3 -3
  632. onnx/reference/ops/op_constant.py +2 -2
  633. onnx/reference/ops/op_constant_of_shape.py +1 -1
  634. onnx/reference/ops/op_conv.py +22 -17
  635. onnx/reference/ops/op_conv_integer.py +1 -1
  636. onnx/reference/ops/op_conv_transpose.py +37 -6
  637. onnx/reference/ops/op_cos.py +1 -1
  638. onnx/reference/ops/op_cosh.py +1 -1
  639. onnx/reference/ops/op_cum_sum.py +1 -1
  640. onnx/reference/ops/op_deform_conv.py +1 -1
  641. onnx/reference/ops/op_depth_to_space.py +1 -1
  642. onnx/reference/ops/op_dequantize_linear.py +7 -9
  643. onnx/reference/ops/op_det.py +1 -1
  644. onnx/reference/ops/op_dft.py +16 -2
  645. onnx/reference/ops/op_div.py +1 -1
  646. onnx/reference/ops/op_dropout.py +9 -8
  647. onnx/reference/ops/op_dynamic_quantize_linear.py +1 -1
  648. onnx/reference/ops/op_einsum.py +1 -1
  649. onnx/reference/ops/op_elu.py +1 -1
  650. onnx/reference/ops/op_equal.py +1 -1
  651. onnx/reference/ops/op_erf.py +1 -1
  652. onnx/reference/ops/op_exp.py +1 -1
  653. onnx/reference/ops/op_expand.py +1 -1
  654. onnx/reference/ops/op_eyelike.py +2 -2
  655. onnx/reference/ops/op_flatten.py +1 -1
  656. onnx/reference/ops/op_floor.py +1 -1
  657. onnx/reference/ops/op_gather.py +1 -1
  658. onnx/reference/ops/op_gather_elements.py +3 -3
  659. onnx/reference/ops/op_gathernd.py +2 -4
  660. onnx/reference/ops/op_gemm.py +12 -2
  661. onnx/reference/ops/op_global_average_pool.py +1 -1
  662. onnx/reference/ops/op_global_max_pool.py +1 -1
  663. onnx/reference/ops/op_greater.py +1 -1
  664. onnx/reference/ops/op_greater_or_equal.py +1 -1
  665. onnx/reference/ops/op_grid_sample.py +2 -3
  666. onnx/reference/ops/op_gru.py +7 -7
  667. onnx/reference/ops/op_hamming_window.py +1 -1
  668. onnx/reference/ops/op_hann_window.py +1 -1
  669. onnx/reference/ops/op_hard_sigmoid.py +1 -1
  670. onnx/reference/ops/op_hardmax.py +5 -2
  671. onnx/reference/ops/op_identity.py +3 -3
  672. onnx/reference/ops/op_if.py +2 -2
  673. onnx/reference/ops/op_instance_normalization.py +1 -1
  674. onnx/reference/ops/op_isinf.py +1 -1
  675. onnx/reference/ops/op_isnan.py +1 -1
  676. onnx/reference/ops/op_layer_normalization.py +2 -4
  677. onnx/reference/ops/op_leaky_relu.py +1 -1
  678. onnx/reference/ops/op_less.py +1 -1
  679. onnx/reference/ops/op_less_or_equal.py +1 -1
  680. onnx/reference/ops/op_log.py +1 -1
  681. onnx/reference/ops/op_log_softmax.py +1 -1
  682. onnx/reference/ops/op_loop.py +4 -2
  683. onnx/reference/ops/op_lp_normalization.py +1 -1
  684. onnx/reference/ops/op_lp_pool.py +4 -2
  685. onnx/reference/ops/op_lrn.py +1 -1
  686. onnx/reference/ops/op_lstm.py +9 -11
  687. onnx/reference/ops/op_matmul.py +1 -1
  688. onnx/reference/ops/op_matmul_integer.py +1 -1
  689. onnx/reference/ops/op_max.py +1 -1
  690. onnx/reference/ops/op_max_pool.py +8 -8
  691. onnx/reference/ops/op_max_unpool.py +5 -3
  692. onnx/reference/ops/op_mean.py +1 -1
  693. onnx/reference/ops/op_mel_weight_matrix.py +1 -1
  694. onnx/reference/ops/op_min.py +1 -1
  695. onnx/reference/ops/op_mod.py +1 -1
  696. onnx/reference/ops/op_mul.py +1 -1
  697. onnx/reference/ops/op_neg.py +1 -1
  698. onnx/reference/ops/op_negative_log_likelihood_loss.py +4 -2
  699. onnx/reference/ops/op_non_max_suppression.py +10 -11
  700. onnx/reference/ops/op_non_zero.py +1 -1
  701. onnx/reference/ops/op_not.py +1 -1
  702. onnx/reference/ops/op_one_hot.py +1 -1
  703. onnx/reference/ops/op_optional.py +1 -1
  704. onnx/reference/ops/op_optional_get_element.py +1 -1
  705. onnx/reference/ops/op_optional_has_element.py +1 -1
  706. onnx/reference/ops/op_or.py +1 -1
  707. onnx/reference/ops/op_pad.py +1 -1
  708. onnx/reference/ops/op_pool_common.py +7 -6
  709. onnx/reference/ops/op_pow.py +1 -1
  710. onnx/reference/ops/op_prelu.py +3 -3
  711. onnx/reference/ops/op_qlinear_conv.py +1 -1
  712. onnx/reference/ops/op_qlinear_matmul.py +1 -1
  713. onnx/reference/ops/op_quantize_linear.py +15 -9
  714. onnx/reference/ops/op_random_normal.py +1 -1
  715. onnx/reference/ops/op_random_normal_like.py +1 -1
  716. onnx/reference/ops/op_random_uniform.py +1 -1
  717. onnx/reference/ops/op_random_uniform_like.py +1 -1
  718. onnx/reference/ops/op_range.py +1 -1
  719. onnx/reference/ops/op_reciprocal.py +1 -1
  720. onnx/reference/ops/op_reduce_l1.py +1 -1
  721. onnx/reference/ops/op_reduce_l2.py +1 -1
  722. onnx/reference/ops/op_reduce_log_sum.py +1 -1
  723. onnx/reference/ops/op_reduce_log_sum_exp.py +1 -1
  724. onnx/reference/ops/op_reduce_max.py +1 -1
  725. onnx/reference/ops/op_reduce_mean.py +2 -2
  726. onnx/reference/ops/op_reduce_min.py +1 -1
  727. onnx/reference/ops/op_reduce_prod.py +1 -1
  728. onnx/reference/ops/op_reduce_sum.py +2 -2
  729. onnx/reference/ops/op_reduce_sum_square.py +1 -1
  730. onnx/reference/ops/op_regex_full_match.py +1 -1
  731. onnx/reference/ops/op_relu.py +1 -1
  732. onnx/reference/ops/op_reshape.py +1 -1
  733. onnx/reference/ops/op_reverse_sequence.py +1 -1
  734. onnx/reference/ops/op_rnn.py +10 -8
  735. onnx/reference/ops/op_roi_align.py +5 -5
  736. onnx/reference/ops/op_round.py +1 -1
  737. onnx/reference/ops/op_scan.py +8 -8
  738. onnx/reference/ops/op_scatter_elements.py +19 -50
  739. onnx/reference/ops/op_scatternd.py +1 -1
  740. onnx/reference/ops/op_selu.py +1 -1
  741. onnx/reference/ops/op_sequence_at.py +1 -1
  742. onnx/reference/ops/op_sequence_construct.py +1 -1
  743. onnx/reference/ops/op_sequence_empty.py +2 -2
  744. onnx/reference/ops/op_sequence_erase.py +1 -1
  745. onnx/reference/ops/op_sequence_insert.py +6 -6
  746. onnx/reference/ops/op_sequence_length.py +1 -1
  747. onnx/reference/ops/op_sequence_map.py +1 -1
  748. onnx/reference/ops/op_shape.py +2 -6
  749. onnx/reference/ops/op_shrink.py +1 -1
  750. onnx/reference/ops/op_sigmoid.py +1 -1
  751. onnx/reference/ops/op_sign.py +1 -1
  752. onnx/reference/ops/op_sin.py +1 -1
  753. onnx/reference/ops/op_sinh.py +1 -1
  754. onnx/reference/ops/op_size.py +1 -1
  755. onnx/reference/ops/op_slice.py +3 -5
  756. onnx/reference/ops/op_softmax.py +1 -1
  757. onnx/reference/ops/op_softmax_cross_entropy_loss.py +1 -1
  758. onnx/reference/ops/op_softplus.py +1 -1
  759. onnx/reference/ops/op_softsign.py +1 -1
  760. onnx/reference/ops/op_space_to_depth.py +1 -1
  761. onnx/reference/ops/op_split.py +1 -1
  762. onnx/reference/ops/op_split_to_sequence.py +5 -7
  763. onnx/reference/ops/op_sqrt.py +1 -1
  764. onnx/reference/ops/op_squeeze.py +1 -1
  765. onnx/reference/ops/op_stft.py +3 -2
  766. onnx/reference/ops/op_string_concat.py +1 -1
  767. onnx/reference/ops/op_string_normalizer.py +8 -8
  768. onnx/reference/ops/op_string_split.py +2 -4
  769. onnx/reference/ops/op_sub.py +1 -1
  770. onnx/reference/ops/op_sum.py +1 -1
  771. onnx/reference/ops/op_tan.py +1 -1
  772. onnx/reference/ops/op_tanh.py +1 -1
  773. onnx/reference/ops/op_tfidf_vectorizer.py +11 -12
  774. onnx/reference/ops/op_thresholded_relu.py +1 -1
  775. onnx/reference/ops/op_tile.py +1 -1
  776. onnx/reference/ops/op_topk.py +7 -2
  777. onnx/reference/ops/op_transpose.py +1 -1
  778. onnx/reference/ops/op_trilu.py +1 -1
  779. onnx/reference/ops/op_unique.py +3 -1
  780. onnx/reference/ops/op_unsqueeze.py +2 -2
  781. onnx/reference/ops/op_upsample.py +1 -1
  782. onnx/reference/ops/op_where.py +1 -1
  783. onnx/reference/ops/op_xor.py +1 -1
  784. onnx/reference/ops_optimized/__init__.py +1 -0
  785. onnx/reference/ops_optimized/op_conv_optimized.py +1 -1
  786. onnx/reference/reference_evaluator.py +27 -13
  787. onnx/serialization.py +1 -1
  788. onnx/shape_inference/implementation.cc +15 -1
  789. onnx/shape_inference/implementation.h +15 -1
  790. onnx/shape_inference.py +1 -1
  791. onnx/subbyte.py +6 -6
  792. onnx/test/basic_test.py +1 -0
  793. onnx/test/checker_test.py +37 -2
  794. onnx/test/compose_test.py +12 -11
  795. onnx/test/cpp/schema_registration_test.cc +3 -3
  796. onnx/test/cpp/shape_inference_test.cc +38 -2
  797. onnx/test/elu_test.py +2 -0
  798. onnx/test/function_inference_test.py +2 -0
  799. onnx/test/function_test.py +1 -0
  800. onnx/test/helper_test.py +77 -16
  801. onnx/test/hub_test.py +1 -1
  802. onnx/test/inference_function_test.py +25 -8
  803. onnx/test/inliner_test.py +2 -0
  804. onnx/test/model_container_refeval_test.py +2 -1
  805. onnx/test/model_container_test.py +1 -0
  806. onnx/test/model_inference_test.py +2 -0
  807. onnx/test/numpy_helper_test.py +56 -1
  808. onnx/test/parser_test.py +48 -2
  809. onnx/test/printer_test.py +2 -0
  810. onnx/test/reference_evaluator_ml_test.py +2 -3
  811. onnx/test/reference_evaluator_model_test.py +2 -0
  812. onnx/test/reference_evaluator_test.py +173 -19
  813. onnx/test/relu_test.py +2 -0
  814. onnx/test/schema_test.py +4 -2
  815. onnx/test/serialization_test.py +2 -0
  816. onnx/test/shape_inference_test.py +349 -19
  817. onnx/test/symbolic_shape_test.py +3 -3
  818. onnx/test/test_backend_onnxruntime.py +272 -1
  819. onnx/test/test_backend_reference.py +24 -3
  820. onnx/test/test_backend_test.py +6 -5
  821. onnx/test/test_external_data.py +91 -2
  822. onnx/test/test_with_ort.py +1 -0
  823. onnx/test/tools_test.py +15 -14
  824. onnx/test/training_tool_test.py +1 -0
  825. onnx/test/utils_test.py +1 -0
  826. onnx/test/version_converter/automatic_downgrade_test.py +2 -0
  827. onnx/test/version_converter/automatic_upgrade_test.py +2 -0
  828. onnx/test/version_converter_test.py +26 -7
  829. onnx/test/version_utils.py +8 -0
  830. onnx/tools/net_drawer.py +7 -6
  831. onnx/tools/replace_constants.py +11 -11
  832. onnx/tools/update_model_dims.py +7 -6
  833. onnx/utils.py +104 -21
  834. onnx/version.py +2 -2
  835. onnx/version_converter/adapters/split_17_18.h +1 -1
  836. onnx/version_converter/convert.h +107 -2
  837. onnx/version_converter.py +3 -2
  838. {onnx-1.16.1.dist-info → onnx-1.17.0.dist-info}/METADATA +8 -11
  839. {onnx-1.16.1.dist-info → onnx-1.17.0.dist-info}/RECORD +843 -817
  840. {onnx-1.16.1.dist-info → onnx-1.17.0.dist-info}/WHEEL +1 -1
  841. {onnx-1.16.1.dist-info → onnx-1.17.0.dist-info}/LICENSE +0 -0
  842. {onnx-1.16.1.dist-info → onnx-1.17.0.dist-info}/entry_points.txt +0 -0
  843. {onnx-1.16.1.dist-info → onnx-1.17.0.dist-info}/top_level.txt +0 -0
onnx/defs/math/defs.cc CHANGED
@@ -14,17 +14,6 @@
14
14
 
15
15
  namespace ONNX_NAMESPACE {
16
16
 
17
- inline int MathOpTwoIntegers(std::string op_type, int a, int b) {
18
- if (op_type == "Add") {
19
- return a + b;
20
- } else if (op_type == "Sub") {
21
- return a - b;
22
- } else if (op_type == "Mul") {
23
- return a * b;
24
- }
25
- fail_shape_inference("Wrong op_type name for running propagation: ", op_type);
26
- }
27
-
28
17
  inline void MathOpDataPropagator(DataPropagationContext& ctx, std::string op_type) {
29
18
  const auto input_0 = ctx.getInputData(0);
30
19
  const auto input_1 = ctx.getInputData(1);
@@ -43,7 +32,7 @@ inline void MathOpDataPropagator(DataPropagationContext& ctx, std::string op_typ
43
32
  auto& input_dim_1 = input_1->dim(size_1 == 1 ? 0 : i);
44
33
  if (input_dim_0.has_dim_value() && input_dim_1.has_dim_value()) {
45
34
  tsp.mutable_dim()->Add()->set_dim_value(
46
- MathOpTwoIntegers(op_type, input_dim_0.dim_value(), input_dim_1.dim_value()));
35
+ defs::math::utils::MathOpTwoIntegers(op_type, input_dim_0.dim_value(), input_dim_1.dim_value()));
47
36
  } else {
48
37
  // Cannot compute the value; simply add an empty dim without value and param
49
38
  tsp.mutable_dim()->Add();
@@ -341,7 +330,7 @@ ONNX_OPERATOR_SET_SCHEMA(
341
330
  }
342
331
  )ONNX"));
343
332
 
344
- static const char* ThresholdedRelu_ver10_doc = R"DOC(
333
+ static const char* ThresholdedRelu_ver22_doc = R"DOC(
345
334
  ThresholdedRelu takes one input data (Tensor<T>) and produces one output data
346
335
  (Tensor<T>) where the rectified linear function, y = x for x > alpha, y = 0 otherwise,
347
336
  is applied to the tensor elementwise.
@@ -349,16 +338,13 @@ is applied to the tensor elementwise.
349
338
 
350
339
  ONNX_OPERATOR_SET_SCHEMA(
351
340
  ThresholdedRelu,
352
- 10,
341
+ 22,
353
342
  OpSchema()
354
- .SetDoc(ThresholdedRelu_ver10_doc)
343
+ .SetDoc(ThresholdedRelu_ver22_doc)
355
344
  .Attr("alpha", "Threshold value", AttributeProto::FLOAT, 1.0f)
356
345
  .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
357
346
  .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
358
- .TypeConstraint(
359
- "T",
360
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
361
- "Constrain input and output types to float tensors.")
347
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
362
348
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
363
349
  .FunctionBody(
364
350
  R"ONNX(
@@ -373,7 +359,7 @@ ONNX_OPERATOR_SET_SCHEMA(
373
359
  )ONNX",
374
360
  18));
375
361
 
376
- static const char* Selu_ver6_doc = R"DOC(
362
+ static const char* Selu_ver22_doc = R"DOC(
377
363
  Selu takes one input data (Tensor<T>) and produces one output data
378
364
  (Tensor<T>) where the scaled exponential linear unit function,
379
365
  `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`,
@@ -382,7 +368,7 @@ is applied to the tensor elementwise.
382
368
 
383
369
  ONNX_OPERATOR_SET_SCHEMA(
384
370
  Selu,
385
- 6,
371
+ 22,
386
372
  OpSchema()
387
373
  .Attr(
388
374
  "alpha",
@@ -396,13 +382,10 @@ ONNX_OPERATOR_SET_SCHEMA(
396
382
  "(i.e., float32 approximation of 1.0507009873554804934193349852946).",
397
383
  AttributeProto::FLOAT,
398
384
  1.05070102214813232421875f)
399
- .SetDoc(Selu_ver6_doc)
385
+ .SetDoc(Selu_ver22_doc)
400
386
  .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
401
387
  .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
402
- .TypeConstraint(
403
- "T",
404
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
405
- "Constrain input and output types to float tensors.")
388
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
406
389
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
407
390
  .FunctionBody(
408
391
  R"ONNX(
@@ -424,7 +407,7 @@ ONNX_OPERATOR_SET_SCHEMA(
424
407
  )ONNX",
425
408
  18));
426
409
 
427
- static const char* Elu_ver6_doc = R"DOC(
410
+ static const char* Elu_ver22_doc = R"DOC(
428
411
  Elu takes one input data (Tensor<T>) and produces one output data
429
412
  (Tensor<T>) where the function `f(x) = alpha * (exp(x) - 1.) for x <
430
413
  0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise.
@@ -433,16 +416,13 @@ Elu takes one input data (Tensor<T>) and produces one output data
433
416
 
434
417
  ONNX_OPERATOR_SET_SCHEMA(
435
418
  Elu,
436
- 6,
419
+ 22,
437
420
  OpSchema()
438
421
  .Attr("alpha", "Coefficient of ELU.", AttributeProto::FLOAT, 1.0f)
439
- .SetDoc(Elu_ver6_doc)
422
+ .SetDoc(Elu_ver22_doc)
440
423
  .Input(0, "X", "1D input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
441
424
  .Output(0, "Y", "1D output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
442
- .TypeConstraint(
443
- "T",
444
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
445
- "Constrain input and output types to float tensors.")
425
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
446
426
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
447
427
  .FunctionBody(
448
428
  R"ONNX(
@@ -462,7 +442,7 @@ ONNX_OPERATOR_SET_SCHEMA(
462
442
  )ONNX",
463
443
  18));
464
444
 
465
- static const char* mish_ver18_doc = R"DOC(
445
+ static const char* mish_ver22_doc = R"DOC(
466
446
  Mish: A Self Regularized Non-Monotonic Neural Activation Function.
467
447
 
468
448
  Perform the linear unit element-wise on the input tensor X using formula:
@@ -474,15 +454,12 @@ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
474
454
 
475
455
  ONNX_OPERATOR_SET_SCHEMA(
476
456
  Mish,
477
- 18,
457
+ 22,
478
458
  OpSchema()
479
- .SetDoc(mish_ver18_doc)
459
+ .SetDoc(mish_ver22_doc)
480
460
  .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
481
461
  .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
482
- .TypeConstraint(
483
- "T",
484
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
485
- "Constrain input X and output types to float tensors.")
462
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input X and output types to float tensors.")
486
463
  .FunctionBody(R"ONNX(
487
464
  {
488
465
  Softplus_X = Softplus (X)
@@ -664,10 +641,7 @@ ONNX_OPERATOR_SET_SCHEMA(
664
641
  true,
665
642
  1,
666
643
  OpSchema::Differentiable)
667
- .TypeConstraint(
668
- "T",
669
- {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
670
- "Constrain input and output types to float tensors.")
644
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
671
645
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
672
646
 
673
647
  static const char* Log_ver13_doc = R"DOC(
@@ -716,10 +690,7 @@ ONNX_OPERATOR_SET_SCHEMA(
716
690
  true,
717
691
  1,
718
692
  OpSchema::Differentiable)
719
- .TypeConstraint(
720
- "T",
721
- {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
722
- "Constrain input and output types to float tensors.")
693
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
723
694
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
724
695
 
725
696
  static const char* Pow_ver15_doc = R"DOC(
@@ -842,7 +813,7 @@ ONNX_OPERATOR_SET_SCHEMA(
842
813
  "Constrain input and output types to float tensors.")
843
814
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
844
815
 
845
- static const char* HardSigmoid_ver6_doc = R"DOC(
816
+ static const char* HardSigmoid_ver22_doc = R"DOC(
846
817
  HardSigmoid takes one input data (Tensor<T>) and produces one output data
847
818
  (Tensor<T>) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)),
848
819
  is applied to the tensor elementwise.
@@ -850,17 +821,14 @@ is applied to the tensor elementwise.
850
821
 
851
822
  ONNX_OPERATOR_SET_SCHEMA(
852
823
  HardSigmoid,
853
- 6,
824
+ 22,
854
825
  OpSchema()
855
826
  .Attr("alpha", "Value of alpha.", AttributeProto::FLOAT, 0.2f)
856
827
  .Attr("beta", "Value of beta.", AttributeProto::FLOAT, 0.5f)
857
- .SetDoc(HardSigmoid_ver6_doc)
828
+ .SetDoc(HardSigmoid_ver22_doc)
858
829
  .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
859
830
  .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
860
- .TypeConstraint(
861
- "T",
862
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
863
- "Constrain input and output types to float tensors.")
831
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
864
832
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
865
833
  .FunctionBody(
866
834
  R"ONNX(
@@ -881,7 +849,7 @@ ONNX_OPERATOR_SET_SCHEMA(
881
849
  )ONNX",
882
850
  18));
883
851
 
884
- static const char* HardSwish_ver14_doc = R"DOC(
852
+ static const char* HardSwish_ver22_doc = R"DOC(
885
853
  HardSwish takes one input data (Tensor<T>) and produces one output data (Tensor<T>) where
886
854
  the HardSwish function, y = x * max(0, min(1, alpha * x + beta)) = x * HardSigmoid<alpha, beta>(x),
887
855
  where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise.
@@ -889,15 +857,12 @@ where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise.
889
857
 
890
858
  ONNX_OPERATOR_SET_SCHEMA(
891
859
  HardSwish,
892
- 14,
860
+ 22,
893
861
  OpSchema()
894
- .SetDoc(HardSwish_ver14_doc)
862
+ .SetDoc(HardSwish_ver22_doc)
895
863
  .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
896
864
  .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
897
- .TypeConstraint(
898
- "T",
899
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
900
- "Constrain input and output types to float tensors.")
865
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
901
866
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
902
867
  .FunctionBody(R"ONNX(
903
868
  {
@@ -1232,15 +1197,15 @@ ONNX_OPERATOR_SET_SCHEMA(
1232
1197
  "hardmax",
1233
1198
  "Hardmax(element in input, axis) = 1 if the element is the first maximum value along the specified axis, 0 otherwise")));
1234
1199
 
1235
- static const char* Softsign_ver1_doc = R"DOC(
1200
+ static const char* Softsign_ver22_doc = R"DOC(
1236
1201
  Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise.
1237
1202
  )DOC";
1238
1203
 
1239
1204
  ONNX_OPERATOR_SET_SCHEMA(
1240
1205
  Softsign,
1241
- 1,
1206
+ 22,
1242
1207
  OpSchema()
1243
- .SetDoc(Softsign_ver1_doc)
1208
+ .SetDoc(Softsign_ver22_doc)
1244
1209
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1245
1210
  .Output(
1246
1211
  0,
@@ -1251,10 +1216,7 @@ ONNX_OPERATOR_SET_SCHEMA(
1251
1216
  true,
1252
1217
  1,
1253
1218
  OpSchema::Differentiable)
1254
- .TypeConstraint(
1255
- "T",
1256
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1257
- "Constrain input and output types to float tensors.")
1219
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1258
1220
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
1259
1221
  .FunctionBody(
1260
1222
  R"ONNX(
@@ -1268,7 +1230,7 @@ ONNX_OPERATOR_SET_SCHEMA(
1268
1230
  )ONNX",
1269
1231
  18));
1270
1232
 
1271
- static const char* Softplus_ver1_doc = R"DOC(
1233
+ static const char* Softplus_ver22_doc = R"DOC(
1272
1234
  Softplus takes one input data (Tensor<T>) and produces one output data
1273
1235
  (Tensor<T>) where the softplus function, y = ln(exp(x) + 1), is applied to
1274
1236
  the tensor elementwise.
@@ -1276,15 +1238,12 @@ the tensor elementwise.
1276
1238
 
1277
1239
  ONNX_OPERATOR_SET_SCHEMA(
1278
1240
  Softplus,
1279
- 1,
1241
+ 22,
1280
1242
  OpSchema()
1281
- .SetDoc(Softplus_ver1_doc)
1243
+ .SetDoc(Softplus_ver22_doc)
1282
1244
  .Input(0, "X", "1D input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1283
1245
  .Output(0, "Y", "1D input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1284
- .TypeConstraint(
1285
- "T",
1286
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1287
- "Constrain input and output types to float tensors.")
1246
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1288
1247
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)
1289
1248
  .FunctionBody(
1290
1249
  R"ONNX(
@@ -1386,7 +1345,7 @@ ONNX_OPERATOR_SET_SCHEMA(
1386
1345
  }));
1387
1346
 
1388
1347
  static const char* MatMul_ver13_doc = R"DOC(
1389
- Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
1348
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
1390
1349
  )DOC";
1391
1350
 
1392
1351
  ONNX_OPERATOR_SET_SCHEMA(
@@ -1549,15 +1508,15 @@ ONNX_OPERATOR_SET_SCHEMA(
1549
1508
  return;
1550
1509
  }));
1551
1510
 
1552
- static const char* Sin_ver7_doc = R"DOC(
1511
+ static const char* Sin_ver22_doc = R"DOC(
1553
1512
  Calculates the sine of the given input tensor, element-wise.
1554
1513
  )DOC";
1555
1514
 
1556
1515
  ONNX_OPERATOR_SET_SCHEMA(
1557
1516
  Sin,
1558
- 7,
1517
+ 22,
1559
1518
  OpSchema()
1560
- .SetDoc(Sin_ver7_doc)
1519
+ .SetDoc(Sin_ver22_doc)
1561
1520
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1562
1521
  .Output(
1563
1522
  0,
@@ -1569,21 +1528,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1569
1528
  true,
1570
1529
  1,
1571
1530
  OpSchema::Differentiable)
1572
- .TypeConstraint(
1573
- "T",
1574
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1575
- "Constrain input and output types to float tensors.")
1531
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1576
1532
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1577
1533
 
1578
- static const char* Cos_ver7_doc = R"DOC(
1534
+ static const char* Cos_ver22_doc = R"DOC(
1579
1535
  Calculates the cosine of the given input tensor, element-wise.
1580
1536
  )DOC";
1581
1537
 
1582
1538
  ONNX_OPERATOR_SET_SCHEMA(
1583
1539
  Cos,
1584
- 7,
1540
+ 22,
1585
1541
  OpSchema()
1586
- .SetDoc(Cos_ver7_doc)
1542
+ .SetDoc(Cos_ver22_doc)
1587
1543
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1588
1544
  .Output(
1589
1545
  0,
@@ -1595,21 +1551,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1595
1551
  true,
1596
1552
  1,
1597
1553
  OpSchema::Differentiable)
1598
- .TypeConstraint(
1599
- "T",
1600
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1601
- "Constrain input and output types to float tensors.")
1554
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1602
1555
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1603
1556
 
1604
- static const char* Tan_ver7_doc = R"DOC(
1557
+ static const char* Tan_ver22_doc = R"DOC(
1605
1558
  Calculates the tangent of the given input tensor, element-wise.
1606
1559
  )DOC";
1607
1560
 
1608
1561
  ONNX_OPERATOR_SET_SCHEMA(
1609
1562
  Tan,
1610
- 7,
1563
+ 22,
1611
1564
  OpSchema()
1612
- .SetDoc(Tan_ver7_doc)
1565
+ .SetDoc(Tan_ver22_doc)
1613
1566
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1614
1567
  .Output(
1615
1568
  0,
@@ -1621,21 +1574,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1621
1574
  true,
1622
1575
  1,
1623
1576
  OpSchema::Differentiable)
1624
- .TypeConstraint(
1625
- "T",
1626
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1627
- "Constrain input and output types to float tensors.")
1577
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1628
1578
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1629
1579
 
1630
- static const char* Asin_ver7_doc = R"DOC(
1580
+ static const char* Asin_ver22_doc = R"DOC(
1631
1581
  Calculates the arcsine (inverse of sine) of the given input tensor, element-wise.
1632
1582
  )DOC";
1633
1583
 
1634
1584
  ONNX_OPERATOR_SET_SCHEMA(
1635
1585
  Asin,
1636
- 7,
1586
+ 22,
1637
1587
  OpSchema()
1638
- .SetDoc(Asin_ver7_doc)
1588
+ .SetDoc(Asin_ver22_doc)
1639
1589
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1640
1590
  .Output(
1641
1591
  0,
@@ -1647,21 +1597,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1647
1597
  true,
1648
1598
  1,
1649
1599
  OpSchema::Differentiable)
1650
- .TypeConstraint(
1651
- "T",
1652
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1653
- "Constrain input and output types to float tensors.")
1600
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1654
1601
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1655
1602
 
1656
- static const char* Acos_ver7_doc = R"DOC(
1603
+ static const char* Acos_ver22_doc = R"DOC(
1657
1604
  Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise.
1658
1605
  )DOC";
1659
1606
 
1660
1607
  ONNX_OPERATOR_SET_SCHEMA(
1661
1608
  Acos,
1662
- 7,
1609
+ 22,
1663
1610
  OpSchema()
1664
- .SetDoc(Acos_ver7_doc)
1611
+ .SetDoc(Acos_ver22_doc)
1665
1612
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1666
1613
  .Output(
1667
1614
  0,
@@ -1673,21 +1620,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1673
1620
  true,
1674
1621
  1,
1675
1622
  OpSchema::Differentiable)
1676
- .TypeConstraint(
1677
- "T",
1678
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1679
- "Constrain input and output types to float tensors.")
1623
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1680
1624
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1681
1625
 
1682
- static const char* Atan_ver7_doc = R"DOC(
1626
+ static const char* Atan_ver22_doc = R"DOC(
1683
1627
  Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise.
1684
1628
  )DOC";
1685
1629
 
1686
1630
  ONNX_OPERATOR_SET_SCHEMA(
1687
1631
  Atan,
1688
- 7,
1632
+ 22,
1689
1633
  OpSchema()
1690
- .SetDoc(Atan_ver7_doc)
1634
+ .SetDoc(Atan_ver22_doc)
1691
1635
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1692
1636
  .Output(
1693
1637
  0,
@@ -1699,10 +1643,7 @@ ONNX_OPERATOR_SET_SCHEMA(
1699
1643
  true,
1700
1644
  1,
1701
1645
  OpSchema::Differentiable)
1702
- .TypeConstraint(
1703
- "T",
1704
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1705
- "Constrain input and output types to float tensors.")
1646
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1706
1647
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1707
1648
 
1708
1649
  static const char* Expand_ver13_doc = R"DOC(
@@ -1749,15 +1690,15 @@ ONNX_OPERATOR_SET_SCHEMA(
1749
1690
  }
1750
1691
  }));
1751
1692
 
1752
- static const char* Sinh_ver9_doc = R"DOC(
1693
+ static const char* Sinh_ver22_doc = R"DOC(
1753
1694
  Calculates the hyperbolic sine of the given input tensor element-wise.
1754
1695
  )DOC";
1755
1696
 
1756
1697
  ONNX_OPERATOR_SET_SCHEMA(
1757
1698
  Sinh,
1758
- 9,
1699
+ 22,
1759
1700
  OpSchema()
1760
- .SetDoc(Sinh_ver9_doc)
1701
+ .SetDoc(Sinh_ver22_doc)
1761
1702
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1762
1703
  .Output(
1763
1704
  0,
@@ -1769,21 +1710,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1769
1710
  true,
1770
1711
  1,
1771
1712
  OpSchema::Differentiable)
1772
- .TypeConstraint(
1773
- "T",
1774
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1775
- "Constrain input and output types to float tensors.")
1713
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1776
1714
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1777
1715
 
1778
- static const char* Cosh_ver9_doc = R"DOC(
1716
+ static const char* Cosh_ver22_doc = R"DOC(
1779
1717
  Calculates the hyperbolic cosine of the given input tensor element-wise.
1780
1718
  )DOC";
1781
1719
 
1782
1720
  ONNX_OPERATOR_SET_SCHEMA(
1783
1721
  Cosh,
1784
- 9,
1722
+ 22,
1785
1723
  OpSchema()
1786
- .SetDoc(Cosh_ver9_doc)
1724
+ .SetDoc(Cosh_ver22_doc)
1787
1725
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1788
1726
  .Output(
1789
1727
  0,
@@ -1795,21 +1733,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1795
1733
  true,
1796
1734
  1,
1797
1735
  OpSchema::Differentiable)
1798
- .TypeConstraint(
1799
- "T",
1800
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1801
- "Constrain input and output types to float tensors.")
1736
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1802
1737
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1803
1738
 
1804
- static const char* Asinh_ver9_doc = R"DOC(
1739
+ static const char* Asinh_ver22_doc = R"DOC(
1805
1740
  Calculates the hyperbolic arcsine of the given input tensor element-wise.
1806
1741
  )DOC";
1807
1742
 
1808
1743
  ONNX_OPERATOR_SET_SCHEMA(
1809
1744
  Asinh,
1810
- 9,
1745
+ 22,
1811
1746
  OpSchema()
1812
- .SetDoc(Asinh_ver9_doc)
1747
+ .SetDoc(Asinh_ver22_doc)
1813
1748
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1814
1749
  .Output(
1815
1750
  0,
@@ -1821,21 +1756,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1821
1756
  true,
1822
1757
  1,
1823
1758
  OpSchema::Differentiable)
1824
- .TypeConstraint(
1825
- "T",
1826
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1827
- "Constrain input and output types to float tensors.")
1759
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1828
1760
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1829
1761
 
1830
- static const char* Acosh_ver9_doc = R"DOC(
1762
+ static const char* Acosh_ver22_doc = R"DOC(
1831
1763
  Calculates the hyperbolic arccosine of the given input tensor element-wise.
1832
1764
  )DOC";
1833
1765
 
1834
1766
  ONNX_OPERATOR_SET_SCHEMA(
1835
1767
  Acosh,
1836
- 9,
1768
+ 22,
1837
1769
  OpSchema()
1838
- .SetDoc(Acosh_ver9_doc)
1770
+ .SetDoc(Acosh_ver22_doc)
1839
1771
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1840
1772
  .Output(
1841
1773
  0,
@@ -1847,21 +1779,18 @@ ONNX_OPERATOR_SET_SCHEMA(
1847
1779
  true,
1848
1780
  1,
1849
1781
  OpSchema::Differentiable)
1850
- .TypeConstraint(
1851
- "T",
1852
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1853
- "Constrain input and output types to float tensors.")
1782
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1854
1783
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1855
1784
 
1856
- static const char* Atanh_ver9_doc = R"DOC(
1785
+ static const char* Atanh_ver22_doc = R"DOC(
1857
1786
  Calculates the hyperbolic arctangent of the given input tensor element-wise.
1858
1787
  )DOC";
1859
1788
 
1860
1789
  ONNX_OPERATOR_SET_SCHEMA(
1861
1790
  Atanh,
1862
- 9,
1791
+ 22,
1863
1792
  OpSchema()
1864
- .SetDoc(Atanh_ver9_doc)
1793
+ .SetDoc(Atanh_ver22_doc)
1865
1794
  .Input(0, "input", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
1866
1795
  .Output(
1867
1796
  0,
@@ -1873,10 +1802,7 @@ ONNX_OPERATOR_SET_SCHEMA(
1873
1802
  true,
1874
1803
  1,
1875
1804
  OpSchema::Differentiable)
1876
- .TypeConstraint(
1877
- "T",
1878
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
1879
- "Constrain input and output types to float tensors.")
1805
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
1880
1806
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
1881
1807
 
1882
1808
  static const char* Sign_ver13_doc = R"DOC(
@@ -2017,7 +1943,7 @@ ONNX_OPERATOR_SET_SCHEMA(
2017
1943
  .TypeAndShapeInferenceFunction(defs::math::utils::QLinearMatMulShapeInference));
2018
1944
 
2019
1945
  static const char* MatMulInteger_ver10_doc = R"DOC(
2020
- Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html.
1946
+ Matrix product that behaves like [numpy.matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html).
2021
1947
  The production MUST never overflow. The accumulation may overflow if and only if in 32 bits.
2022
1948
  )DOC";
2023
1949
 
@@ -2154,7 +2080,7 @@ ONNX_OPERATOR_SET_SCHEMA(
2154
2080
  .TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "axis tensor can be int32 or int64 only")
2155
2081
  .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
2156
2082
 
2157
- static const char* Round_ver11_doc = R"DOC(
2083
+ static const char* Round_ver22_doc = R"DOC(
2158
2084
  Round takes one input Tensor and rounds the values, element-wise, meaning
2159
2085
  it finds the nearest integer for each value.
2160
2086
  In case of halves, the rule is to round them to the nearest even integer.
@@ -2173,18 +2099,15 @@ round([-4.5]) = [-4.0]
2173
2099
 
2174
2100
  ONNX_OPERATOR_SET_SCHEMA(
2175
2101
  Round,
2176
- 11,
2102
+ 22,
2177
2103
  OpSchema()
2178
- .SetDoc(Round_ver11_doc)
2104
+ .SetDoc(Round_ver22_doc)
2179
2105
  .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
2180
2106
  .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
2181
- .TypeConstraint(
2182
- "T",
2183
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
2184
- "Constrain input and output types to float tensors.")
2107
+ .TypeConstraint("T", OpSchema::all_float_types_ir4(), "Constrain input and output types to float tensors.")
2185
2108
  .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));
2186
2109
 
2187
- static const char* Det_ver11_doc = R"DOC(
2110
+ static const char* Det_ver22_doc = R"DOC(
2188
2111
  Det calculates determinant of a square matrix or batches of square matrices.
2189
2112
  Det takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions,
2190
2113
  and the inner-most 2 dimensions form square matrices.
@@ -2194,14 +2117,14 @@ e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`).
2194
2117
 
2195
2118
  ONNX_OPERATOR_SET_SCHEMA(
2196
2119
  Det,
2197
- 11,
2120
+ 22,
2198
2121
  OpSchema()
2199
- .SetDoc(Det_ver11_doc)
2122
+ .SetDoc(Det_ver22_doc)
2200
2123
  .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
2201
2124
  .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable)
2202
2125
  .TypeConstraint(
2203
2126
  "T",
2204
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
2127
+ OpSchema::all_float_types_ir4(),
2205
2128
  "Constrain input and output types to floating-point tensors.")
2206
2129
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
2207
2130
  // Type inference
@@ -2235,110 +2158,6 @@ ONNX_OPERATOR_SET_SCHEMA(
2235
2158
  }
2236
2159
  }));
2237
2160
 
2238
- static const char* NegativeLogLikelihoodLoss_ver13_doc = R"DOC(
2239
- A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss.
2240
- Its "input" tensor has the shape of (N, C, d1, d2, ..., dk) where k >= 0.
2241
- The "input" tensor contains log-probabilities for input[n, :, d_1, d_2,..., d_k] being in a class of [0, C).
2242
- The operator's "target" input tensor has the shape of (N, d1, d2, ..., dk). It encodes class labels (one of C classes)
2243
- or it may contain a special value (indicated by an attribute ignore_index) for N x d1 x d2 x ... x dk samples.
2244
- The loss value for input[n, :, d_1, d_2,...d_k] being classified as class c = target[n][d_1][d_2]...[d_k] is computed as:
2245
-
2246
- ```
2247
- loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k].
2248
- ```
2249
-
2250
- When an optional "weight" is provided, the sample loss is calculated as:
2251
-
2252
- ```
2253
- loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k] * weight[c].
2254
- ```
2255
-
2256
- loss is zero for the case when target-value equals ignore_index.
2257
-
2258
- ```
2259
- loss[n][d_1][d_2]...[d_k] = 0, when target[n][d_1][d_2]...[d_k] = ignore_index
2260
- ```
2261
-
2262
- If "reduction" attribute is set to "none", the operator's output will be the above loss with shape (N, d1, d2, ..., dk).
2263
- If "reduction" attribute is set to "mean" (the default attribute value), the output loss is (weight) averaged:
2264
-
2265
- ```
2266
- mean(loss), if "weight" is not provided,
2267
- ```
2268
-
2269
- or if weight is provided,
2270
-
2271
- ```
2272
- sum(loss) / sum(weight[target[n][d_1][d_2]...[d_k]]]), for all samples.
2273
- ```
2274
-
2275
- If "reduction" attribute is set to "sum", the output is a scalar: `sum(loss)`.
2276
-
2277
- See also https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss.
2278
-
2279
- Example 1:
2280
-
2281
- ```
2282
- // negative log likelihood loss, "none" reduction
2283
- N, C, d1 = 2, 3, 2
2284
- input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
2285
- [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
2286
- target = [[2, 1], [0, 2]]
2287
-
2288
- loss = np.zeros((N, d1))
2289
- for n in range(N):
2290
- for d_1 in range(d1):
2291
- c = target[n][d_1]
2292
- loss[n][d_1] = -input[n][c][d_1]
2293
-
2294
- // print(loss)
2295
- // [[-3. -2.]
2296
- // [-0. -2.]]
2297
- ```
2298
-
2299
- Example 2:
2300
-
2301
- ```
2302
- // weighted negative log likelihood loss, sum reduction
2303
- N, C, d1 = 2, 3, 2
2304
- input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
2305
- [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
2306
- target = [[2, 1], [0, 2]]
2307
- weight = [0.2, 0.3, 0.1]
2308
- loss = np.zeros((N, d1))
2309
- for n in range(N):
2310
- for d_1 in range(d1):
2311
- c = target[n][d_1]
2312
- loss[n][d_1] = -input[n][c][d_1] * weight[c]
2313
-
2314
- loss = np.sum(loss)
2315
- // print(loss)
2316
- // -1.1
2317
- ```
2318
-
2319
- Example 3:
2320
-
2321
- ```
2322
- // weighted negative log likelihood loss, mean reduction
2323
- N, C, d1 = 2, 3, 2
2324
- input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
2325
- [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
2326
- target = [[2, 1], [0, 2]]
2327
- weight = [0.2, 0.3, 0.1]
2328
- loss = np.zeros((N, d1))
2329
- weight_total = 0
2330
- for n in range(N):
2331
- for d_1 in range(d1):
2332
- c = target[n][d_1]
2333
- loss[n][d_1] = -input[n][c][d_1] * weight[c]
2334
- weight_total = weight_total + weight[c]
2335
-
2336
- loss = np.sum(loss) / weight_total
2337
- // print(loss)
2338
- // -1.57
2339
- ```
2340
- )DOC";
2341
-
2342
2161
  bool BuildContextDependentFunctionBody(
2343
2162
  const FunctionBodyBuildContext& ctx,
2344
2163
  const OpSchema& schema,
@@ -2451,11 +2270,115 @@ bool BuildContextDependentFunctionBody(
2451
2270
  return true;
2452
2271
  }
2453
2272
 
2273
+ static const char* NegativeLogLikelihoodLoss_ver22_doc = R"DOC(
2274
+ A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss.
2275
+ Its "input" tensor has the shape of (N, C, d1, d2, ..., dk) where k >= 0.
2276
+ The "input" tensor contains log-probabilities for input[n, :, d_1, d_2,..., d_k] being in a class of [0, C).
2277
+ The operator's "target" input tensor has the shape of (N, d1, d2, ..., dk). It encodes class labels (one of C classes)
2278
+ or it may contain a special value (indicated by an attribute ignore_index) for N x d1 x d2 x ... x dk samples.
2279
+ The loss value for input[n, :, d_1, d_2,...d_k] being classified as class c = target[n][d_1][d_2]...[d_k] is computed as:
2280
+
2281
+ ```
2282
+ loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k].
2283
+ ```
2284
+
2285
+ When an optional "weight" is provided, the sample loss is calculated as:
2286
+
2287
+ ```
2288
+ loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k] * weight[c].
2289
+ ```
2290
+
2291
+ loss is zero for the case when target-value equals ignore_index.
2292
+
2293
+ ```
2294
+ loss[n][d_1][d_2]...[d_k] = 0, when target[n][d_1][d_2]...[d_k] = ignore_index
2295
+ ```
2296
+
2297
+ If "reduction" attribute is set to "none", the operator's output will be the above loss with shape (N, d1, d2, ..., dk).
2298
+ If "reduction" attribute is set to "mean" (the default attribute value), the output loss is (weight) averaged:
2299
+
2300
+ ```
2301
+ mean(loss), if "weight" is not provided,
2302
+ ```
2303
+
2304
+ or if weight is provided,
2305
+
2306
+ ```
2307
+ sum(loss) / sum(weight[target[n][d_1][d_2]...[d_k]]]), for all samples.
2308
+ ```
2309
+
2310
+ If "reduction" attribute is set to "sum", the output is a scalar: `sum(loss)`.
2311
+
2312
+ See also https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss.
2313
+
2314
+ Example 1:
2315
+
2316
+ ```
2317
+ // negative log likelihood loss, "none" reduction
2318
+ N, C, d1 = 2, 3, 2
2319
+ input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
2320
+ [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
2321
+ target = [[2, 1], [0, 2]]
2322
+
2323
+ loss = np.zeros((N, d1))
2324
+ for n in range(N):
2325
+ for d_1 in range(d1):
2326
+ c = target[n][d_1]
2327
+ loss[n][d_1] = -input[n][c][d_1]
2328
+
2329
+ // print(loss)
2330
+ // [[-3. -2.]
2331
+ // [-0. -2.]]
2332
+ ```
2333
+
2334
+ Example 2:
2335
+
2336
+ ```
2337
+ // weighted negative log likelihood loss, sum reduction
2338
+ N, C, d1 = 2, 3, 2
2339
+ input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
2340
+ [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
2341
+ target = [[2, 1], [0, 2]]
2342
+ weight = [0.2, 0.3, 0.1]
2343
+ loss = np.zeros((N, d1))
2344
+ for n in range(N):
2345
+ for d_1 in range(d1):
2346
+ c = target[n][d_1]
2347
+ loss[n][d_1] = -input[n][c][d_1] * weight[c]
2348
+
2349
+ loss = np.sum(loss)
2350
+ // print(loss)
2351
+ // -1.1
2352
+ ```
2353
+
2354
+ Example 3:
2355
+
2356
+ ```
2357
+ // weighted negative log likelihood loss, mean reduction
2358
+ N, C, d1 = 2, 3, 2
2359
+ input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]],
2360
+ [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]]
2361
+ target = [[2, 1], [0, 2]]
2362
+ weight = [0.2, 0.3, 0.1]
2363
+ loss = np.zeros((N, d1))
2364
+ weight_total = 0
2365
+ for n in range(N):
2366
+ for d_1 in range(d1):
2367
+ c = target[n][d_1]
2368
+ loss[n][d_1] = -input[n][c][d_1] * weight[c]
2369
+ weight_total = weight_total + weight[c]
2370
+
2371
+ loss = np.sum(loss) / weight_total
2372
+ // print(loss)
2373
+ // -1.57
2374
+ ```
2375
+ )DOC";
2376
+
2454
2377
  ONNX_OPERATOR_SET_SCHEMA(
2455
2378
  NegativeLogLikelihoodLoss,
2456
- 13,
2379
+ 22,
2457
2380
  OpSchema()
2458
- .SetDoc(NegativeLogLikelihoodLoss_ver13_doc)
2381
+ .SetDoc(NegativeLogLikelihoodLoss_ver22_doc)
2459
2382
  .Input(
2460
2383
  0,
2461
2384
  "input",
@@ -2502,7 +2425,7 @@ ONNX_OPERATOR_SET_SCHEMA(
2502
2425
  false)
2503
2426
  .TypeConstraint(
2504
2427
  "T",
2505
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
2428
+ OpSchema::all_float_types_ir4(),
2506
2429
  "Constrain input, weight, and output types to floating-point tensors.")
2507
2430
  .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain target to integer types")
2508
2431
  .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBody)
@@ -2519,10 +2442,14 @@ ONNX_OPERATOR_SET_SCHEMA(
2519
2442
  const int target_rank = static_cast<int>(target_shape.dim_size());
2520
2443
 
2521
2444
  if (input_rank < 2) {
2522
- fail_shape_inference("Input rank must be >= 2.")
2445
+ fail_shape_inference("Input rank must be >= 2. input_rank=", input_rank);
2523
2446
  }
2524
2447
  if (target_rank != input_rank - 1) {
2525
- fail_shape_inference("Target rank must be 1 less than the input rank.");
2448
+ fail_shape_inference(
2449
+ "Target rank must be 1 less than the input rank. input_rank=",
2450
+ input_rank,
2451
+ ", target_rank=",
2452
+ target_rank);
2526
2453
  }
2527
2454
 
2528
2455
  // match input dimensions (N, C, d1, ..., dk) with target
@@ -2532,13 +2459,18 @@ ONNX_OPERATOR_SET_SCHEMA(
2532
2459
  const auto target_dim = target_shape.dim(dim);
2533
2460
  if (input_dim.has_dim_value() && target_dim.has_dim_value() &&
2534
2461
  input_dim.dim_value() != target_dim.dim_value())
2535
- fail_shape_inference("Input and target dimension value mismatch.");
2462
+ fail_shape_inference(
2463
+ "Input and target dimension value mismatch. input_dim_value=",
2464
+ input_dim.dim_value(),
2465
+ " target_dim_value=",
2466
+ target_dim.dim_value());
2536
2467
  }
2537
2468
 
2538
2469
  if (ctx.getNumInputs() == 3 && hasInputShape(ctx, 2)) {
2539
2470
  const TensorShapeProto& weight_shape = ctx.getInputType(2)->tensor_type().shape();
2540
- if (weight_shape.dim_size() != 1) {
2541
- fail_shape_inference("Weight rank must be 1.");
2471
+ const auto weight_rank = weight_shape.dim_size();
2472
+ if (weight_rank != 1) {
2473
+ fail_shape_inference("Weight rank must be 1. weight_rank=", weight_rank);
2542
2474
  }
2543
2475
  }
2544
2476
 
@@ -2559,17 +2491,17 @@ ONNX_OPERATOR_SET_SCHEMA(
2559
2491
  }
2560
2492
  }));
2561
2493
 
2562
- void einsumRankInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string equation) {
2563
- const size_t numInputs = ctx.getNumInputs();
2564
- if (numInputs < 1 || !hasNInputShapes(ctx, static_cast<int>(numInputs))) {
2494
+ void einsumShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string const& equation) {
2495
+ // Only accept letters for indices
2496
+ auto is_letter = [](char c) { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); };
2497
+
2498
+ const size_t num_inputs = ctx.getNumInputs();
2499
+ if (num_inputs < 1 || !hasNInputShapes(ctx, static_cast<int>(num_inputs))) {
2565
2500
  return;
2566
2501
  }
2567
-
2568
- auto* output_shape = getOutputShape(ctx, 0);
2502
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
2569
2503
  std::string left_equation;
2570
2504
 
2571
- equation.erase(std::remove(equation.begin(), equation.end(), ' '),
2572
- equation.end()); // Remove space char
2573
2505
  auto mid_index = equation.find("->");
2574
2506
  if (mid_index != std::string::npos) {
2575
2507
  // Separate right and left hand sides of the equation
@@ -2586,73 +2518,130 @@ void einsumRankInference(ONNX_NAMESPACE::InferenceContext& ctx, std::string equa
2586
2518
 
2587
2519
  // Parse the left-hand side
2588
2520
  std::stringstream str(left_equation);
2521
+ std::map<char, size_t> label_maps;
2522
+ std::set<char> repeated_labels;
2523
+ ONNX_NAMESPACE::TensorShapeProto dims_value, ellipsis_dims_value;
2524
+ size_t num_labels = 0;
2525
+ bool ellipsis_flag = true;
2526
+
2589
2527
  while (!str.eof()) {
2590
2528
  std::getline(str, term, ',');
2591
2529
  auto ellipsis_index = term.find("...");
2592
- if (numInputs <= num_operands) {
2530
+ if (num_inputs <= num_operands) {
2593
2531
  fail_shape_inference("Number of input tensors does not match the operands in the equation.");
2594
2532
  }
2595
- size_t rank = ctx.getInputType(num_operands)->tensor_type().shape().dim_size();
2533
+ const auto& shape = ctx.getInputType(num_operands)->tensor_type().shape();
2534
+ size_t rank = shape.dim_size();
2535
+ size_t ellipsis_dims = 0;
2536
+
2537
+ size_t term_size = 0; // number of legal indices for the current term
2538
+ size_t num_illegal_char = 0; // number of illegal char before the current 'index' in the current term
2539
+
2540
+ for (size_t index = 0; index < term.size(); ++index) {
2541
+ if (is_letter(term[index])) {
2542
+ term_size += 1;
2543
+ }
2544
+ }
2545
+
2546
+ for (size_t index = 0; index < term.size(); ++index) {
2547
+ if (index == ellipsis_index) {
2548
+ // find ellipsis and record the dims represented by ellipsis
2549
+ ellipsis_dims = rank - term_size;
2550
+ if (ellipsis_flag) {
2551
+ ellipsis_flag = false;
2552
+ for (size_t i = 0; i < ellipsis_dims; i++) {
2553
+ *ellipsis_dims_value.add_dim() = shape.dim(index + i - num_illegal_char);
2554
+ }
2555
+ } else {
2556
+ for (size_t i = 0; i < ellipsis_dims; i++) {
2557
+ const auto shape_dim = shape.dim(index + i - num_illegal_char);
2558
+ const auto current_dim = ellipsis_dims_value.mutable_dim(i);
2559
+ if (shape_dim.has_dim_value() && current_dim->has_dim_value() &&
2560
+ shape_dim.dim_value() > current_dim->dim_value() && current_dim->dim_value() == 1) {
2561
+ current_dim->set_dim_value(shape_dim.dim_value());
2562
+ }
2563
+ }
2564
+ }
2565
+ index += 2; // skip the rest of dots
2566
+ num_illegal_char += 3;
2567
+ continue;
2568
+
2569
+ } else if (!is_letter(term[index])) {
2570
+ num_illegal_char += 1;
2571
+ continue;
2572
+ }
2573
+
2574
+ const auto inserted = label_maps.insert({term[index], num_labels}).second;
2575
+ if (inserted) {
2576
+ *dims_value.add_dim() = shape.dim(index + ellipsis_dims - num_illegal_char);
2577
+ ++num_labels;
2578
+ } else {
2579
+ repeated_labels.insert(term[index]);
2580
+ }
2581
+ }
2582
+
2596
2583
  if (ellipsis_index != std::string::npos) {
2597
2584
  // If there is an ellipsis, the number of dimensions it represents
2598
2585
  // must be total dim - letter dimensions
2599
2586
  if (num_ellipsis == 0) {
2600
- if (rank + 3 < term.size()) {
2587
+ if (rank < term_size) {
2601
2588
  fail_shape_inference("Ellipsis represents incompatible dimensions.");
2602
2589
  }
2603
- num_ellipsis_indices = rank - term.size() + 3;
2590
+ num_ellipsis_indices = rank - term_size;
2604
2591
  } else { // ellipsis has been seen before. Check that if dimensions
2605
2592
  // are compatible
2606
- if (num_ellipsis_indices != rank - term.size() + 3) {
2593
+ if (num_ellipsis_indices != rank - term_size) {
2607
2594
  fail_shape_inference("Ellipsis represents incompatible dimensions.");
2608
2595
  }
2609
2596
  }
2610
2597
  num_ellipsis++;
2611
2598
  } else {
2612
- if (rank != term.size()) {
2599
+ if (rank != term_size) {
2613
2600
  fail_shape_inference("Rank of input ", num_operands, " does not match the equation indices.");
2614
2601
  }
2615
2602
  }
2616
2603
  num_operands++;
2617
2604
  }
2618
2605
 
2619
- if (numInputs != num_operands) {
2606
+ if (num_inputs != num_operands) {
2620
2607
  fail_shape_inference("Number of input tensors does not match the operands in the equation.");
2621
2608
  }
2622
2609
 
2623
- const size_t number_of_letters = 26;
2624
- size_t num_letter_occurrences[number_of_letters] = {0};
2625
2610
  // Parse the provided right-hand side
2626
2611
  if (mid_index != std::string::npos) {
2627
2612
  std::string right_equation = equation.substr(mid_index + 2);
2628
2613
  auto right_ellipsis_index = right_equation.find("...");
2629
- if (right_ellipsis_index != std::string::npos) { // Right-hand side contains ellipsis
2630
- for (size_t i = 0; i < num_ellipsis_indices; ++i) {
2631
- output_shape->add_dim();
2614
+
2615
+ for (size_t index = 0; index < right_equation.size(); ++index) {
2616
+ // If there's an ellipsis, add its corresponding dimensions
2617
+ if (index == right_ellipsis_index) {
2618
+ for (size_t i = 0; i < num_ellipsis_indices; i++) {
2619
+ *output_shape.add_dim() = ellipsis_dims_value.dim(i);
2620
+ }
2621
+ index += 2; // skip the rest of dots
2622
+ continue;
2632
2623
  }
2633
- }
2634
- for (char c : right_equation) { // Add a dimension per each character
2635
- // in right hand equation
2636
- if (c != '.') {
2637
- output_shape->add_dim();
2624
+
2625
+ if (is_letter(right_equation[index])) {
2626
+ *output_shape.add_dim() = dims_value.dim(label_maps[right_equation[index]]);
2638
2627
  }
2639
2628
  }
2640
2629
  } else { // Infer the dimension for right-hand side
2641
- // If there's an ellipsis, add it's corresponding dimensions
2630
+ // If there's an ellipsis, add its corresponding dimensions
2642
2631
  for (size_t i = 0; i < num_ellipsis_indices; i++) {
2643
- output_shape->add_dim();
2632
+ *output_shape.add_dim() = ellipsis_dims_value.dim(i);
2644
2633
  }
2645
- for (size_t i = 0; i < left_equation.size(); i++) { // Count chars that appear exactly once on left hand side
2646
- if ((left_equation.at(i) != ',') && (left_equation.at(i) != '.')) {
2647
- num_letter_occurrences[left_equation.at(i) - 'a']++;
2648
- }
2649
- }
2650
- for (size_t index = 0; index < number_of_letters; index++) {
2651
- if (num_letter_occurrences[index] == 1) {
2652
- output_shape->add_dim();
2634
+ // If no explicit output was given, generate an implicit output by ordering all the
2635
+ // labels in alphabetic order (by ASCII value consistent with numpy, so Z < a).
2636
+ // Exclude any labels that occurred more than once, as these cancel out.
2637
+ for (auto i : label_maps) {
2638
+ if (repeated_labels.count(i.first) == 0) {
2639
+ *output_shape.add_dim() = dims_value.dim(i.second);
2653
2640
  }
2654
2641
  }
2655
2642
  }
2643
+
2644
+ updateOutputShape(ctx, 0, output_shape);
2656
2645
  }
2657
2646
 
2658
2647
  static const char* Einsum_ver12_doc = R"DOC(
@@ -2702,7 +2691,10 @@ ONNX_OPERATOR_SET_SCHEMA(
2702
2691
  if (equation.compare("") == 0) {
2703
2692
  return;
2704
2693
  }
2705
- einsumRankInference(ctx, equation);
2694
+
2695
+ equation.erase(std::remove(equation.begin(), equation.end(), ' '),
2696
+ equation.end()); // Remove space char
2697
+ einsumShapeInference(ctx, equation);
2706
2698
  }));
2707
2699
 
2708
2700
  const char* reduction_doc_sce =
@@ -3480,6 +3472,9 @@ ONNX_OPERATOR_SET_SCHEMA(
3480
3472
  }
3481
3473
 
3482
3474
  auto& input_shape = getInputShape(ctx, 0);
3475
+ if (input_shape.dim_size() < 2) {
3476
+ fail_shape_inference("First input should have at least 2 dimensions in ", ctx.getDisplayName(), ".");
3477
+ }
3483
3478
  auto signal_dim = input_shape.dim(1);
3484
3479
  if (!signal_dim.has_dim_value()) {
3485
3480
  return;