onnx 1.16.2__cp39-cp39-win32.whl → 1.17.0__cp39-cp39-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx might be problematic. Click here for more details.

Files changed (843) hide show
  1. onnx/__init__.py +3 -1
  2. onnx/_custom_element_types.py +63 -0
  3. onnx/backend/base.py +17 -15
  4. onnx/backend/sample/ops/__init__.py +4 -4
  5. onnx/backend/sample/ops/abs.py +1 -0
  6. onnx/backend/test/__init__.py +1 -0
  7. onnx/backend/test/case/__init__.py +2 -2
  8. onnx/backend/test/case/base.py +6 -5
  9. onnx/backend/test/case/model/__init__.py +4 -3
  10. onnx/backend/test/case/model/expand.py +1 -0
  11. onnx/backend/test/case/model/gradient.py +1 -0
  12. onnx/backend/test/case/model/sequence.py +3 -1
  13. onnx/backend/test/case/model/shrink.py +1 -0
  14. onnx/backend/test/case/model/sign.py +1 -0
  15. onnx/backend/test/case/model/single-relu.py +1 -0
  16. onnx/backend/test/case/model/stringnormalizer.py +1 -1
  17. onnx/backend/test/case/node/__init__.py +31 -22
  18. onnx/backend/test/case/node/_image_decoder_data.py +1 -0
  19. onnx/backend/test/case/node/abs.py +1 -0
  20. onnx/backend/test/case/node/acos.py +1 -0
  21. onnx/backend/test/case/node/acosh.py +1 -0
  22. onnx/backend/test/case/node/adagrad.py +2 -1
  23. onnx/backend/test/case/node/adam.py +4 -1
  24. onnx/backend/test/case/node/add.py +1 -0
  25. onnx/backend/test/case/node/affinegrid.py +1 -0
  26. onnx/backend/test/case/node/ai_onnx_ml/array_feature_extractor.py +1 -0
  27. onnx/backend/test/case/node/ai_onnx_ml/binarizer.py +1 -0
  28. onnx/backend/test/case/node/ai_onnx_ml/label_encoder.py +1 -0
  29. onnx/backend/test/case/node/ai_onnx_ml/tree_ensemble.py +1 -0
  30. onnx/backend/test/case/node/and.py +1 -0
  31. onnx/backend/test/case/node/argmax.py +1 -0
  32. onnx/backend/test/case/node/argmin.py +1 -0
  33. onnx/backend/test/case/node/asin.py +1 -0
  34. onnx/backend/test/case/node/asinh.py +1 -0
  35. onnx/backend/test/case/node/atan.py +1 -0
  36. onnx/backend/test/case/node/atanh.py +1 -0
  37. onnx/backend/test/case/node/averagepool.py +1 -0
  38. onnx/backend/test/case/node/batchnorm.py +1 -0
  39. onnx/backend/test/case/node/bernoulli.py +1 -0
  40. onnx/backend/test/case/node/bitshift.py +1 -0
  41. onnx/backend/test/case/node/bitwiseand.py +1 -0
  42. onnx/backend/test/case/node/bitwisenot.py +1 -0
  43. onnx/backend/test/case/node/bitwiseor.py +1 -0
  44. onnx/backend/test/case/node/bitwisexor.py +1 -0
  45. onnx/backend/test/case/node/blackmanwindow.py +13 -3
  46. onnx/backend/test/case/node/cast.py +2 -1
  47. onnx/backend/test/case/node/castlike.py +1 -0
  48. onnx/backend/test/case/node/ceil.py +1 -0
  49. onnx/backend/test/case/node/celu.py +1 -0
  50. onnx/backend/test/case/node/center_crop_pad.py +1 -0
  51. onnx/backend/test/case/node/clip.py +1 -0
  52. onnx/backend/test/case/node/col2im.py +1 -1
  53. onnx/backend/test/case/node/compress.py +1 -0
  54. onnx/backend/test/case/node/concat.py +3 -2
  55. onnx/backend/test/case/node/constant.py +1 -0
  56. onnx/backend/test/case/node/constantofshape.py +1 -0
  57. onnx/backend/test/case/node/conv.py +1 -0
  58. onnx/backend/test/case/node/convinteger.py +1 -0
  59. onnx/backend/test/case/node/convtranspose.py +135 -0
  60. onnx/backend/test/case/node/cos.py +1 -0
  61. onnx/backend/test/case/node/cosh.py +1 -0
  62. onnx/backend/test/case/node/cumsum.py +1 -0
  63. onnx/backend/test/case/node/deformconv.py +17 -26
  64. onnx/backend/test/case/node/depthtospace.py +1 -0
  65. onnx/backend/test/case/node/dequantizelinear.py +1 -0
  66. onnx/backend/test/case/node/det.py +1 -0
  67. onnx/backend/test/case/node/dft.py +1 -0
  68. onnx/backend/test/case/node/div.py +1 -0
  69. onnx/backend/test/case/node/dropout.py +1 -0
  70. onnx/backend/test/case/node/dynamicquantizelinear.py +1 -0
  71. onnx/backend/test/case/node/einsum.py +2 -3
  72. onnx/backend/test/case/node/elu.py +1 -0
  73. onnx/backend/test/case/node/equal.py +1 -0
  74. onnx/backend/test/case/node/erf.py +1 -0
  75. onnx/backend/test/case/node/exp.py +1 -0
  76. onnx/backend/test/case/node/expand.py +1 -0
  77. onnx/backend/test/case/node/eyelike.py +1 -0
  78. onnx/backend/test/case/node/flatten.py +1 -0
  79. onnx/backend/test/case/node/floor.py +1 -0
  80. onnx/backend/test/case/node/gather.py +1 -0
  81. onnx/backend/test/case/node/gatherelements.py +1 -0
  82. onnx/backend/test/case/node/gathernd.py +1 -0
  83. onnx/backend/test/case/node/gelu.py +1 -0
  84. onnx/backend/test/case/node/gemm.py +3 -4
  85. onnx/backend/test/case/node/globalaveragepool.py +1 -0
  86. onnx/backend/test/case/node/globalmaxpool.py +1 -0
  87. onnx/backend/test/case/node/greater.py +1 -0
  88. onnx/backend/test/case/node/greater_equal.py +1 -0
  89. onnx/backend/test/case/node/gridsample.py +1 -0
  90. onnx/backend/test/case/node/groupnormalization.py +1 -0
  91. onnx/backend/test/case/node/gru.py +3 -2
  92. onnx/backend/test/case/node/hammingwindow.py +13 -2
  93. onnx/backend/test/case/node/hannwindow.py +10 -2
  94. onnx/backend/test/case/node/hardmax.py +1 -0
  95. onnx/backend/test/case/node/hardsigmoid.py +1 -0
  96. onnx/backend/test/case/node/hardswish.py +1 -0
  97. onnx/backend/test/case/node/identity.py +1 -0
  98. onnx/backend/test/case/node/if.py +1 -0
  99. onnx/backend/test/case/node/instancenorm.py +1 -0
  100. onnx/backend/test/case/node/isinf.py +1 -0
  101. onnx/backend/test/case/node/isnan.py +1 -0
  102. onnx/backend/test/case/node/layernormalization.py +1 -0
  103. onnx/backend/test/case/node/leakyrelu.py +1 -0
  104. onnx/backend/test/case/node/less.py +1 -0
  105. onnx/backend/test/case/node/less_equal.py +1 -0
  106. onnx/backend/test/case/node/log.py +1 -0
  107. onnx/backend/test/case/node/logsoftmax.py +1 -0
  108. onnx/backend/test/case/node/loop.py +4 -3
  109. onnx/backend/test/case/node/lppool.py +1 -0
  110. onnx/backend/test/case/node/lrn.py +1 -0
  111. onnx/backend/test/case/node/lstm.py +3 -2
  112. onnx/backend/test/case/node/matmul.py +1 -0
  113. onnx/backend/test/case/node/matmulinteger.py +1 -0
  114. onnx/backend/test/case/node/max.py +1 -0
  115. onnx/backend/test/case/node/maxpool.py +1 -0
  116. onnx/backend/test/case/node/maxunpool.py +1 -0
  117. onnx/backend/test/case/node/mean.py +1 -0
  118. onnx/backend/test/case/node/meanvariancenormalization.py +1 -0
  119. onnx/backend/test/case/node/melweightmatrix.py +1 -0
  120. onnx/backend/test/case/node/min.py +1 -0
  121. onnx/backend/test/case/node/mish.py +1 -0
  122. onnx/backend/test/case/node/mod.py +1 -0
  123. onnx/backend/test/case/node/momentum.py +1 -0
  124. onnx/backend/test/case/node/mul.py +1 -0
  125. onnx/backend/test/case/node/neg.py +1 -0
  126. onnx/backend/test/case/node/negativeloglikelihoodloss.py +4 -1
  127. onnx/backend/test/case/node/nonmaxsuppression.py +1 -0
  128. onnx/backend/test/case/node/nonzero.py +1 -0
  129. onnx/backend/test/case/node/not.py +1 -0
  130. onnx/backend/test/case/node/onehot.py +1 -0
  131. onnx/backend/test/case/node/optionalgetelement.py +3 -2
  132. onnx/backend/test/case/node/optionalhaselement.py +2 -3
  133. onnx/backend/test/case/node/or.py +1 -0
  134. onnx/backend/test/case/node/pad.py +2 -1
  135. onnx/backend/test/case/node/pow.py +1 -0
  136. onnx/backend/test/case/node/prelu.py +1 -0
  137. onnx/backend/test/case/node/qlinearconv.py +1 -0
  138. onnx/backend/test/case/node/qlinearmatmul.py +1 -0
  139. onnx/backend/test/case/node/quantizelinear.py +1 -0
  140. onnx/backend/test/case/node/rangeop.py +1 -0
  141. onnx/backend/test/case/node/reciprocal.py +1 -0
  142. onnx/backend/test/case/node/reduce_log_sum.py +1 -0
  143. onnx/backend/test/case/node/reduce_log_sum_exp.py +1 -0
  144. onnx/backend/test/case/node/reducel1.py +1 -0
  145. onnx/backend/test/case/node/reducel2.py +1 -0
  146. onnx/backend/test/case/node/reducemax.py +2 -1
  147. onnx/backend/test/case/node/reducemean.py +1 -0
  148. onnx/backend/test/case/node/reducemin.py +1 -0
  149. onnx/backend/test/case/node/reduceprod.py +1 -0
  150. onnx/backend/test/case/node/reducesum.py +2 -1
  151. onnx/backend/test/case/node/reducesumsquare.py +1 -0
  152. onnx/backend/test/case/node/regex_full_match.py +1 -0
  153. onnx/backend/test/case/node/relu.py +1 -0
  154. onnx/backend/test/case/node/reshape.py +1 -0
  155. onnx/backend/test/case/node/resize.py +3 -2
  156. onnx/backend/test/case/node/reversesequence.py +1 -0
  157. onnx/backend/test/case/node/rnn.py +3 -2
  158. onnx/backend/test/case/node/roialign.py +1 -0
  159. onnx/backend/test/case/node/round.py +4 -3
  160. onnx/backend/test/case/node/scan.py +1 -0
  161. onnx/backend/test/case/node/scatter.py +1 -0
  162. onnx/backend/test/case/node/scatterelements.py +7 -3
  163. onnx/backend/test/case/node/scatternd.py +1 -0
  164. onnx/backend/test/case/node/selu.py +1 -0
  165. onnx/backend/test/case/node/sequence_map.py +1 -0
  166. onnx/backend/test/case/node/sequenceinsert.py +4 -3
  167. onnx/backend/test/case/node/shape.py +1 -0
  168. onnx/backend/test/case/node/shrink.py +1 -0
  169. onnx/backend/test/case/node/sigmoid.py +1 -0
  170. onnx/backend/test/case/node/sign.py +1 -0
  171. onnx/backend/test/case/node/sin.py +1 -0
  172. onnx/backend/test/case/node/sinh.py +1 -0
  173. onnx/backend/test/case/node/size.py +1 -0
  174. onnx/backend/test/case/node/slice.py +1 -0
  175. onnx/backend/test/case/node/softmax.py +1 -0
  176. onnx/backend/test/case/node/softmaxcrossentropy.py +4 -1
  177. onnx/backend/test/case/node/softplus.py +1 -0
  178. onnx/backend/test/case/node/softsign.py +1 -0
  179. onnx/backend/test/case/node/spacetodepth.py +1 -0
  180. onnx/backend/test/case/node/split.py +1 -0
  181. onnx/backend/test/case/node/splittosequence.py +1 -0
  182. onnx/backend/test/case/node/sqrt.py +1 -0
  183. onnx/backend/test/case/node/squeeze.py +1 -0
  184. onnx/backend/test/case/node/stft.py +4 -1
  185. onnx/backend/test/case/node/string_concat.py +1 -0
  186. onnx/backend/test/case/node/string_split.py +1 -0
  187. onnx/backend/test/case/node/stringnormalizer.py +1 -0
  188. onnx/backend/test/case/node/sub.py +1 -0
  189. onnx/backend/test/case/node/sum.py +1 -0
  190. onnx/backend/test/case/node/tan.py +1 -0
  191. onnx/backend/test/case/node/tanh.py +1 -0
  192. onnx/backend/test/case/node/tfidfvectorizer.py +1 -0
  193. onnx/backend/test/case/node/thresholdedrelu.py +1 -0
  194. onnx/backend/test/case/node/tile.py +1 -0
  195. onnx/backend/test/case/node/topk.py +1 -0
  196. onnx/backend/test/case/node/transpose.py +1 -0
  197. onnx/backend/test/case/node/trilu.py +1 -0
  198. onnx/backend/test/case/node/unique.py +7 -0
  199. onnx/backend/test/case/node/unsqueeze.py +1 -0
  200. onnx/backend/test/case/node/upsample.py +1 -0
  201. onnx/backend/test/case/node/where.py +1 -0
  202. onnx/backend/test/case/node/xor.py +1 -0
  203. onnx/backend/test/case/test_case.py +6 -5
  204. onnx/backend/test/case/utils.py +2 -2
  205. onnx/backend/test/cmd_tools.py +1 -0
  206. onnx/backend/test/data/node/test_acos/model.onnx +0 -0
  207. onnx/backend/test/data/node/test_acos/test_data_set_0/output_0.pb +0 -0
  208. onnx/backend/test/data/node/test_acos_example/model.onnx +0 -0
  209. onnx/backend/test/data/node/test_acosh/model.onnx +0 -0
  210. onnx/backend/test/data/node/test_acosh/test_data_set_0/output_0.pb +1 -1
  211. onnx/backend/test/data/node/test_acosh_example/model.onnx +0 -0
  212. onnx/backend/test/data/node/test_asin/model.onnx +0 -0
  213. onnx/backend/test/data/node/test_asin/test_data_set_0/output_0.pb +1 -1
  214. onnx/backend/test/data/node/test_asin_example/model.onnx +0 -0
  215. onnx/backend/test/data/node/test_asinh/model.onnx +0 -0
  216. onnx/backend/test/data/node/test_asinh/test_data_set_0/output_0.pb +1 -1
  217. onnx/backend/test/data/node/test_asinh_example/model.onnx +0 -0
  218. onnx/backend/test/data/node/test_atan/model.onnx +0 -0
  219. onnx/backend/test/data/node/test_atan/test_data_set_0/output_0.pb +1 -1
  220. onnx/backend/test/data/node/test_atan_example/model.onnx +0 -0
  221. onnx/backend/test/data/node/test_atanh/model.onnx +0 -0
  222. onnx/backend/test/data/node/test_atanh/test_data_set_0/output_0.pb +2 -2
  223. onnx/backend/test/data/node/test_atanh_example/model.onnx +0 -0
  224. onnx/backend/test/data/node/test_averagepool_1d_default/model.onnx +0 -0
  225. onnx/backend/test/data/node/test_averagepool_2d_ceil/model.onnx +0 -0
  226. onnx/backend/test/data/node/test_averagepool_2d_default/model.onnx +0 -0
  227. onnx/backend/test/data/node/test_averagepool_2d_dilations/model.onnx +0 -0
  228. onnx/backend/test/data/node/test_averagepool_2d_pads/model.onnx +0 -0
  229. onnx/backend/test/data/node/test_averagepool_2d_pads_count_include_pad/model.onnx +0 -0
  230. onnx/backend/test/data/node/test_averagepool_2d_precomputed_pads/model.onnx +0 -0
  231. onnx/backend/test/data/node/test_averagepool_2d_precomputed_pads_count_include_pad/model.onnx +0 -0
  232. onnx/backend/test/data/node/test_averagepool_2d_precomputed_same_upper/model.onnx +0 -0
  233. onnx/backend/test/data/node/test_averagepool_2d_precomputed_strides/model.onnx +0 -0
  234. onnx/backend/test/data/node/test_averagepool_2d_same_lower/model.onnx +0 -0
  235. onnx/backend/test/data/node/test_averagepool_2d_same_upper/model.onnx +0 -0
  236. onnx/backend/test/data/node/test_averagepool_2d_strides/model.onnx +0 -0
  237. onnx/backend/test/data/node/test_averagepool_3d_default/model.onnx +0 -0
  238. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_0_ceil_mode_is_False/model.onnx +0 -0
  239. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_0_ceil_mode_is_True/model.onnx +0 -0
  240. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_False/model.onnx +0 -0
  241. onnx/backend/test/data/node/test_averagepool_3d_dilations_large_count_include_pad_is_1_ceil_mode_is_True/model.onnx +0 -0
  242. onnx/backend/test/data/node/test_averagepool_3d_dilations_small/model.onnx +0 -0
  243. onnx/backend/test/data/node/test_basic_conv_with_padding/model.onnx +0 -0
  244. onnx/backend/test/data/node/test_basic_conv_without_padding/model.onnx +0 -0
  245. onnx/backend/test/data/node/test_basic_deform_conv_with_padding/model.onnx +0 -0
  246. onnx/backend/test/data/node/test_basic_deform_conv_without_padding/model.onnx +0 -0
  247. onnx/backend/test/data/node/test_bernoulli/model.onnx +0 -0
  248. onnx/backend/test/data/node/test_bernoulli_double/model.onnx +0 -0
  249. onnx/backend/test/data/node/test_bernoulli_double_expanded/model.onnx +0 -0
  250. onnx/backend/test/data/node/test_bernoulli_expanded/model.onnx +0 -0
  251. onnx/backend/test/data/node/test_bernoulli_seed/model.onnx +0 -0
  252. onnx/backend/test/data/node/test_bernoulli_seed_expanded/model.onnx +0 -0
  253. onnx/backend/test/data/node/test_blackmanwindow/test_data_set_0/output_0.pb +0 -0
  254. onnx/backend/test/data/node/test_blackmanwindow_expanded/test_data_set_0/output_0.pb +0 -0
  255. onnx/backend/test/data/node/test_blackmanwindow_symmetric/test_data_set_0/output_0.pb +0 -0
  256. onnx/backend/test/data/node/test_blackmanwindow_symmetric_expanded/test_data_set_0/output_0.pb +0 -0
  257. onnx/backend/test/data/node/test_cast_FLOAT16_to_INT4/test_data_set_0/output_0.pb +1 -1
  258. onnx/backend/test/data/node/test_cast_FLOAT_to_INT4/test_data_set_0/output_0.pb +1 -1
  259. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT/test_data_set_0/input_0.pb +1 -1
  260. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT16/test_data_set_0/input_0.pb +1 -1
  261. onnx/backend/test/data/node/test_cast_INT4_to_INT8/test_data_set_0/input_0.pb +1 -1
  262. onnx/backend/test/data/node/test_conv_with_autopad_same/model.onnx +0 -0
  263. onnx/backend/test/data/node/test_conv_with_strides_and_asymmetric_padding/model.onnx +0 -0
  264. onnx/backend/test/data/node/test_conv_with_strides_no_padding/model.onnx +0 -0
  265. onnx/backend/test/data/node/test_conv_with_strides_padding/model.onnx +0 -0
  266. onnx/backend/test/data/node/test_convtranspose/model.onnx +0 -0
  267. onnx/backend/test/data/node/test_convtranspose_1d/model.onnx +0 -0
  268. onnx/backend/test/data/node/test_convtranspose_3d/model.onnx +0 -0
  269. onnx/backend/test/data/node/test_convtranspose_autopad_same/model.onnx +0 -0
  270. onnx/backend/test/data/node/test_convtranspose_dilations/model.onnx +0 -0
  271. onnx/backend/test/data/node/test_convtranspose_group_2/model.onnx +0 -0
  272. onnx/backend/test/data/node/test_convtranspose_group_2/test_data_set_0/input_0.pb +0 -0
  273. onnx/backend/test/data/node/test_convtranspose_group_2/test_data_set_0/input_1.pb +0 -0
  274. onnx/backend/test/data/node/test_convtranspose_group_2/test_data_set_0/output_0.pb +0 -0
  275. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/model.onnx +0 -0
  276. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/test_data_set_0/input_0.pb +0 -0
  277. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/test_data_set_0/input_1.pb +0 -0
  278. onnx/backend/test/data/node/test_convtranspose_group_2_image_3/test_data_set_0/output_0.pb +0 -0
  279. onnx/backend/test/data/node/test_convtranspose_kernel_shape/model.onnx +0 -0
  280. onnx/backend/test/data/node/test_convtranspose_output_shape/model.onnx +0 -0
  281. onnx/backend/test/data/node/test_convtranspose_pad/model.onnx +0 -0
  282. onnx/backend/test/data/node/test_convtranspose_pads/model.onnx +0 -0
  283. onnx/backend/test/data/node/test_cos/model.onnx +0 -0
  284. onnx/backend/test/data/node/test_cos_example/model.onnx +0 -0
  285. onnx/backend/test/data/node/test_cosh/model.onnx +0 -0
  286. onnx/backend/test/data/node/test_cosh/test_data_set_0/output_0.pb +1 -1
  287. onnx/backend/test/data/node/test_cosh_example/model.onnx +0 -0
  288. onnx/backend/test/data/node/test_cosh_example/test_data_set_0/output_0.pb +0 -0
  289. onnx/backend/test/data/node/test_deform_conv_with_mask_bias/model.onnx +0 -0
  290. onnx/backend/test/data/node/test_deform_conv_with_multiple_offset_groups/model.onnx +0 -0
  291. onnx/backend/test/data/node/test_dequantizelinear_int4/test_data_set_0/input_0.pb +1 -1
  292. onnx/backend/test/data/node/test_det_2d/model.onnx +0 -0
  293. onnx/backend/test/data/node/test_det_nd/model.onnx +0 -0
  294. onnx/backend/test/data/node/test_dft/test_data_set_0/output_0.pb +0 -0
  295. onnx/backend/test/data/node/test_dft_axis/test_data_set_0/output_0.pb +0 -0
  296. onnx/backend/test/data/node/test_dft_axis_opset19/test_data_set_0/output_0.pb +0 -0
  297. onnx/backend/test/data/node/test_dft_inverse/test_data_set_0/output_0.pb +0 -0
  298. onnx/backend/test/data/node/test_dft_inverse_opset19/test_data_set_0/output_0.pb +0 -0
  299. onnx/backend/test/data/node/test_dft_opset19/test_data_set_0/output_0.pb +0 -0
  300. onnx/backend/test/data/node/test_dropout_default/model.onnx +0 -0
  301. onnx/backend/test/data/node/test_dropout_default_mask/model.onnx +0 -0
  302. onnx/backend/test/data/node/test_dropout_default_mask_ratio/model.onnx +0 -0
  303. onnx/backend/test/data/node/test_dropout_default_ratio/model.onnx +0 -0
  304. onnx/backend/test/data/node/test_elu/model.onnx +0 -0
  305. onnx/backend/test/data/node/test_elu_default/model.onnx +0 -0
  306. onnx/backend/test/data/node/test_elu_example/model.onnx +0 -0
  307. onnx/backend/test/data/node/test_eyelike_populate_off_main_diagonal/model.onnx +0 -0
  308. onnx/backend/test/data/node/test_eyelike_with_dtype/model.onnx +0 -0
  309. onnx/backend/test/data/node/test_eyelike_without_dtype/model.onnx +0 -0
  310. onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb +0 -0
  311. onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb +0 -0
  312. onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb +4 -3
  313. onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb +4 -3
  314. onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb +0 -0
  315. onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb +0 -0
  316. onnx/backend/test/data/node/test_globalaveragepool/model.onnx +0 -0
  317. onnx/backend/test/data/node/test_globalaveragepool_precomputed/model.onnx +0 -0
  318. onnx/backend/test/data/node/test_globalmaxpool/model.onnx +0 -0
  319. onnx/backend/test/data/node/test_globalmaxpool_precomputed/model.onnx +0 -0
  320. onnx/backend/test/data/node/test_gridsample/model.onnx +0 -0
  321. onnx/backend/test/data/node/test_gridsample_aligncorners_true/model.onnx +0 -0
  322. onnx/backend/test/data/node/test_gridsample_bicubic/model.onnx +0 -0
  323. onnx/backend/test/data/node/test_gridsample_bicubic_align_corners_0_additional_1/model.onnx +0 -0
  324. onnx/backend/test/data/node/test_gridsample_bicubic_align_corners_1_additional_1/model.onnx +0 -0
  325. onnx/backend/test/data/node/test_gridsample_bilinear/model.onnx +0 -0
  326. onnx/backend/test/data/node/test_gridsample_bilinear_align_corners_0_additional_1/model.onnx +0 -0
  327. onnx/backend/test/data/node/test_gridsample_bilinear_align_corners_1_additional_1/model.onnx +0 -0
  328. onnx/backend/test/data/node/test_gridsample_border_padding/model.onnx +0 -0
  329. onnx/backend/test/data/node/test_gridsample_nearest/model.onnx +0 -0
  330. onnx/backend/test/data/node/test_gridsample_nearest_align_corners_0_additional_1/model.onnx +0 -0
  331. onnx/backend/test/data/node/test_gridsample_nearest_align_corners_1_additional_1/model.onnx +0 -0
  332. onnx/backend/test/data/node/test_gridsample_reflection_padding/model.onnx +0 -0
  333. onnx/backend/test/data/node/test_gridsample_volumetric_bilinear_align_corners_0/model.onnx +0 -0
  334. onnx/backend/test/data/node/test_gridsample_volumetric_bilinear_align_corners_1/model.onnx +0 -0
  335. onnx/backend/test/data/node/test_gridsample_volumetric_nearest_align_corners_0/model.onnx +0 -0
  336. onnx/backend/test/data/node/test_gridsample_volumetric_nearest_align_corners_1/model.onnx +0 -0
  337. onnx/backend/test/data/node/test_gridsample_zeros_padding/model.onnx +0 -0
  338. onnx/backend/test/data/node/test_gru_batchwise/model.onnx +0 -0
  339. onnx/backend/test/data/node/test_gru_defaults/model.onnx +0 -0
  340. onnx/backend/test/data/node/test_gru_seq_length/model.onnx +0 -0
  341. onnx/backend/test/data/node/test_gru_with_initial_bias/model.onnx +0 -0
  342. onnx/backend/test/data/node/test_hammingwindow/test_data_set_0/output_0.pb +0 -0
  343. onnx/backend/test/data/node/test_hammingwindow_expanded/test_data_set_0/output_0.pb +0 -0
  344. onnx/backend/test/data/node/test_hammingwindow_symmetric/test_data_set_0/output_0.pb +1 -1
  345. onnx/backend/test/data/node/test_hammingwindow_symmetric_expanded/test_data_set_0/output_0.pb +1 -1
  346. onnx/backend/test/data/node/test_hannwindow/test_data_set_0/output_0.pb +0 -0
  347. onnx/backend/test/data/node/test_hannwindow_expanded/test_data_set_0/output_0.pb +0 -0
  348. onnx/backend/test/data/node/test_hannwindow_symmetric/test_data_set_0/output_0.pb +0 -0
  349. onnx/backend/test/data/node/test_hannwindow_symmetric_expanded/test_data_set_0/output_0.pb +0 -0
  350. onnx/backend/test/data/node/test_hardsigmoid/model.onnx +0 -0
  351. onnx/backend/test/data/node/test_hardsigmoid_default/model.onnx +0 -0
  352. onnx/backend/test/data/node/test_hardsigmoid_example/model.onnx +0 -0
  353. onnx/backend/test/data/node/test_hardswish/model.onnx +0 -0
  354. onnx/backend/test/data/node/test_hardswish_expanded/model.onnx +0 -0
  355. onnx/backend/test/data/node/test_image_decoder_decode_jpeg2k_rgb/test_data_set_0/input_0.pb +0 -0
  356. onnx/backend/test/data/node/test_instancenorm_epsilon/model.onnx +0 -0
  357. onnx/backend/test/data/node/test_instancenorm_example/model.onnx +0 -0
  358. onnx/backend/test/data/node/test_lppool_1d_default/model.onnx +0 -0
  359. onnx/backend/test/data/node/test_lppool_1d_default/test_data_set_0/output_0.pb +2 -2
  360. onnx/backend/test/data/node/test_lppool_2d_default/model.onnx +0 -0
  361. onnx/backend/test/data/node/test_lppool_2d_default/test_data_set_0/output_0.pb +0 -0
  362. onnx/backend/test/data/node/test_lppool_2d_dilations/model.onnx +0 -0
  363. onnx/backend/test/data/node/test_lppool_2d_pads/model.onnx +0 -0
  364. onnx/backend/test/data/node/test_lppool_2d_pads/test_data_set_0/output_0.pb +0 -0
  365. onnx/backend/test/data/node/test_lppool_2d_same_lower/model.onnx +0 -0
  366. onnx/backend/test/data/node/test_lppool_2d_same_lower/test_data_set_0/output_0.pb +0 -0
  367. onnx/backend/test/data/node/test_lppool_2d_same_upper/model.onnx +0 -0
  368. onnx/backend/test/data/node/test_lppool_2d_same_upper/test_data_set_0/output_0.pb +0 -0
  369. onnx/backend/test/data/node/test_lppool_2d_strides/model.onnx +0 -0
  370. onnx/backend/test/data/node/test_lppool_2d_strides/test_data_set_0/output_0.pb +0 -0
  371. onnx/backend/test/data/node/test_lppool_3d_default/model.onnx +0 -0
  372. onnx/backend/test/data/node/test_lppool_3d_default/test_data_set_0/output_0.pb +0 -0
  373. onnx/backend/test/data/node/test_lstm_batchwise/model.onnx +0 -0
  374. onnx/backend/test/data/node/test_lstm_defaults/model.onnx +0 -0
  375. onnx/backend/test/data/node/test_lstm_with_initial_bias/model.onnx +0 -0
  376. onnx/backend/test/data/node/test_lstm_with_peepholes/model.onnx +0 -0
  377. onnx/backend/test/data/node/test_maxpool_1d_default/model.onnx +0 -0
  378. onnx/backend/test/data/node/test_maxpool_2d_ceil/model.onnx +0 -0
  379. onnx/backend/test/data/node/test_maxpool_2d_ceil_output_size_reduce_by_one/model.onnx +0 -0
  380. onnx/backend/test/data/node/test_maxpool_2d_default/model.onnx +0 -0
  381. onnx/backend/test/data/node/test_maxpool_2d_dilations/model.onnx +0 -0
  382. onnx/backend/test/data/node/test_maxpool_2d_pads/model.onnx +0 -0
  383. onnx/backend/test/data/node/test_maxpool_2d_precomputed_pads/model.onnx +0 -0
  384. onnx/backend/test/data/node/test_maxpool_2d_precomputed_same_upper/model.onnx +0 -0
  385. onnx/backend/test/data/node/test_maxpool_2d_precomputed_strides/model.onnx +0 -0
  386. onnx/backend/test/data/node/test_maxpool_2d_same_lower/model.onnx +0 -0
  387. onnx/backend/test/data/node/test_maxpool_2d_same_upper/model.onnx +0 -0
  388. onnx/backend/test/data/node/test_maxpool_2d_strides/model.onnx +0 -0
  389. onnx/backend/test/data/node/test_maxpool_2d_uint8/model.onnx +0 -0
  390. onnx/backend/test/data/node/test_maxpool_3d_default/model.onnx +0 -0
  391. onnx/backend/test/data/node/test_maxpool_3d_dilations/model.onnx +0 -0
  392. onnx/backend/test/data/node/test_maxpool_3d_dilations_use_ref_impl/model.onnx +0 -0
  393. onnx/backend/test/data/node/test_maxpool_3d_dilations_use_ref_impl_large/model.onnx +0 -0
  394. onnx/backend/test/data/node/test_maxpool_with_argmax_2d_precomputed_pads/model.onnx +0 -0
  395. onnx/backend/test/data/node/test_maxpool_with_argmax_2d_precomputed_strides/model.onnx +0 -0
  396. onnx/backend/test/data/node/test_maxunpool_export_with_output_shape/model.onnx +0 -0
  397. onnx/backend/test/data/node/test_maxunpool_export_without_output_shape/model.onnx +0 -0
  398. onnx/backend/test/data/node/test_mish/model.onnx +0 -0
  399. onnx/backend/test/data/node/test_mish/test_data_set_0/output_0.pb +0 -0
  400. onnx/backend/test/data/node/test_mish_expanded/model.onnx +0 -0
  401. onnx/backend/test/data/node/test_mish_expanded/test_data_set_0/output_0.pb +0 -0
  402. onnx/backend/test/data/node/test_nllloss_NC/model.onnx +0 -0
  403. onnx/backend/test/data/node/test_nllloss_NC_expanded/model.onnx +0 -0
  404. onnx/backend/test/data/node/test_nllloss_NCd1/model.onnx +0 -0
  405. onnx/backend/test/data/node/test_nllloss_NCd1_expanded/model.onnx +0 -0
  406. onnx/backend/test/data/node/test_nllloss_NCd1_ii/model.onnx +0 -0
  407. onnx/backend/test/data/node/test_nllloss_NCd1_ii_expanded/model.onnx +0 -0
  408. onnx/backend/test/data/node/test_nllloss_NCd1_mean_weight_negative_ii/model.onnx +0 -0
  409. onnx/backend/test/data/node/test_nllloss_NCd1_mean_weight_negative_ii_expanded/model.onnx +0 -0
  410. onnx/backend/test/data/node/test_nllloss_NCd1_weight/model.onnx +0 -0
  411. onnx/backend/test/data/node/test_nllloss_NCd1_weight_expanded/model.onnx +0 -0
  412. onnx/backend/test/data/node/test_nllloss_NCd1_weight_ii/model.onnx +0 -0
  413. onnx/backend/test/data/node/test_nllloss_NCd1_weight_ii_expanded/model.onnx +0 -0
  414. onnx/backend/test/data/node/test_nllloss_NCd1d2/model.onnx +0 -0
  415. onnx/backend/test/data/node/test_nllloss_NCd1d2_expanded/model.onnx +0 -0
  416. onnx/backend/test/data/node/test_nllloss_NCd1d2_no_weight_reduction_mean_ii/model.onnx +0 -0
  417. onnx/backend/test/data/node/test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded/model.onnx +0 -0
  418. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_mean/model.onnx +0 -0
  419. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_mean_expanded/model.onnx +0 -0
  420. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_sum/model.onnx +0 -0
  421. onnx/backend/test/data/node/test_nllloss_NCd1d2_reduction_sum_expanded/model.onnx +0 -0
  422. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight/model.onnx +0 -0
  423. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_expanded/model.onnx +0 -0
  424. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_mean/model.onnx +0 -0
  425. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_mean_expanded/model.onnx +0 -0
  426. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum/model.onnx +0 -0
  427. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum_expanded/model.onnx +0 -0
  428. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum_ii/model.onnx +0 -0
  429. onnx/backend/test/data/node/test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded/model.onnx +0 -0
  430. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_none_no_weight_negative_ii/model.onnx +0 -0
  431. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded/model.onnx +0 -0
  432. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_sum_weight_high_ii/model.onnx +0 -0
  433. onnx/backend/test/data/node/test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded/model.onnx +0 -0
  434. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_mean_weight/model.onnx +0 -0
  435. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_mean_weight_expanded/model.onnx +0 -0
  436. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_none_no_weight/model.onnx +0 -0
  437. onnx/backend/test/data/node/test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded/model.onnx +0 -0
  438. onnx/backend/test/data/node/test_quantizelinear_int4/test_data_set_0/output_0.pb +1 -1
  439. onnx/backend/test/data/node/test_reduce_log_sum_exp_do_not_keepdims_random/test_data_set_0/output_0.pb +1 -1
  440. onnx/backend/test/data/node/test_reduce_log_sum_exp_do_not_keepdims_random_expanded/test_data_set_0/output_0.pb +1 -1
  441. onnx/backend/test/data/node/test_reduce_log_sum_exp_keepdims_random/test_data_set_0/output_0.pb +1 -1
  442. onnx/backend/test/data/node/test_reduce_log_sum_exp_keepdims_random_expanded/test_data_set_0/output_0.pb +1 -1
  443. onnx/backend/test/data/node/test_reduce_log_sum_exp_negative_axes_keepdims_random/test_data_set_0/output_0.pb +1 -1
  444. onnx/backend/test/data/node/test_reduce_log_sum_exp_negative_axes_keepdims_random_expanded/test_data_set_0/output_0.pb +1 -1
  445. onnx/backend/test/data/node/test_reduce_max_empty_set/model.onnx +0 -0
  446. onnx/backend/test/data/node/test_reduce_max_empty_set/test_data_set_0/input_0.pb +0 -0
  447. onnx/backend/test/data/node/test_reduce_max_empty_set/test_data_set_0/input_1.pb +0 -0
  448. onnx/backend/test/data/node/test_reduce_max_empty_set/test_data_set_0/output_0.pb +0 -0
  449. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/model.onnx +0 -0
  450. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/test_data_set_0/input_0.pb +1 -0
  451. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/test_data_set_0/input_1.pb +0 -0
  452. onnx/backend/test/data/node/test_reduce_sum_empty_axes_input_noop/test_data_set_0/output_0.pb +1 -0
  453. onnx/backend/test/data/node/test_reduce_sum_negative_axes_keepdims_random/model.onnx +0 -0
  454. onnx/backend/test/data/node/test_reduce_sum_negative_axes_keepdims_random/test_data_set_0/input_1.pb +0 -0
  455. onnx/backend/test/data/node/test_reduce_sum_negative_axes_keepdims_random/test_data_set_0/output_0.pb +1 -1
  456. onnx/backend/test/data/node/test_resize_tf_crop_and_resize/model.onnx +0 -0
  457. onnx/backend/test/data/node/test_resize_tf_crop_and_resize/test_data_set_0/input_1.pb +0 -0
  458. onnx/backend/test/data/node/test_resize_tf_crop_and_resize/test_data_set_0/output_0.pb +0 -0
  459. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/model.onnx +0 -0
  460. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/input_0.pb +0 -0
  461. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/input_1.pb +0 -0
  462. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/input_2.pb +0 -0
  463. onnx/backend/test/data/node/test_resize_tf_crop_and_resize_extrapolation_value/test_data_set_0/output_0.pb +0 -0
  464. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_larger/model.onnx +0 -0
  465. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_larger/test_data_set_0/output_0.pb +0 -0
  466. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/model.onnx +0 -0
  467. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/test_data_set_0/input_0.pb +0 -0
  468. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/test_data_set_0/input_1.pb +0 -0
  469. onnx/backend/test/data/node/test_resize_upsample_sizes_nearest_not_smaller/test_data_set_0/output_0.pb +0 -0
  470. onnx/backend/test/data/node/test_rnn_seq_length/model.onnx +0 -0
  471. onnx/backend/test/data/node/test_roialign_aligned_false/model.onnx +0 -0
  472. onnx/backend/test/data/node/test_roialign_aligned_true/model.onnx +0 -0
  473. onnx/backend/test/data/node/test_roialign_mode_max/model.onnx +0 -0
  474. onnx/backend/test/data/node/test_round/model.onnx +0 -0
  475. onnx/backend/test/data/node/test_selu/model.onnx +0 -0
  476. onnx/backend/test/data/node/test_selu_default/model.onnx +0 -0
  477. onnx/backend/test/data/node/test_selu_example/model.onnx +0 -0
  478. onnx/backend/test/data/node/test_simple_rnn_batchwise/model.onnx +0 -0
  479. onnx/backend/test/data/node/test_simple_rnn_defaults/model.onnx +0 -0
  480. onnx/backend/test/data/node/test_simple_rnn_with_initial_bias/model.onnx +0 -0
  481. onnx/backend/test/data/node/test_sin/model.onnx +0 -0
  482. onnx/backend/test/data/node/test_sin_example/model.onnx +0 -0
  483. onnx/backend/test/data/node/test_sinh/model.onnx +0 -0
  484. onnx/backend/test/data/node/test_sinh/test_data_set_0/output_0.pb +1 -1
  485. onnx/backend/test/data/node/test_sinh_example/model.onnx +0 -0
  486. onnx/backend/test/data/node/test_softplus/model.onnx +0 -0
  487. onnx/backend/test/data/node/test_softplus_example/model.onnx +0 -0
  488. onnx/backend/test/data/node/test_softsign/model.onnx +0 -0
  489. onnx/backend/test/data/node/test_softsign_example/model.onnx +0 -0
  490. onnx/backend/test/data/node/test_stft_with_window/test_data_set_0/input_2.pb +0 -0
  491. onnx/backend/test/data/node/test_stft_with_window/test_data_set_0/output_0.pb +0 -0
  492. onnx/backend/test/data/node/test_tan/model.onnx +0 -0
  493. onnx/backend/test/data/node/test_tan/test_data_set_0/output_0.pb +1 -1
  494. onnx/backend/test/data/node/test_tan_example/model.onnx +0 -0
  495. onnx/backend/test/data/node/test_thresholdedrelu/model.onnx +0 -0
  496. onnx/backend/test/data/node/test_thresholdedrelu_default/model.onnx +0 -0
  497. onnx/backend/test/data/node/test_thresholdedrelu_example/model.onnx +0 -0
  498. onnx/backend/test/data/node/test_training_dropout/model.onnx +0 -0
  499. onnx/backend/test/data/node/test_training_dropout_default/model.onnx +0 -0
  500. onnx/backend/test/data/node/test_training_dropout_default_mask/model.onnx +0 -0
  501. onnx/backend/test/data/node/test_training_dropout_mask/model.onnx +0 -0
  502. onnx/backend/test/data/node/test_training_dropout_zero_ratio/model.onnx +0 -0
  503. onnx/backend/test/data/node/test_training_dropout_zero_ratio_mask/model.onnx +0 -0
  504. onnx/backend/test/loader/__init__.py +11 -6
  505. onnx/backend/test/report/__init__.py +4 -3
  506. onnx/backend/test/report/base.py +1 -0
  507. onnx/backend/test/report/coverage.py +21 -20
  508. onnx/backend/test/runner/__init__.py +12 -8
  509. onnx/backend/test/runner/item.py +3 -2
  510. onnx/backend/test/stat_coverage.py +6 -5
  511. onnx/bin/checker.py +1 -0
  512. onnx/checker.cc +6 -1
  513. onnx/common/version.h +1 -1
  514. onnx/compose.py +66 -50
  515. onnx/cpp2py_export.cc +4 -0
  516. onnx/defs/__init__.py +2 -2
  517. onnx/defs/data_type_utils.cc +0 -1
  518. onnx/defs/gen_doc.py +9 -8
  519. onnx/defs/gen_shape_inference_information.py +1 -0
  520. onnx/defs/generator/defs.cc +32 -84
  521. onnx/defs/generator/old.cc +389 -0
  522. onnx/defs/math/defs.cc +308 -313
  523. onnx/defs/math/old.cc +989 -7
  524. onnx/defs/math/utils.cc +12 -1
  525. onnx/defs/math/utils.h +2 -0
  526. onnx/defs/nn/defs.cc +57 -75
  527. onnx/defs/nn/old.cc +1536 -2
  528. onnx/defs/object_detection/defs.cc +4 -7
  529. onnx/defs/object_detection/old.cc +117 -0
  530. onnx/defs/operator_sets.h +108 -1
  531. onnx/defs/parser.cc +10 -1
  532. onnx/defs/quantization/defs.cc +3 -2
  533. onnx/defs/quantization/old.cc +4 -1
  534. onnx/defs/rnn/defs.cc +10 -13
  535. onnx/defs/rnn/old.cc +517 -2
  536. onnx/defs/schema.cc +53 -59
  537. onnx/defs/schema.h +58 -2
  538. onnx/defs/shape_inference.h +67 -18
  539. onnx/defs/tensor/defs.cc +22 -20
  540. onnx/defs/tensor/old.cc +111 -0
  541. onnx/external_data_helper.py +27 -14
  542. onnx/gen_proto.py +3 -2
  543. onnx/helper.py +86 -61
  544. onnx/hub.py +30 -28
  545. onnx/inliner/inliner.cc +0 -1
  546. onnx/mapping.py +3 -2
  547. onnx/numpy_helper.py +159 -23
  548. onnx/onnx-ml.proto +1 -1
  549. onnx/onnx.in.proto +1 -1
  550. onnx/onnx.proto +1 -1
  551. onnx/onnx_cpp2py_export/defs.pyi +0 -2
  552. onnx/onnx_cpp2py_export/inliner.pyi +0 -4
  553. onnx/onnx_cpp2py_export/parser.pyi +0 -4
  554. onnx/onnx_cpp2py_export.cp39-win32.pyd +0 -0
  555. onnx/parser.py +1 -0
  556. onnx/printer.py +2 -3
  557. onnx/reference/__init__.py +1 -0
  558. onnx/reference/custom_element_types.py +73 -8
  559. onnx/reference/op_run.py +13 -58
  560. onnx/reference/ops/__init__.py +1 -0
  561. onnx/reference/ops/_helpers.py +6 -4
  562. onnx/reference/ops/_op.py +16 -5
  563. onnx/reference/ops/_op_common_indices.py +1 -1
  564. onnx/reference/ops/_op_common_pool.py +38 -29
  565. onnx/reference/ops/_op_common_random.py +1 -1
  566. onnx/reference/ops/_op_common_window.py +2 -2
  567. onnx/reference/ops/_op_list.py +9 -6
  568. onnx/reference/ops/aionnx_preview_training/__init__.py +1 -0
  569. onnx/reference/ops/aionnx_preview_training/_op_list.py +5 -7
  570. onnx/reference/ops/aionnx_preview_training/_op_run_training.py +1 -1
  571. onnx/reference/ops/aionnx_preview_training/op_adagrad.py +14 -5
  572. onnx/reference/ops/aionnx_preview_training/op_adam.py +2 -2
  573. onnx/reference/ops/aionnx_preview_training/op_momentum.py +14 -2
  574. onnx/reference/ops/aionnxml/__init__.py +1 -0
  575. onnx/reference/ops/aionnxml/_common_classifier.py +1 -0
  576. onnx/reference/ops/aionnxml/_op_list.py +5 -6
  577. onnx/reference/ops/aionnxml/_op_run_aionnxml.py +1 -1
  578. onnx/reference/ops/aionnxml/op_array_feature_extractor.py +1 -1
  579. onnx/reference/ops/aionnxml/op_binarizer.py +1 -1
  580. onnx/reference/ops/aionnxml/op_dict_vectorizer.py +2 -2
  581. onnx/reference/ops/aionnxml/op_feature_vectorizer.py +1 -1
  582. onnx/reference/ops/aionnxml/op_imputer.py +3 -3
  583. onnx/reference/ops/aionnxml/op_label_encoder.py +1 -1
  584. onnx/reference/ops/aionnxml/op_linear_classifier.py +2 -2
  585. onnx/reference/ops/aionnxml/op_linear_regressor.py +1 -1
  586. onnx/reference/ops/aionnxml/op_normalizer.py +1 -1
  587. onnx/reference/ops/aionnxml/op_one_hot_encoder.py +1 -1
  588. onnx/reference/ops/aionnxml/op_scaler.py +1 -1
  589. onnx/reference/ops/aionnxml/op_svm_classifier.py +10 -7
  590. onnx/reference/ops/aionnxml/op_svm_helper.py +2 -2
  591. onnx/reference/ops/aionnxml/op_svm_regressor.py +1 -1
  592. onnx/reference/ops/aionnxml/op_tree_ensemble.py +3 -3
  593. onnx/reference/ops/aionnxml/op_tree_ensemble_classifier.py +1 -1
  594. onnx/reference/ops/aionnxml/op_tree_ensemble_helper.py +2 -2
  595. onnx/reference/ops/aionnxml/op_tree_ensemble_regressor.py +5 -3
  596. onnx/reference/ops/experimental/__init__.py +1 -0
  597. onnx/reference/ops/experimental/_op_list.py +6 -12
  598. onnx/reference/ops/experimental/_op_run_experimental.py +1 -1
  599. onnx/reference/ops/experimental/op_im2col.py +1 -1
  600. onnx/reference/ops/op_abs.py +1 -1
  601. onnx/reference/ops/op_acos.py +1 -1
  602. onnx/reference/ops/op_acosh.py +1 -1
  603. onnx/reference/ops/op_add.py +1 -1
  604. onnx/reference/ops/op_affine_grid.py +1 -1
  605. onnx/reference/ops/op_and.py +1 -1
  606. onnx/reference/ops/op_argmax.py +1 -1
  607. onnx/reference/ops/op_argmin.py +1 -1
  608. onnx/reference/ops/op_asin.py +1 -1
  609. onnx/reference/ops/op_asinh.py +1 -1
  610. onnx/reference/ops/op_atan.py +1 -1
  611. onnx/reference/ops/op_atanh.py +1 -1
  612. onnx/reference/ops/op_attribute_has_value.py +15 -15
  613. onnx/reference/ops/op_average_pool.py +1 -1
  614. onnx/reference/ops/op_batch_normalization.py +13 -2
  615. onnx/reference/ops/op_bernoulli.py +1 -1
  616. onnx/reference/ops/op_bitshift.py +1 -1
  617. onnx/reference/ops/op_bitwise_and.py +1 -1
  618. onnx/reference/ops/op_bitwise_not.py +1 -1
  619. onnx/reference/ops/op_bitwise_or.py +1 -1
  620. onnx/reference/ops/op_bitwise_xor.py +1 -1
  621. onnx/reference/ops/op_blackman_window.py +1 -1
  622. onnx/reference/ops/op_cast.py +11 -10
  623. onnx/reference/ops/op_cast_like.py +1 -1
  624. onnx/reference/ops/op_ceil.py +1 -1
  625. onnx/reference/ops/op_celu.py +1 -1
  626. onnx/reference/ops/op_center_crop_pad.py +1 -1
  627. onnx/reference/ops/op_clip.py +1 -1
  628. onnx/reference/ops/op_col2im.py +10 -4
  629. onnx/reference/ops/op_compress.py +1 -1
  630. onnx/reference/ops/op_concat.py +1 -1
  631. onnx/reference/ops/op_concat_from_sequence.py +3 -3
  632. onnx/reference/ops/op_constant.py +2 -2
  633. onnx/reference/ops/op_constant_of_shape.py +1 -1
  634. onnx/reference/ops/op_conv.py +22 -17
  635. onnx/reference/ops/op_conv_integer.py +1 -1
  636. onnx/reference/ops/op_conv_transpose.py +37 -6
  637. onnx/reference/ops/op_cos.py +1 -1
  638. onnx/reference/ops/op_cosh.py +1 -1
  639. onnx/reference/ops/op_cum_sum.py +1 -1
  640. onnx/reference/ops/op_deform_conv.py +1 -1
  641. onnx/reference/ops/op_depth_to_space.py +1 -1
  642. onnx/reference/ops/op_dequantize_linear.py +7 -9
  643. onnx/reference/ops/op_det.py +1 -1
  644. onnx/reference/ops/op_dft.py +16 -2
  645. onnx/reference/ops/op_div.py +1 -1
  646. onnx/reference/ops/op_dropout.py +9 -8
  647. onnx/reference/ops/op_dynamic_quantize_linear.py +1 -1
  648. onnx/reference/ops/op_einsum.py +1 -1
  649. onnx/reference/ops/op_elu.py +1 -1
  650. onnx/reference/ops/op_equal.py +1 -1
  651. onnx/reference/ops/op_erf.py +1 -1
  652. onnx/reference/ops/op_exp.py +1 -1
  653. onnx/reference/ops/op_expand.py +1 -1
  654. onnx/reference/ops/op_eyelike.py +2 -2
  655. onnx/reference/ops/op_flatten.py +1 -1
  656. onnx/reference/ops/op_floor.py +1 -1
  657. onnx/reference/ops/op_gather.py +1 -1
  658. onnx/reference/ops/op_gather_elements.py +3 -3
  659. onnx/reference/ops/op_gathernd.py +2 -4
  660. onnx/reference/ops/op_gemm.py +12 -2
  661. onnx/reference/ops/op_global_average_pool.py +1 -1
  662. onnx/reference/ops/op_global_max_pool.py +1 -1
  663. onnx/reference/ops/op_greater.py +1 -1
  664. onnx/reference/ops/op_greater_or_equal.py +1 -1
  665. onnx/reference/ops/op_grid_sample.py +2 -3
  666. onnx/reference/ops/op_gru.py +7 -7
  667. onnx/reference/ops/op_hamming_window.py +1 -1
  668. onnx/reference/ops/op_hann_window.py +1 -1
  669. onnx/reference/ops/op_hard_sigmoid.py +1 -1
  670. onnx/reference/ops/op_hardmax.py +5 -2
  671. onnx/reference/ops/op_identity.py +3 -3
  672. onnx/reference/ops/op_if.py +2 -2
  673. onnx/reference/ops/op_instance_normalization.py +1 -1
  674. onnx/reference/ops/op_isinf.py +1 -1
  675. onnx/reference/ops/op_isnan.py +1 -1
  676. onnx/reference/ops/op_layer_normalization.py +2 -4
  677. onnx/reference/ops/op_leaky_relu.py +1 -1
  678. onnx/reference/ops/op_less.py +1 -1
  679. onnx/reference/ops/op_less_or_equal.py +1 -1
  680. onnx/reference/ops/op_log.py +1 -1
  681. onnx/reference/ops/op_log_softmax.py +1 -1
  682. onnx/reference/ops/op_loop.py +4 -2
  683. onnx/reference/ops/op_lp_normalization.py +1 -1
  684. onnx/reference/ops/op_lp_pool.py +4 -2
  685. onnx/reference/ops/op_lrn.py +1 -1
  686. onnx/reference/ops/op_lstm.py +9 -11
  687. onnx/reference/ops/op_matmul.py +1 -1
  688. onnx/reference/ops/op_matmul_integer.py +1 -1
  689. onnx/reference/ops/op_max.py +1 -1
  690. onnx/reference/ops/op_max_pool.py +8 -8
  691. onnx/reference/ops/op_max_unpool.py +5 -3
  692. onnx/reference/ops/op_mean.py +1 -1
  693. onnx/reference/ops/op_mel_weight_matrix.py +1 -1
  694. onnx/reference/ops/op_min.py +1 -1
  695. onnx/reference/ops/op_mod.py +1 -1
  696. onnx/reference/ops/op_mul.py +1 -1
  697. onnx/reference/ops/op_neg.py +1 -1
  698. onnx/reference/ops/op_negative_log_likelihood_loss.py +4 -2
  699. onnx/reference/ops/op_non_max_suppression.py +10 -11
  700. onnx/reference/ops/op_non_zero.py +1 -1
  701. onnx/reference/ops/op_not.py +1 -1
  702. onnx/reference/ops/op_one_hot.py +1 -1
  703. onnx/reference/ops/op_optional.py +1 -1
  704. onnx/reference/ops/op_optional_get_element.py +1 -1
  705. onnx/reference/ops/op_optional_has_element.py +1 -1
  706. onnx/reference/ops/op_or.py +1 -1
  707. onnx/reference/ops/op_pad.py +1 -1
  708. onnx/reference/ops/op_pool_common.py +7 -6
  709. onnx/reference/ops/op_pow.py +1 -1
  710. onnx/reference/ops/op_prelu.py +3 -3
  711. onnx/reference/ops/op_qlinear_conv.py +1 -1
  712. onnx/reference/ops/op_qlinear_matmul.py +1 -1
  713. onnx/reference/ops/op_quantize_linear.py +15 -9
  714. onnx/reference/ops/op_random_normal.py +1 -1
  715. onnx/reference/ops/op_random_normal_like.py +1 -1
  716. onnx/reference/ops/op_random_uniform.py +1 -1
  717. onnx/reference/ops/op_random_uniform_like.py +1 -1
  718. onnx/reference/ops/op_range.py +1 -1
  719. onnx/reference/ops/op_reciprocal.py +1 -1
  720. onnx/reference/ops/op_reduce_l1.py +1 -1
  721. onnx/reference/ops/op_reduce_l2.py +1 -1
  722. onnx/reference/ops/op_reduce_log_sum.py +1 -1
  723. onnx/reference/ops/op_reduce_log_sum_exp.py +1 -1
  724. onnx/reference/ops/op_reduce_max.py +1 -1
  725. onnx/reference/ops/op_reduce_mean.py +2 -2
  726. onnx/reference/ops/op_reduce_min.py +1 -1
  727. onnx/reference/ops/op_reduce_prod.py +1 -1
  728. onnx/reference/ops/op_reduce_sum.py +2 -2
  729. onnx/reference/ops/op_reduce_sum_square.py +1 -1
  730. onnx/reference/ops/op_regex_full_match.py +1 -1
  731. onnx/reference/ops/op_relu.py +1 -1
  732. onnx/reference/ops/op_reshape.py +1 -1
  733. onnx/reference/ops/op_reverse_sequence.py +1 -1
  734. onnx/reference/ops/op_rnn.py +10 -8
  735. onnx/reference/ops/op_roi_align.py +5 -5
  736. onnx/reference/ops/op_round.py +1 -1
  737. onnx/reference/ops/op_scan.py +8 -8
  738. onnx/reference/ops/op_scatter_elements.py +19 -50
  739. onnx/reference/ops/op_scatternd.py +1 -1
  740. onnx/reference/ops/op_selu.py +1 -1
  741. onnx/reference/ops/op_sequence_at.py +1 -1
  742. onnx/reference/ops/op_sequence_construct.py +1 -1
  743. onnx/reference/ops/op_sequence_empty.py +2 -2
  744. onnx/reference/ops/op_sequence_erase.py +1 -1
  745. onnx/reference/ops/op_sequence_insert.py +6 -6
  746. onnx/reference/ops/op_sequence_length.py +1 -1
  747. onnx/reference/ops/op_sequence_map.py +1 -1
  748. onnx/reference/ops/op_shape.py +2 -6
  749. onnx/reference/ops/op_shrink.py +1 -1
  750. onnx/reference/ops/op_sigmoid.py +1 -1
  751. onnx/reference/ops/op_sign.py +1 -1
  752. onnx/reference/ops/op_sin.py +1 -1
  753. onnx/reference/ops/op_sinh.py +1 -1
  754. onnx/reference/ops/op_size.py +1 -1
  755. onnx/reference/ops/op_slice.py +3 -5
  756. onnx/reference/ops/op_softmax.py +1 -1
  757. onnx/reference/ops/op_softmax_cross_entropy_loss.py +1 -1
  758. onnx/reference/ops/op_softplus.py +1 -1
  759. onnx/reference/ops/op_softsign.py +1 -1
  760. onnx/reference/ops/op_space_to_depth.py +1 -1
  761. onnx/reference/ops/op_split.py +1 -1
  762. onnx/reference/ops/op_split_to_sequence.py +5 -7
  763. onnx/reference/ops/op_sqrt.py +1 -1
  764. onnx/reference/ops/op_squeeze.py +1 -1
  765. onnx/reference/ops/op_stft.py +3 -2
  766. onnx/reference/ops/op_string_concat.py +1 -1
  767. onnx/reference/ops/op_string_normalizer.py +8 -8
  768. onnx/reference/ops/op_string_split.py +2 -4
  769. onnx/reference/ops/op_sub.py +1 -1
  770. onnx/reference/ops/op_sum.py +1 -1
  771. onnx/reference/ops/op_tan.py +1 -1
  772. onnx/reference/ops/op_tanh.py +1 -1
  773. onnx/reference/ops/op_tfidf_vectorizer.py +11 -12
  774. onnx/reference/ops/op_thresholded_relu.py +1 -1
  775. onnx/reference/ops/op_tile.py +1 -1
  776. onnx/reference/ops/op_topk.py +7 -2
  777. onnx/reference/ops/op_transpose.py +1 -1
  778. onnx/reference/ops/op_trilu.py +1 -1
  779. onnx/reference/ops/op_unique.py +3 -1
  780. onnx/reference/ops/op_unsqueeze.py +2 -2
  781. onnx/reference/ops/op_upsample.py +1 -1
  782. onnx/reference/ops/op_where.py +1 -1
  783. onnx/reference/ops/op_xor.py +1 -1
  784. onnx/reference/ops_optimized/__init__.py +1 -0
  785. onnx/reference/ops_optimized/op_conv_optimized.py +1 -1
  786. onnx/reference/reference_evaluator.py +27 -13
  787. onnx/serialization.py +1 -1
  788. onnx/shape_inference/implementation.cc +15 -1
  789. onnx/shape_inference/implementation.h +15 -1
  790. onnx/shape_inference.py +1 -1
  791. onnx/subbyte.py +6 -6
  792. onnx/test/basic_test.py +1 -0
  793. onnx/test/checker_test.py +37 -2
  794. onnx/test/compose_test.py +12 -11
  795. onnx/test/cpp/schema_registration_test.cc +3 -3
  796. onnx/test/cpp/shape_inference_test.cc +38 -2
  797. onnx/test/elu_test.py +2 -0
  798. onnx/test/function_inference_test.py +2 -0
  799. onnx/test/function_test.py +1 -0
  800. onnx/test/helper_test.py +77 -16
  801. onnx/test/hub_test.py +1 -1
  802. onnx/test/inference_function_test.py +25 -8
  803. onnx/test/inliner_test.py +2 -0
  804. onnx/test/model_container_refeval_test.py +2 -1
  805. onnx/test/model_container_test.py +1 -0
  806. onnx/test/model_inference_test.py +2 -0
  807. onnx/test/numpy_helper_test.py +56 -1
  808. onnx/test/parser_test.py +48 -2
  809. onnx/test/printer_test.py +2 -0
  810. onnx/test/reference_evaluator_ml_test.py +2 -3
  811. onnx/test/reference_evaluator_model_test.py +2 -0
  812. onnx/test/reference_evaluator_test.py +173 -19
  813. onnx/test/relu_test.py +2 -0
  814. onnx/test/schema_test.py +4 -2
  815. onnx/test/serialization_test.py +2 -0
  816. onnx/test/shape_inference_test.py +349 -19
  817. onnx/test/symbolic_shape_test.py +3 -3
  818. onnx/test/test_backend_onnxruntime.py +272 -1
  819. onnx/test/test_backend_reference.py +24 -3
  820. onnx/test/test_backend_test.py +6 -5
  821. onnx/test/test_external_data.py +91 -2
  822. onnx/test/test_with_ort.py +1 -0
  823. onnx/test/tools_test.py +15 -14
  824. onnx/test/training_tool_test.py +1 -0
  825. onnx/test/utils_test.py +1 -0
  826. onnx/test/version_converter/automatic_downgrade_test.py +2 -0
  827. onnx/test/version_converter/automatic_upgrade_test.py +2 -0
  828. onnx/test/version_converter_test.py +26 -7
  829. onnx/test/version_utils.py +8 -0
  830. onnx/tools/net_drawer.py +6 -5
  831. onnx/tools/replace_constants.py +11 -11
  832. onnx/tools/update_model_dims.py +7 -6
  833. onnx/utils.py +41 -21
  834. onnx/version.py +2 -2
  835. onnx/version_converter/adapters/split_17_18.h +1 -1
  836. onnx/version_converter/convert.h +107 -2
  837. onnx/version_converter.py +3 -2
  838. {onnx-1.16.2.dist-info → onnx-1.17.0.dist-info}/METADATA +9 -12
  839. {onnx-1.16.2.dist-info → onnx-1.17.0.dist-info}/RECORD +843 -817
  840. {onnx-1.16.2.dist-info → onnx-1.17.0.dist-info}/WHEEL +1 -1
  841. {onnx-1.16.2.dist-info → onnx-1.17.0.dist-info}/LICENSE +0 -0
  842. {onnx-1.16.2.dist-info → onnx-1.17.0.dist-info}/entry_points.txt +0 -0
  843. {onnx-1.16.2.dist-info → onnx-1.17.0.dist-info}/top_level.txt +0 -0
onnx/defs/schema.h CHANGED
@@ -15,6 +15,7 @@
15
15
  #include <ostream>
16
16
  #include <set>
17
17
  #include <string>
18
+ #include <string_view>
18
19
  #include <tuple>
19
20
  #include <unordered_map>
20
21
  #include <unordered_set>
@@ -763,12 +764,36 @@ class OpSchema final {
763
764
  return all_tensor_types_ir4;
764
765
  }
765
766
 
767
+ static const std::vector<std::string>& all_non_complex_numeric_types_plus_bool_ir4() {
768
+ static const std::vector<std::string> all_non_complex_numeric_types_plus_bool_ir4 = {
769
+ "tensor(uint8)",
770
+ "tensor(uint16)",
771
+ "tensor(uint32)",
772
+ "tensor(uint64)",
773
+ "tensor(int8)",
774
+ "tensor(int16)",
775
+ "tensor(int32)",
776
+ "tensor(int64)",
777
+ "tensor(bfloat16)",
778
+ "tensor(float16)",
779
+ "tensor(float)",
780
+ "tensor(double)",
781
+ "tensor(bool)"};
782
+ return all_non_complex_numeric_types_plus_bool_ir4;
783
+ }
784
+
766
785
  static const std::vector<std::string>& all_float_types_ir4() {
767
786
  static const std::vector<std::string> all_float_types_ir4 = {
768
787
  "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"};
769
788
  return all_float_types_ir4;
770
789
  }
771
790
 
791
+ static const std::vector<std::string>& all_float_types_plus_Xint8_ir4() {
792
+ static const std::vector<std::string> all_float_types_ir4 = {
793
+ "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int8)", "tensor(uint8)"};
794
+ return all_float_types_ir4;
795
+ }
796
+
772
797
  static const std::vector<std::string>& all_float_types_ir9() {
773
798
  static const std::vector<std::string> all_float_types_ir9 = {
774
799
  "tensor(bfloat16)",
@@ -809,6 +834,16 @@ class OpSchema final {
809
834
  return all_tensor_types_ir10;
810
835
  }
811
836
 
837
+ static const std::vector<std::string>& all_non_complex_tensor_types_ir10() {
838
+ static const std::vector<std::string> all_non_complex_tensor_types_ir10 = {
839
+ "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)",
840
+ "tensor(int8)", "tensor(int16)", "tensor(int32)", "tensor(int64)",
841
+ "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)",
842
+ "tensor(string)", "tensor(bool)", "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
843
+ "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)", "tensor(int4)"};
844
+ return all_non_complex_tensor_types_ir10;
845
+ }
846
+
812
847
  static const std::vector<std::string>& all_tensor_sequence_types() {
813
848
  static const std::vector<std::string> all_tensor_sequence_types = {
814
849
  "seq(tensor(uint8))",
@@ -1098,6 +1133,27 @@ class OpSchema final {
1098
1133
  std::set<std::string>* updated_ops = nullptr) const;
1099
1134
  void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;
1100
1135
 
1136
+ /**
1137
+ * @brief A common function to generate a prefix string for use in fail_check during the verify function.
1138
+ * @param node_name If empty, the returned string will not include the node name.
1139
+ * @return std::string The prefix string.
1140
+ */
1141
+ std::string VerifyFailPrefix(std::string_view node_name) const;
1142
+
1143
+ /**
1144
+ * @brief Verifies if the input number matches the pattern specified in the schema.
1145
+ * @param input_num The number of inputs to be verified against the schema.
1146
+ * @param node_info The prefix string used if the check fails.
1147
+ */
1148
+ void VerifyInputNum(int input_num, std::string_view node_name = "") const;
1149
+
1150
+ /**
1151
+ * @brief Verifies if the output number matches the pattern specified in the schema.
1152
+ * @param output_num The number of outputs to be verified against the schema.
1153
+ * @param node_info The prefix string used if the check fails.
1154
+ */
1155
+ void VerifyOutputNum(int output_num, std::string_view node_name = "") const;
1156
+
1101
1157
  std::string name_;
1102
1158
  std::string file_;
1103
1159
  std::string doc_;
@@ -1153,7 +1209,7 @@ class OpSchemaRegistry final : public ISchemaRegistry {
1153
1209
  // Increase the highest version when you make BC-breaking changes to the
1154
1210
  // operator schema on specific domain. Update the lowest version when it's
1155
1211
  // determined to remove too old version history.
1156
- map_[ONNX_DOMAIN] = std::make_pair(1, 21);
1212
+ map_[ONNX_DOMAIN] = std::make_pair(1, 22);
1157
1213
  map_[AI_ONNX_ML_DOMAIN] = std::make_pair(1, 5);
1158
1214
  map_[AI_ONNX_TRAINING_DOMAIN] = std::make_pair(1, 1);
1159
1215
  // ONNX's preview domain contains operators subject to change, so
@@ -1163,7 +1219,7 @@ class OpSchemaRegistry final : public ISchemaRegistry {
1163
1219
  // Version corresponding last release of ONNX. Update this to match with
1164
1220
  // the max version above in a *release* version of ONNX. But in other
1165
1221
  // versions, the max version may be ahead of the last-release-version.
1166
- last_release_version_map_[ONNX_DOMAIN] = 21;
1222
+ last_release_version_map_[ONNX_DOMAIN] = 22;
1167
1223
  last_release_version_map_[AI_ONNX_ML_DOMAIN] = 5;
1168
1224
  last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1;
1169
1225
  last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1;
@@ -105,6 +105,10 @@ struct InferenceContext {
105
105
  virtual const SparseTensorProto* getInputSparseData(size_t index) const = 0;
106
106
  // Gets the shape inputs computed by partial data propagation.
107
107
  virtual const TensorShapeProto* getSymbolicInput(size_t index) const = 0;
108
+ // To display a name the user can use to narrow its search.
109
+ virtual std::string getDisplayName() const {
110
+ return "";
111
+ }
108
112
  };
109
113
 
110
114
  // We use data propagation to perform partial evaluation of the model, to compute statically
@@ -263,7 +267,15 @@ inline void propagateElemTypeFromDtypeToOutput(
263
267
  } else {
264
268
  // This is not expected to happen
265
269
  fail_type_inference(
266
- "Output ", outputIndex, " expected to have: ", expected_value_case, " or UNDEFINED. Got: ", output_value_case);
270
+ "Output ",
271
+ outputIndex,
272
+ " expected to have: ",
273
+ expected_value_case,
274
+ " or UNDEFINED. Got: ",
275
+ output_value_case,
276
+ " in ",
277
+ ctx.getDisplayName(),
278
+ ".");
267
279
  }
268
280
  }
269
281
 
@@ -277,18 +289,18 @@ inline void propagateElemTypeFromDtypeToOutput(InferenceContext& ctx, const Attr
277
289
  const auto attr_type = attr->type();
278
290
  if (attr_type == AttributeProto::TENSOR) {
279
291
  if (attr->t().dims().size() != 1) {
280
- fail_type_inference("Attribute expected to have a one-dim tensor");
292
+ fail_type_inference("Attribute expected to have a one-dim tensor in ", ctx.getDisplayName(), ".");
281
293
  }
282
294
  data_type = attr->t().data_type();
283
295
  expected_value_case = TypeProto::kTensorType;
284
296
  } else if (attr_type == AttributeProto::SPARSE_TENSOR) {
285
297
  if (attr->sparse_tensor().dims().size() != 1) {
286
- fail_type_inference("Attribute expected to have a one-dim sparse tensor");
298
+ fail_type_inference("Attribute expected to have a one-dim sparse tensor in ", ctx.getDisplayName(), ".");
287
299
  }
288
300
  data_type = attr->sparse_tensor().values().data_type();
289
301
  expected_value_case = TypeProto::kSparseTensorType;
290
302
  } else {
291
- fail_type_inference("Attribute expected to have tensor or sparse tensor type");
303
+ fail_type_inference("Attribute expected to have tensor or sparse tensor type in ", ctx.getDisplayName(), ".");
292
304
  }
293
305
 
294
306
  propagateElemTypeFromDtypeToOutput(ctx, data_type, outputIndex, expected_value_case);
@@ -326,7 +338,10 @@ inline const TensorShapeProto& getInputShape(const InferenceContext& ctx, size_t
326
338
  const auto* input_type = ctx.getInputType(n);
327
339
  const auto value_case = input_type->value_case();
328
340
  if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
329
- fail_type_inference("Attribute expected to have tensor or sparse tensor type");
341
+ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
342
+ }
343
+ if (!hasShape(*input_type)) {
344
+ fail_shape_inference("Input ", n, " must have a non null shape in ", ctx.getDisplayName(), ".");
330
345
  }
331
346
  if (value_case == TypeProto::kTensorType) {
332
347
  return input_type->tensor_type().shape();
@@ -344,7 +359,7 @@ inline const TensorShapeProto* getOptionalInputShape(InferenceContext& ctx, size
344
359
 
345
360
  const auto value_case = input_type->value_case();
346
361
  if (value_case != TypeProto::kTensorType && value_case != TypeProto::kSparseTensorType) {
347
- fail_type_inference("Attribute expected to have tensor or sparse tensor type");
362
+ fail_type_inference("Input ", n, "expected to be a tensor or a sparse tensor type in ", ctx.getDisplayName(), ".");
348
363
  }
349
364
  if (value_case == TypeProto::kTensorType) {
350
365
  return &input_type->tensor_type().shape();
@@ -372,7 +387,10 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType(
372
387
  " does not match type of output: ",
373
388
  outputIndex,
374
389
  "type: ",
375
- output_value_case);
390
+ output_value_case,
391
+ " in ",
392
+ ctx.getDisplayName(),
393
+ ".");
376
394
  }
377
395
  if (TypeProto::kTensorType == input_value_case) {
378
396
  auto* dim = output_type->mutable_tensor_type()->mutable_shape()->add_dim();
@@ -382,7 +400,13 @@ inline void appendSingleDimCopiedFromInputTypeToOutputType(
382
400
  *dim = input_type->sparse_tensor_type().shape().dim(static_cast<int>(fromDimIndex));
383
401
  } else {
384
402
  fail_type_inference(
385
- "Input ", inputIndex, " and Output ", outputIndex, " expected to have tensor or sparse tensor type");
403
+ "Input ",
404
+ inputIndex,
405
+ " and Output ",
406
+ outputIndex,
407
+ " expected to have tensor or sparse tensor type in ",
408
+ ctx.getDisplayName(),
409
+ ".");
386
410
  }
387
411
  }
388
412
 
@@ -440,7 +464,14 @@ updateOutputElemType(InferenceContext& ctx, size_t outputIndex, int32_t elemType
440
464
  setTensorElementType(elemType, expected_type, *output_type);
441
465
  } else {
442
466
  // This is not expected to happen
443
- fail_type_inference("Output ", outputIndex, " expected to have tensor or sparse tensor type: ", expected_type);
467
+ fail_type_inference(
468
+ "Output ",
469
+ outputIndex,
470
+ " expected to have tensor or sparse tensor type: ",
471
+ expected_type,
472
+ " in ",
473
+ ctx.getDisplayName(),
474
+ ".");
444
475
  }
445
476
  }
446
477
 
@@ -462,16 +493,17 @@ inline void propagateElemTypeFromAttributeToOutput(
462
493
  updateOutputElemType(ctx, outputIndex, default_value, expected_type);
463
494
  return;
464
495
  } else {
465
- fail_type_inference("Value of attribute ", attributeName, " not specified");
496
+ fail_type_inference("Value of attribute ", attributeName, " not specified in ", ctx.getDisplayName(), ".");
466
497
  }
467
498
  }
468
499
  if (!attr_proto->has_i()) {
469
- fail_type_inference("Attribute ", attributeName, " should be of integer type and specify a type.");
500
+ fail_type_inference(
501
+ "Attribute ", attributeName, " should be of integer type and specify a type in ", ctx.getDisplayName(), ".");
470
502
  }
471
503
  auto attr_value = attr_proto->i();
472
504
  auto elem_type = static_cast<TensorProto_DataType>(attr_value);
473
505
  if (!TensorProto_DataType_IsValid(elem_type)) {
474
- fail_type_inference("Attribute ", attributeName, " does not specify a valid type.");
506
+ fail_type_inference("Attribute ", attributeName, " does not specify a valid type in ", ctx.getDisplayName(), ".");
475
507
  }
476
508
  updateOutputElemType(ctx, outputIndex, elem_type, expected_type);
477
509
  }
@@ -497,7 +529,7 @@ inline TensorShapeProto*
497
529
  getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_type = TypeProto::kTensorType) {
498
530
  auto output_type = ctx.getOutputType(n);
499
531
  if (output_type == nullptr) {
500
- fail_type_inference("Output ", n, " expected to have tensor or sparse type");
532
+ fail_type_inference("Output ", n, " expected to have tensor or sparse type in ", ctx.getDisplayName(), ".");
501
533
  }
502
534
  const auto output_value_case = output_type->value_case();
503
535
  if (output_value_case == TypeProto::kTensorType || output_value_case == TypeProto::kSparseTensorType) {
@@ -505,7 +537,7 @@ getOutputShape(InferenceContext& ctx, size_t n, TypeProto::ValueCase default_typ
505
537
  } else if (output_value_case == TypeProto::VALUE_NOT_SET) {
506
538
  return getTensorMutableShape(default_type, *output_type);
507
539
  } else {
508
- fail_type_inference("Output ", n, " expected to have tensor type");
540
+ fail_type_inference("Output ", n, " expected to have tensor type in ", ctx.getDisplayName(), ".");
509
541
  }
510
542
  }
511
543
 
@@ -562,13 +594,13 @@ inline void propagateShapeFromAttributeToOutput(
562
594
  auto attr_proto = ctx.getAttribute(attributeName);
563
595
  if ((nullptr == attr_proto) || (!attr_proto->has_type()) ||
564
596
  (attr_proto->type() != AttributeProto_AttributeType_INTS)) {
565
- fail_shape_inference("Attribute ", attributeName, " should specify a shape");
597
+ fail_shape_inference("Attribute ", attributeName, " should specify a shape in ", ctx.getDisplayName(), ".");
566
598
  }
567
599
  auto& int_list = attr_proto->ints();
568
600
  TensorShapeProto shape;
569
601
  for (auto dim_size : int_list) {
570
602
  if (dim_size < 0) {
571
- fail_shape_inference("Negative values are not allowed in a shape specification");
603
+ fail_shape_inference("Negative values are not allowed in a shape specification in ", ctx.getDisplayName(), ".");
572
604
  }
573
605
  shape.add_dim()->set_dim_value(dim_size);
574
606
  }
@@ -745,7 +777,16 @@ inline void checkInputRank(InferenceContext& ctx, size_t input_index, int expect
745
777
  if (hasInputShape(ctx, input_index)) {
746
778
  auto rank = getInputShape(ctx, input_index).dim_size();
747
779
  if (rank != expected_rank) {
748
- fail_shape_inference("Input ", input_index, " expected to have rank ", expected_rank, " but has rank ", rank);
780
+ fail_shape_inference(
781
+ "Input ",
782
+ input_index,
783
+ " expected to have rank ",
784
+ expected_rank,
785
+ " but has rank ",
786
+ rank,
787
+ " in ",
788
+ ctx.getDisplayName(),
789
+ ".");
749
790
  }
750
791
  }
751
792
  }
@@ -798,7 +839,15 @@ inline void unifyInputDim(InferenceContext& ctx, size_t input_index, int dim_ind
798
839
  // This shape is expected to have rank > dim_index:
799
840
  if (input_shape.dim_size() <= dim_index) {
800
841
  fail_shape_inference(
801
- "Input ", input_index, " expected to have rank >", dim_index, " but has rank ", input_shape.dim_size());
842
+ "Input ",
843
+ input_index,
844
+ " expected to have rank >",
845
+ dim_index,
846
+ " but has rank ",
847
+ input_shape.dim_size(),
848
+ " in ",
849
+ ctx.getDisplayName(),
850
+ ".");
802
851
  }
803
852
  const Dim& input_dim = input_shape.dim(dim_index);
804
853
  // Now, unify dim and input_dim:
onnx/defs/tensor/defs.cc CHANGED
@@ -5,6 +5,7 @@
5
5
  #include <algorithm>
6
6
  #include <cmath>
7
7
  #include <numeric>
8
+ #include <optional>
8
9
 
9
10
  #include "onnx/defs/data_propagators.h"
10
11
  #include "onnx/defs/function.h"
@@ -135,7 +136,7 @@ ONNX_OPERATOR_SET_SCHEMA(
135
136
  PropagateShapeDataFromInputToOutput(ctx, 0);
136
137
  }));
137
138
 
138
- static const char* CastLike_ver19_doc = R"DOC(
139
+ static const char* CastLike_ver21_doc = R"DOC(
139
140
  The operator casts the elements of a given input tensor (the first input) to
140
141
  the same data type as the elements of the second input tensor.
141
142
  See documentation of the Cast operator for further details.
@@ -145,7 +146,7 @@ ONNX_OPERATOR_SET_SCHEMA(
145
146
  CastLike,
146
147
  21,
147
148
  OpSchema()
148
- .SetDoc(CastLike_ver19_doc)
149
+ .SetDoc(CastLike_ver21_doc)
149
150
  .Attr(
150
151
  "saturate",
151
152
  "The parameter defines how the conversion behaves if an input value is out of "
@@ -175,19 +176,11 @@ ONNX_OPERATOR_SET_SCHEMA(
175
176
  OpSchema::Differentiable)
176
177
  .TypeConstraint(
177
178
  "T1",
178
- {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int8)",
179
- "tensor(int16)", "tensor(int32)", "tensor(int64)", "tensor(uint8)",
180
- "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(bool)",
181
- "tensor(string)", "tensor(bfloat16)", "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
182
- "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)", "tensor(int4)"},
179
+ OpSchema::all_non_complex_tensor_types_ir10(),
183
180
  "Constrain input types. Casting from complex is not supported.")
184
181
  .TypeConstraint(
185
182
  "T2",
186
- {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(int8)",
187
- "tensor(int16)", "tensor(int32)", "tensor(int64)", "tensor(uint8)",
188
- "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(bool)",
189
- "tensor(string)", "tensor(bfloat16)", "tensor(float8e4m3fn)", "tensor(float8e4m3fnuz)",
190
- "tensor(float8e5m2)", "tensor(float8e5m2fnuz)", "tensor(uint4)", "tensor(int4)"},
183
+ OpSchema::all_non_complex_tensor_types_ir10(),
191
184
  "Constrain output types. Casting to complex is not supported.")
192
185
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
193
186
  propagateElemTypeFromInputToOutput(ctx, 1, 0);
@@ -2323,7 +2316,7 @@ ONNX_OPERATOR_SET_SCHEMA(
2323
2316
  .SetDoc(Resize_ver19_doc)
2324
2317
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { resizeShapeInference_opset18_to_19(ctx); }));
2325
2318
 
2326
- static const char* GridSample_ver20_doc = R"DOC(
2319
+ static const char* GridSample_ver22_doc = R"DOC(
2327
2320
  Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
2328
2321
  For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2),
2329
2322
  the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W),
@@ -2346,7 +2339,7 @@ See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/ge
2346
2339
 
2347
2340
  ONNX_OPERATOR_SET_SCHEMA(
2348
2341
  GridSample,
2349
- 20,
2342
+ 22,
2350
2343
  OpSchema()
2351
2344
  .Attr(
2352
2345
  "mode",
@@ -2412,13 +2405,10 @@ ONNX_OPERATOR_SET_SCHEMA(
2412
2405
  OpSchema::Differentiable)
2413
2406
  .TypeConstraint(
2414
2407
  "T1",
2415
- OpSchema::all_tensor_types(),
2408
+ OpSchema::all_tensor_types_ir4(),
2416
2409
  "Constrain input `X` and output `Y` types to all tensor types.")
2417
- .TypeConstraint(
2418
- "T2",
2419
- {"tensor(float16)", "tensor(float)", "tensor(double)"},
2420
- "Constrain grid types to float tensors.")
2421
- .SetDoc(GridSample_ver20_doc)
2410
+ .TypeConstraint("T2", OpSchema::all_float_types_ir4(), "Constrain grid types to float tensors.")
2411
+ .SetDoc(GridSample_ver22_doc)
2422
2412
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { gridSampleShapeInference(ctx); }));
2423
2413
 
2424
2414
  static const char* AffineGrid_ver20_doc = R"DOC(
@@ -2895,8 +2885,18 @@ ONNX_OPERATOR_SET_SCHEMA(
2895
2885
  // and 1 element vector for now. In future when version update for
2896
2886
  // this op is done we should only allow scalar or change the spec to
2897
2887
  // allow both.
2888
+ std::optional<int64_t> depth_value;
2898
2889
  if (hasInputShape(ctx, 1)) {
2899
2890
  auto& depth_shape = getInputShape(ctx, 1);
2891
+ if (const TensorProto* depth_data = ctx.getInputData(1)) {
2892
+ if (depth_data->data_type() == TensorProto::INT64) {
2893
+ depth_value = ParseData<int64_t>(depth_data)[0];
2894
+ } else if (depth_data->data_type() == TensorProto::INT32) {
2895
+ depth_value = ParseData<int32_t>(depth_data)[0];
2896
+ } else if (depth_data->data_type() == TensorProto::FLOAT) {
2897
+ depth_value = static_cast<int64_t>(ParseData<float>(depth_data)[0]);
2898
+ }
2899
+ }
2900
2900
  if (depth_shape.dim_size() != 0 && depth_shape.dim_size() != 1) {
2901
2901
  fail_type_inference("Input 'depth' must be a scalar or rank 1 tensor.");
2902
2902
  }
@@ -2947,6 +2947,8 @@ ONNX_OPERATOR_SET_SCHEMA(
2947
2947
  } else if (indices_shape.dim(i - 1).has_dim_param()) {
2948
2948
  dim->set_dim_param(indices_shape.dim(i - 1).dim_param());
2949
2949
  }
2950
+ } else if (depth_value) {
2951
+ dim->set_dim_value(*depth_value);
2950
2952
  }
2951
2953
  }
2952
2954
  }
onnx/defs/tensor/old.cc CHANGED
@@ -5,6 +5,7 @@
5
5
  #include <algorithm>
6
6
  #include <cmath>
7
7
  #include <numeric>
8
+ #include <optional>
8
9
 
9
10
  #include "onnx/defs/data_propagators.h"
10
11
  #include "onnx/defs/function.h"
@@ -12,6 +13,104 @@
12
13
 
13
14
  namespace ONNX_NAMESPACE {
14
15
 
16
+ static const char* GridSample_ver20_doc = R"DOC(
17
+ Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
18
+ For spatial input `X` with shape (N, C, H, W), the `grid` will have shape (N, H_out, W_out, 2),
19
+ the output `Y` will have shape (N, C, H_out, W_out). For volumetric input `X` with shape (N, C, D, H, W),
20
+ the `grid` will have shape (N, D_out, H_out, W_out, 3), the output `Y` will have shape (N, C, D_out, H_out, W_out).
21
+ More generally, for an input `X` of rank r+2 with shape (N, C, d1, d2, ..., dr),
22
+ the `grid` will have shape (N, D1_out, D2_out, ..., Dr_out, r), the output `Y` will have shape (N, C, D1_out, D2_out, ..., Dr_out).
23
+
24
+ The tensor `X` contains values at centers of square pixels (voxels, etc) locations such as (n, c, d1_in, d2_in, ..., dr_in).
25
+ The (n, d1_out, d2_out, ..., dr_out, :) values from the tensor `grid` are the normalized positions for interpolating the values
26
+ at the (n, c, d1_out, d2_out, ..., dr_out) locations from the output tensor `Y` using a specified interpolation method (the mode)
27
+ and a padding mode (for `grid` positions falling outside the 2-dimensional image).
28
+
29
+ For example, the values in `grid[n, h_out, w_out, :]` are size-2 vectors specifying normalized positions in the 2-dimensional space of `X`.
30
+ They are used to interpolate output values of `Y[n, c, h_out, w_out]`.
31
+
32
+ The GridSample operator is often used in doing grid generator and sampler in the
33
+ [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025).
34
+ See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html).
35
+ )DOC";
36
+
37
+ ONNX_OPERATOR_SET_SCHEMA(
38
+ GridSample,
39
+ 20,
40
+ OpSchema()
41
+ .Attr(
42
+ "mode",
43
+ "Three interpolation modes: linear (default), nearest and cubic. "
44
+ "The \"linear\" mode includes linear and N-linear interpolation modes depending on the number of spatial dimensions "
45
+ "of the input tensor (i.e. linear for 1 spatial dimension, bilinear for 2 spatial dimensions, etc.). "
46
+ "The \"cubic\" mode also includes N-cubic interpolation modes following the same rules. The \"nearest\" mode rounds "
47
+ "to the nearest even index when the sampling point falls halfway between two indices.",
48
+ AttributeProto::STRING,
49
+ std::string("linear"))
50
+ .Attr(
51
+ "padding_mode",
52
+ "Support padding modes for outside grid values: `zeros`(default), `border`, `reflection`. "
53
+ "zeros: use 0 for out-of-bound grid locations, "
54
+ "border: use border values for out-of-bound grid locations, "
55
+ "reflection: use values at locations reflected by the border for out-of-bound grid locations. "
56
+ "If index 0 represents the margin pixel, the reflected value at index -1 will be the same as the value at index 1. "
57
+ "For location far away from the border, it will keep being reflected until becoming in bound. "
58
+ "If pixel location x = -3.5 reflects by border -1 and becomes x' = 1.5, then reflects by border 1 and becomes x'' = 0.5.",
59
+ AttributeProto::STRING,
60
+ std::string("zeros"))
61
+ .Attr(
62
+ "align_corners",
63
+ "If align_corners=1, the extrema (-1 and 1) are considered as referring to the center points of the input's corner pixels (voxels, etc.). "
64
+ "If align_corners=0, they are instead considered as referring to the corner points of the input's corner pixels (voxels, etc.), "
65
+ "making the sampling more resolution agnostic.",
66
+ AttributeProto::INT,
67
+ static_cast<int64_t>(0))
68
+ .Input(
69
+ 0,
70
+ "X",
71
+ "Input tensor of rank r+2 that has shape (N, C, D1, D2, ..., Dr), where N is the batch size, "
72
+ "C is the number of channels, D1, D2, ..., Dr are the spatial dimensions.",
73
+ "T1",
74
+ OpSchema::Single,
75
+ true,
76
+ 1,
77
+ OpSchema::Differentiable)
78
+ .Input(
79
+ 1,
80
+ "grid",
81
+ "Input offset of shape (N, D1_out, D2_out, ..., Dr_out, r), where D1_out, D2_out, ..., "
82
+ "Dr_out are the spatial dimensions of the grid and output, and r is the number of spatial dimensions. "
83
+ "Grid specifies the sampling locations normalized by the input spatial dimensions. "
84
+ "Therefore, it should have most values in the range of [-1, 1]. If the grid has values outside the range of [-1, 1], "
85
+ "the corresponding outputs will be handled as defined by padding_mode. Following computer vision convention, "
86
+ "the coordinates in the length-r location vector are listed from the innermost tensor dimension to the outermost, "
87
+ "the opposite of regular tensor indexing.",
88
+ "T2",
89
+ OpSchema::Single,
90
+ true,
91
+ 1,
92
+ OpSchema::NonDifferentiable)
93
+ .Output(
94
+ 0,
95
+ "Y",
96
+ "Output tensor of rank r+2 that has shape (N, C, D1_out, D2_out, ..., Dr_out) of the sampled values. "
97
+ "For integer input types, intermediate values are computed as floating point and cast to integer at the end.",
98
+ "T1",
99
+ OpSchema::Single,
100
+ true,
101
+ 1,
102
+ OpSchema::Differentiable)
103
+ .TypeConstraint(
104
+ "T1",
105
+ OpSchema::all_tensor_types(),
106
+ "Constrain input `X` and output `Y` types to all tensor types.")
107
+ .TypeConstraint(
108
+ "T2",
109
+ {"tensor(float16)", "tensor(float)", "tensor(double)"},
110
+ "Constrain grid types to float tensors.")
111
+ .SetDoc(GridSample_ver20_doc)
112
+ .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { gridSampleShapeInference(ctx); }));
113
+
15
114
  static const char* Cast_ver19_doc = R"DOC(
16
115
  The operator casts the elements of a given input tensor to a data type
17
116
  specified by the 'to' argument and returns an output tensor of the same size in
@@ -5152,8 +5251,18 @@ ONNX_OPERATOR_SET_SCHEMA(
5152
5251
  // and 1 element vector for now. In future when version update for
5153
5252
  // this op is done we should only allow scalar or change the spec to
5154
5253
  // allow both.
5254
+ std::optional<int64_t> depth_value;
5155
5255
  if (hasInputShape(ctx, 1)) {
5156
5256
  auto& depth_shape = getInputShape(ctx, 1);
5257
+ if (const TensorProto* depth_data = ctx.getInputData(1)) {
5258
+ if (depth_data->data_type() == TensorProto::INT64) {
5259
+ depth_value = ParseData<int64_t>(depth_data)[0];
5260
+ } else if (depth_data->data_type() == TensorProto::INT32) {
5261
+ depth_value = ParseData<int32_t>(depth_data)[0];
5262
+ } else if (depth_data->data_type() == TensorProto::FLOAT) {
5263
+ depth_value = static_cast<int64_t>(ParseData<float>(depth_data)[0]);
5264
+ }
5265
+ }
5157
5266
  if (depth_shape.dim_size() != 0 && depth_shape.dim_size() != 1) {
5158
5267
  fail_type_inference("Input 'depth' must be a scalar or rank 1 tensor.");
5159
5268
  }
@@ -5204,6 +5313,8 @@ ONNX_OPERATOR_SET_SCHEMA(
5204
5313
  } else if (indices_shape.dim(i - 1).has_dim_param()) {
5205
5314
  dim->set_dim_param(indices_shape.dim(i - 1).dim_param());
5206
5315
  }
5316
+ } else if (depth_value) {
5317
+ dim->set_dim_value(*depth_value);
5207
5318
  }
5208
5319
  }
5209
5320
  }
@@ -1,15 +1,23 @@
1
1
  # Copyright (c) ONNX Project Contributors
2
2
  #
3
3
  # SPDX-License-Identifier: Apache-2.0
4
+ from __future__ import annotations
5
+
4
6
  import os
5
7
  import re
6
8
  import sys
7
9
  import uuid
8
10
  from itertools import chain
9
- from typing import Callable, Iterable, Optional
11
+ from typing import Callable, Iterable
10
12
 
11
13
  import onnx.onnx_cpp2py_export.checker as c_checker
12
- from onnx.onnx_pb import AttributeProto, GraphProto, ModelProto, TensorProto
14
+ from onnx.onnx_pb import (
15
+ AttributeProto,
16
+ FunctionProto,
17
+ GraphProto,
18
+ ModelProto,
19
+ TensorProto,
20
+ )
13
21
 
14
22
 
15
23
  class ExternalDataInfo:
@@ -71,10 +79,10 @@ def load_external_data_for_model(model: ModelProto, base_dir: str) -> None:
71
79
  def set_external_data(
72
80
  tensor: TensorProto,
73
81
  location: str,
74
- offset: Optional[int] = None,
75
- length: Optional[int] = None,
76
- checksum: Optional[str] = None,
77
- basepath: Optional[str] = None,
82
+ offset: int | None = None,
83
+ length: int | None = None,
84
+ checksum: str | None = None,
85
+ basepath: str | None = None,
78
86
  ) -> None:
79
87
  if not tensor.HasField("raw_data"):
80
88
  raise ValueError(
@@ -101,7 +109,7 @@ def set_external_data(
101
109
  def convert_model_to_external_data(
102
110
  model: ModelProto,
103
111
  all_tensors_to_one_file: bool = True,
104
- location: Optional[str] = None,
112
+ location: str | None = None,
105
113
  size_threshold: int = 1024,
106
114
  convert_attribute: bool = False,
107
115
  ) -> None:
@@ -220,11 +228,12 @@ def _recursive_attribute_processor(
220
228
 
221
229
 
222
230
  def _get_initializer_tensors_from_graph(
223
- onnx_model_proto_graph: GraphProto,
231
+ graph_or_function: GraphProto | FunctionProto, /
224
232
  ) -> Iterable[TensorProto]:
225
- """Create an iterator of initializer tensors from ONNX model graph."""
226
- yield from onnx_model_proto_graph.initializer
227
- for node in onnx_model_proto_graph.node:
233
+ """Create an iterator of initializer tensors from ONNX model graph/function."""
234
+ if isinstance(graph_or_function, GraphProto):
235
+ yield from graph_or_function.initializer
236
+ for node in graph_or_function.node:
228
237
  for attribute in node.attribute:
229
238
  yield from _recursive_attribute_processor(
230
239
  attribute, _get_initializer_tensors_from_graph
@@ -234,13 +243,15 @@ def _get_initializer_tensors_from_graph(
234
243
  def _get_initializer_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
235
244
  """Create an iterator of initializer tensors from ONNX model."""
236
245
  yield from _get_initializer_tensors_from_graph(onnx_model_proto.graph)
246
+ for function in onnx_model_proto.functions:
247
+ yield from _get_attribute_tensors_from_graph(function)
237
248
 
238
249
 
239
250
  def _get_attribute_tensors_from_graph(
240
- onnx_model_proto_graph: GraphProto,
251
+ graph_or_function: GraphProto | FunctionProto, /
241
252
  ) -> Iterable[TensorProto]:
242
- """Create an iterator of tensors from node attributes of an ONNX model graph."""
243
- for node in onnx_model_proto_graph.node:
253
+ """Create an iterator of tensors from node attributes of an ONNX model graph/function."""
254
+ for node in graph_or_function.node:
244
255
  for attribute in node.attribute:
245
256
  if attribute.HasField("t"):
246
257
  yield attribute.t
@@ -253,6 +264,8 @@ def _get_attribute_tensors_from_graph(
253
264
  def _get_attribute_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
254
265
  """Create an iterator of tensors from node attributes of an ONNX model."""
255
266
  yield from _get_attribute_tensors_from_graph(onnx_model_proto.graph)
267
+ for function in onnx_model_proto.functions:
268
+ yield from _get_attribute_tensors_from_graph(function)
256
269
 
257
270
 
258
271
  def _is_valid_filename(filename: str) -> bool: