onnx 1.15.0__cp311-cp311-win_amd64.whl → 1.16.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx might be problematic. Click here for more details.

Files changed (584) hide show
  1. onnx/__init__.py +10 -10
  2. onnx/backend/base.py +13 -14
  3. onnx/backend/sample/ops/abs.py +1 -1
  4. onnx/backend/test/case/model/__init__.py +0 -1
  5. onnx/backend/test/case/node/ai_onnx_ml/tree_ensemble.py +122 -0
  6. onnx/backend/test/case/node/averagepool.py +15 -30
  7. onnx/backend/test/case/node/cast.py +88 -11
  8. onnx/backend/test/case/node/dequantizelinear.py +155 -0
  9. onnx/backend/test/case/node/groupnormalization.py +13 -9
  10. onnx/backend/test/case/node/gru.py +2 -2
  11. onnx/backend/test/case/node/isinf.py +4 -4
  12. onnx/backend/test/case/node/isnan.py +2 -2
  13. onnx/backend/test/case/node/lppool.py +8 -16
  14. onnx/backend/test/case/node/lstm.py +1 -1
  15. onnx/backend/test/case/node/maxpool.py +40 -34
  16. onnx/backend/test/case/node/pow.py +1 -1
  17. onnx/backend/test/case/node/qlinearmatmul.py +143 -109
  18. onnx/backend/test/case/node/quantizelinear.py +298 -7
  19. onnx/backend/test/case/node/reducemax.py +26 -0
  20. onnx/backend/test/case/node/rnn.py +1 -1
  21. onnx/backend/test/case/node/scan.py +6 -2
  22. onnx/backend/test/case/node/scatterelements.py +1 -1
  23. onnx/backend/test/case/node/topk.py +1 -1
  24. onnx/backend/test/case/utils.py +1 -3
  25. onnx/backend/test/data/node/test_ai_onnx_ml_tree_ensemble_set_membership/model.onnx +0 -0
  26. onnx/backend/test/data/node/test_ai_onnx_ml_tree_ensemble_set_membership/test_data_set_0/input_0.pb +0 -0
  27. onnx/backend/test/data/node/test_ai_onnx_ml_tree_ensemble_set_membership/test_data_set_0/output_0.pb +0 -0
  28. onnx/backend/test/data/node/test_ai_onnx_ml_tree_ensemble_single_tree/model.onnx +0 -0
  29. onnx/backend/test/data/node/test_ai_onnx_ml_tree_ensemble_single_tree/test_data_set_0/input_0.pb +1 -0
  30. onnx/backend/test/data/node/test_ai_onnx_ml_tree_ensemble_single_tree/test_data_set_0/output_0.pb +0 -0
  31. onnx/backend/test/data/node/test_cast_BFLOAT16_to_FLOAT/model.onnx +0 -0
  32. onnx/backend/test/data/node/test_cast_DOUBLE_to_FLOAT/model.onnx +0 -0
  33. onnx/backend/test/data/node/test_cast_DOUBLE_to_FLOAT16/model.onnx +0 -0
  34. onnx/backend/test/data/node/test_cast_FLOAT16_to_DOUBLE/model.onnx +0 -0
  35. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT/model.onnx +0 -0
  36. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E4M3FN/model.onnx +0 -0
  37. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E4M3FN/test_data_set_0/input_0.pb +2 -2
  38. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E4M3FN/test_data_set_0/output_0.pb +0 -0
  39. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E4M3FNUZ/model.onnx +0 -0
  40. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E4M3FNUZ/test_data_set_0/input_0.pb +2 -2
  41. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E4M3FNUZ/test_data_set_0/output_0.pb +0 -0
  42. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E5M2/model.onnx +0 -0
  43. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E5M2/test_data_set_0/input_0.pb +2 -2
  44. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E5M2/test_data_set_0/output_0.pb +0 -0
  45. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E5M2FNUZ/model.onnx +0 -0
  46. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E5M2FNUZ/test_data_set_0/input_0.pb +2 -2
  47. onnx/backend/test/data/node/test_cast_FLOAT16_to_FLOAT8E5M2FNUZ/test_data_set_0/output_0.pb +0 -0
  48. onnx/backend/test/data/node/test_cast_FLOAT16_to_INT4/model.onnx +0 -0
  49. onnx/backend/test/data/node/test_cast_FLOAT16_to_INT4/test_data_set_0/input_0.pb +0 -0
  50. onnx/backend/test/data/node/test_cast_FLOAT16_to_INT4/test_data_set_0/output_0.pb +1 -0
  51. onnx/backend/test/data/node/test_cast_FLOAT16_to_UINT4/model.onnx +0 -0
  52. onnx/backend/test/data/node/test_cast_FLOAT16_to_UINT4/test_data_set_0/input_0.pb +0 -0
  53. onnx/backend/test/data/node/test_cast_FLOAT16_to_UINT4/test_data_set_0/output_0.pb +0 -0
  54. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT/model.onnx +0 -0
  55. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT/test_data_set_0/input_0.pb +0 -0
  56. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT/test_data_set_0/output_0.pb +0 -0
  57. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT16/model.onnx +0 -0
  58. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT16/test_data_set_0/input_0.pb +0 -0
  59. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FNUZ_to_FLOAT16/test_data_set_0/output_0.pb +0 -0
  60. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT/model.onnx +0 -0
  61. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT/test_data_set_0/input_0.pb +0 -0
  62. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT/test_data_set_0/output_0.pb +0 -0
  63. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT16/model.onnx +0 -0
  64. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT16/test_data_set_0/input_0.pb +0 -0
  65. onnx/backend/test/data/node/test_cast_FLOAT8E4M3FN_to_FLOAT16/test_data_set_0/output_0.pb +0 -0
  66. onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT/model.onnx +0 -0
  67. onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT/test_data_set_0/input_0.pb +0 -0
  68. onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT/test_data_set_0/output_0.pb +0 -0
  69. onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT16/model.onnx +0 -0
  70. onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT16/test_data_set_0/input_0.pb +0 -0
  71. onnx/backend/test/data/node/test_cast_FLOAT8E5M2FNUZ_to_FLOAT16/test_data_set_0/output_0.pb +0 -0
  72. onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT/model.onnx +0 -0
  73. onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT/test_data_set_0/input_0.pb +0 -0
  74. onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT/test_data_set_0/output_0.pb +0 -0
  75. onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT16/model.onnx +0 -0
  76. onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT16/test_data_set_0/input_0.pb +0 -0
  77. onnx/backend/test/data/node/test_cast_FLOAT8E5M2_to_FLOAT16/test_data_set_0/output_0.pb +0 -0
  78. onnx/backend/test/data/node/test_cast_FLOAT_to_BFLOAT16/model.onnx +0 -0
  79. onnx/backend/test/data/node/test_cast_FLOAT_to_DOUBLE/model.onnx +0 -0
  80. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT16/model.onnx +0 -0
  81. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E4M3FN/model.onnx +0 -0
  82. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E4M3FN/test_data_set_0/input_0.pb +0 -0
  83. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E4M3FN/test_data_set_0/output_0.pb +0 -0
  84. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E4M3FNUZ/model.onnx +0 -0
  85. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E4M3FNUZ/test_data_set_0/input_0.pb +0 -0
  86. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E4M3FNUZ/test_data_set_0/output_0.pb +0 -0
  87. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E5M2/model.onnx +0 -0
  88. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E5M2/test_data_set_0/input_0.pb +0 -0
  89. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E5M2/test_data_set_0/output_0.pb +0 -0
  90. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E5M2FNUZ/model.onnx +0 -0
  91. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E5M2FNUZ/test_data_set_0/input_0.pb +0 -0
  92. onnx/backend/test/data/node/test_cast_FLOAT_to_FLOAT8E5M2FNUZ/test_data_set_0/output_0.pb +0 -0
  93. onnx/backend/test/data/node/test_cast_FLOAT_to_INT4/model.onnx +0 -0
  94. onnx/backend/test/data/node/test_cast_FLOAT_to_INT4/test_data_set_0/input_0.pb +0 -0
  95. onnx/backend/test/data/node/test_cast_FLOAT_to_INT4/test_data_set_0/output_0.pb +1 -0
  96. onnx/backend/test/data/node/test_cast_FLOAT_to_STRING/model.onnx +0 -0
  97. onnx/backend/test/data/node/test_cast_FLOAT_to_UINT4/model.onnx +0 -0
  98. onnx/backend/test/data/node/test_cast_FLOAT_to_UINT4/test_data_set_0/input_0.pb +0 -0
  99. onnx/backend/test/data/node/test_cast_FLOAT_to_UINT4/test_data_set_0/output_0.pb +0 -0
  100. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT/model.onnx +0 -0
  101. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT/test_data_set_0/input_0.pb +1 -0
  102. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT/test_data_set_0/output_0.pb +0 -0
  103. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT16/model.onnx +0 -0
  104. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT16/test_data_set_0/input_0.pb +1 -0
  105. onnx/backend/test/data/node/test_cast_INT4_to_FLOAT16/test_data_set_0/output_0.pb +0 -0
  106. onnx/backend/test/data/node/test_cast_INT4_to_INT8/model.onnx +0 -0
  107. onnx/backend/test/data/node/test_cast_INT4_to_INT8/test_data_set_0/input_0.pb +1 -0
  108. onnx/backend/test/data/node/test_cast_INT4_to_INT8/test_data_set_0/output_0.pb +0 -0
  109. onnx/backend/test/data/node/test_cast_STRING_to_FLOAT/model.onnx +0 -0
  110. onnx/backend/test/data/node/test_cast_UINT4_to_FLOAT/model.onnx +0 -0
  111. onnx/backend/test/data/node/test_cast_UINT4_to_FLOAT/test_data_set_0/input_0.pb +0 -0
  112. onnx/backend/test/data/node/test_cast_UINT4_to_FLOAT/test_data_set_0/output_0.pb +0 -0
  113. onnx/backend/test/data/node/test_cast_UINT4_to_FLOAT16/model.onnx +0 -0
  114. onnx/backend/test/data/node/test_cast_UINT4_to_FLOAT16/test_data_set_0/input_0.pb +0 -0
  115. onnx/backend/test/data/node/test_cast_UINT4_to_FLOAT16/test_data_set_0/output_0.pb +0 -0
  116. onnx/backend/test/data/node/test_cast_UINT4_to_UINT8/model.onnx +0 -0
  117. onnx/backend/test/data/node/test_cast_UINT4_to_UINT8/test_data_set_0/input_0.pb +0 -0
  118. onnx/backend/test/data/node/test_cast_UINT4_to_UINT8/test_data_set_0/output_0.pb +0 -0
  119. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN/model.onnx +0 -0
  120. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN/test_data_set_0/input_0.pb +2 -2
  121. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN/test_data_set_0/output_0.pb +0 -0
  122. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ/model.onnx +0 -0
  123. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ/test_data_set_0/input_0.pb +2 -2
  124. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ/test_data_set_0/output_0.pb +0 -0
  125. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2/model.onnx +0 -0
  126. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2/test_data_set_0/input_0.pb +2 -2
  127. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2/test_data_set_0/output_0.pb +0 -0
  128. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ/model.onnx +0 -0
  129. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ/test_data_set_0/input_0.pb +2 -2
  130. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ/test_data_set_0/output_0.pb +0 -0
  131. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN/model.onnx +0 -0
  132. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN/test_data_set_0/input_0.pb +0 -0
  133. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN/test_data_set_0/output_0.pb +0 -0
  134. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ/model.onnx +0 -0
  135. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ/test_data_set_0/input_0.pb +0 -0
  136. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ/test_data_set_0/output_0.pb +0 -0
  137. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E5M2/model.onnx +0 -0
  138. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E5M2/test_data_set_0/input_0.pb +0 -0
  139. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E5M2/test_data_set_0/output_0.pb +0 -0
  140. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ/model.onnx +0 -0
  141. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ/test_data_set_0/input_0.pb +0 -0
  142. onnx/backend/test/data/node/test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ/test_data_set_0/output_0.pb +0 -0
  143. onnx/backend/test/data/node/test_castlike_BFLOAT16_to_FLOAT/model.onnx +0 -0
  144. onnx/backend/test/data/node/test_castlike_BFLOAT16_to_FLOAT_expanded/model.onnx +0 -0
  145. onnx/backend/test/data/node/test_castlike_DOUBLE_to_FLOAT/model.onnx +0 -0
  146. onnx/backend/test/data/node/test_castlike_DOUBLE_to_FLOAT16/model.onnx +0 -0
  147. onnx/backend/test/data/node/test_castlike_DOUBLE_to_FLOAT16_expanded/model.onnx +0 -0
  148. onnx/backend/test/data/node/test_castlike_DOUBLE_to_FLOAT_expanded/model.onnx +0 -0
  149. onnx/backend/test/data/node/test_castlike_FLOAT16_to_DOUBLE/model.onnx +0 -0
  150. onnx/backend/test/data/node/test_castlike_FLOAT16_to_DOUBLE_expanded/model.onnx +0 -0
  151. onnx/backend/test/data/node/test_castlike_FLOAT16_to_FLOAT/model.onnx +0 -0
  152. onnx/backend/test/data/node/test_castlike_FLOAT16_to_FLOAT_expanded/model.onnx +0 -0
  153. onnx/backend/test/data/node/test_castlike_FLOAT8E4M3FNUZ_to_FLOAT/model.onnx +0 -0
  154. onnx/backend/test/data/node/test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_expanded/model.onnx +0 -0
  155. onnx/backend/test/data/node/test_castlike_FLOAT8E4M3FN_to_FLOAT/model.onnx +0 -0
  156. onnx/backend/test/data/node/test_castlike_FLOAT8E4M3FN_to_FLOAT_expanded/model.onnx +0 -0
  157. onnx/backend/test/data/node/test_castlike_FLOAT8E5M2FNUZ_to_FLOAT/model.onnx +0 -0
  158. onnx/backend/test/data/node/test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_expanded/model.onnx +0 -0
  159. onnx/backend/test/data/node/test_castlike_FLOAT8E5M2_to_FLOAT/model.onnx +0 -0
  160. onnx/backend/test/data/node/test_castlike_FLOAT8E5M2_to_FLOAT_expanded/model.onnx +0 -0
  161. onnx/backend/test/data/node/test_castlike_FLOAT_to_BFLOAT16/model.onnx +0 -0
  162. onnx/backend/test/data/node/test_castlike_FLOAT_to_BFLOAT16_expanded/model.onnx +0 -0
  163. onnx/backend/test/data/node/test_castlike_FLOAT_to_DOUBLE/model.onnx +0 -0
  164. onnx/backend/test/data/node/test_castlike_FLOAT_to_DOUBLE_expanded/model.onnx +0 -0
  165. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT16/model.onnx +0 -0
  166. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT16_expanded/model.onnx +0 -0
  167. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E4M3FN/model.onnx +0 -0
  168. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E4M3FNUZ/model.onnx +0 -0
  169. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_expanded/model.onnx +0 -0
  170. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E4M3FN_expanded/model.onnx +0 -0
  171. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E5M2/model.onnx +0 -0
  172. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E5M2FNUZ/model.onnx +0 -0
  173. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_expanded/model.onnx +0 -0
  174. onnx/backend/test/data/node/test_castlike_FLOAT_to_FLOAT8E5M2_expanded/model.onnx +0 -0
  175. onnx/backend/test/data/node/test_castlike_FLOAT_to_STRING/model.onnx +0 -0
  176. onnx/backend/test/data/node/test_castlike_FLOAT_to_STRING_expanded/model.onnx +0 -0
  177. onnx/backend/test/data/node/test_castlike_STRING_to_FLOAT/model.onnx +0 -0
  178. onnx/backend/test/data/node/test_castlike_STRING_to_FLOAT_expanded/model.onnx +0 -0
  179. onnx/backend/test/data/node/test_constant/model.onnx +0 -0
  180. onnx/backend/test/data/node/test_constant_pad/model.onnx +0 -0
  181. onnx/backend/test/data/node/test_constant_pad_axes/model.onnx +0 -0
  182. onnx/backend/test/data/node/test_constant_pad_negative_axes/model.onnx +0 -0
  183. onnx/backend/test/data/node/test_constantofshape_float_ones/model.onnx +0 -0
  184. onnx/backend/test/data/node/test_constantofshape_int_shape_zero/model.onnx +0 -0
  185. onnx/backend/test/data/node/test_constantofshape_int_zeros/model.onnx +0 -0
  186. onnx/backend/test/data/node/test_dequantizelinear/model.onnx +0 -0
  187. onnx/backend/test/data/node/test_dequantizelinear_axis/model.onnx +0 -0
  188. onnx/backend/test/data/node/test_dequantizelinear_blocked/model.onnx +0 -0
  189. onnx/backend/test/data/node/test_dequantizelinear_blocked/test_data_set_0/input_0.pb +1 -0
  190. onnx/backend/test/data/node/test_dequantizelinear_blocked/test_data_set_0/input_1.pb +0 -0
  191. onnx/backend/test/data/node/test_dequantizelinear_blocked/test_data_set_0/input_2.pb +0 -0
  192. onnx/backend/test/data/node/test_dequantizelinear_blocked/test_data_set_0/output_0.pb +0 -0
  193. onnx/backend/test/data/node/test_dequantizelinear_e4m3fn/model.onnx +0 -0
  194. onnx/backend/test/data/node/test_dequantizelinear_e4m3fn_float16/model.onnx +0 -0
  195. onnx/backend/test/data/node/test_dequantizelinear_e4m3fn_zero_point/model.onnx +0 -0
  196. onnx/backend/test/data/node/test_dequantizelinear_e5m2/model.onnx +0 -0
  197. onnx/backend/test/data/node/test_dequantizelinear_int16/model.onnx +0 -0
  198. onnx/backend/test/data/node/test_dequantizelinear_int16/test_data_set_0/input_0.pb +1 -0
  199. onnx/backend/test/data/node/test_dequantizelinear_int16/test_data_set_0/input_1.pb +0 -0
  200. onnx/backend/test/data/node/test_dequantizelinear_int16/test_data_set_0/input_2.pb +0 -0
  201. onnx/backend/test/data/node/test_dequantizelinear_int16/test_data_set_0/output_0.pb +0 -0
  202. onnx/backend/test/data/node/test_dequantizelinear_int4/model.onnx +0 -0
  203. onnx/backend/test/data/node/test_dequantizelinear_int4/test_data_set_0/input_0.pb +1 -0
  204. onnx/backend/test/data/node/test_dequantizelinear_int4/test_data_set_0/input_1.pb +0 -0
  205. onnx/backend/test/data/node/test_dequantizelinear_int4/test_data_set_0/input_2.pb +1 -0
  206. onnx/backend/test/data/node/test_dequantizelinear_int4/test_data_set_0/output_0.pb +0 -0
  207. onnx/backend/test/data/node/test_dequantizelinear_uint16/model.onnx +0 -0
  208. onnx/backend/test/data/node/test_dequantizelinear_uint16/test_data_set_0/input_0.pb +0 -0
  209. onnx/backend/test/data/node/test_dequantizelinear_uint16/test_data_set_0/input_1.pb +0 -0
  210. onnx/backend/test/data/node/test_dequantizelinear_uint16/test_data_set_0/input_2.pb +1 -0
  211. onnx/backend/test/data/node/test_dequantizelinear_uint16/test_data_set_0/output_0.pb +0 -0
  212. onnx/backend/test/data/node/test_dequantizelinear_uint4/model.onnx +0 -0
  213. onnx/backend/test/data/node/test_dequantizelinear_uint4/test_data_set_0/input_0.pb +1 -0
  214. onnx/backend/test/data/node/test_dequantizelinear_uint4/test_data_set_0/input_1.pb +0 -0
  215. onnx/backend/test/data/node/test_dequantizelinear_uint4/test_data_set_0/input_2.pb +1 -0
  216. onnx/backend/test/data/node/test_dequantizelinear_uint4/test_data_set_0/output_0.pb +0 -0
  217. onnx/backend/test/data/node/test_edge_pad/model.onnx +0 -0
  218. onnx/backend/test/data/node/test_flatten_axis0/model.onnx +0 -0
  219. onnx/backend/test/data/node/test_flatten_axis1/model.onnx +0 -0
  220. onnx/backend/test/data/node/test_flatten_axis2/model.onnx +0 -0
  221. onnx/backend/test/data/node/test_flatten_axis3/model.onnx +0 -0
  222. onnx/backend/test/data/node/test_flatten_default_axis/model.onnx +0 -0
  223. onnx/backend/test/data/node/test_flatten_negative_axis1/model.onnx +0 -0
  224. onnx/backend/test/data/node/test_flatten_negative_axis2/model.onnx +0 -0
  225. onnx/backend/test/data/node/test_flatten_negative_axis3/model.onnx +0 -0
  226. onnx/backend/test/data/node/test_flatten_negative_axis4/model.onnx +0 -0
  227. onnx/backend/test/data/node/test_group_normalization_epsilon/model.onnx +0 -0
  228. onnx/backend/test/data/node/test_group_normalization_epsilon/test_data_set_0/input_0.pb +1 -1
  229. onnx/backend/test/data/node/test_group_normalization_epsilon/test_data_set_0/input_1.pb +1 -1
  230. onnx/backend/test/data/node/test_group_normalization_epsilon/test_data_set_0/input_2.pb +1 -1
  231. onnx/backend/test/data/node/test_group_normalization_epsilon/test_data_set_0/output_0.pb +0 -0
  232. onnx/backend/test/data/node/test_group_normalization_epsilon_expanded/model.onnx +0 -0
  233. onnx/backend/test/data/node/test_group_normalization_epsilon_expanded/test_data_set_0/input_0.pb +1 -1
  234. onnx/backend/test/data/node/test_group_normalization_epsilon_expanded/test_data_set_0/input_1.pb +1 -1
  235. onnx/backend/test/data/node/test_group_normalization_epsilon_expanded/test_data_set_0/input_2.pb +1 -1
  236. onnx/backend/test/data/node/test_group_normalization_epsilon_expanded/test_data_set_0/output_0.pb +0 -0
  237. onnx/backend/test/data/node/test_group_normalization_example/model.onnx +0 -0
  238. onnx/backend/test/data/node/test_group_normalization_example/test_data_set_0/input_1.pb +1 -1
  239. onnx/backend/test/data/node/test_group_normalization_example/test_data_set_0/input_2.pb +1 -1
  240. onnx/backend/test/data/node/test_group_normalization_example/test_data_set_0/output_0.pb +0 -0
  241. onnx/backend/test/data/node/test_group_normalization_example_expanded/model.onnx +0 -0
  242. onnx/backend/test/data/node/test_group_normalization_example_expanded/test_data_set_0/input_1.pb +1 -1
  243. onnx/backend/test/data/node/test_group_normalization_example_expanded/test_data_set_0/input_2.pb +1 -1
  244. onnx/backend/test/data/node/test_group_normalization_example_expanded/test_data_set_0/output_0.pb +0 -0
  245. onnx/backend/test/data/node/test_identity/model.onnx +0 -0
  246. onnx/backend/test/data/node/test_identity_sequence/model.onnx +0 -0
  247. onnx/backend/test/data/node/test_lrn_default/test_data_set_0/output_0.pb +0 -0
  248. onnx/backend/test/data/node/test_maxpool_2d_ceil_output_size_reduce_by_one/model.onnx +0 -0
  249. onnx/backend/test/data/node/test_maxpool_2d_ceil_output_size_reduce_by_one/test_data_set_0/input_0.pb +0 -0
  250. onnx/backend/test/data/node/test_maxpool_2d_ceil_output_size_reduce_by_one/test_data_set_0/output_0.pb +0 -0
  251. onnx/backend/test/data/node/test_mvn/test_data_set_0/output_0.pb +1 -1
  252. onnx/backend/test/data/node/test_mvn_expanded/test_data_set_0/output_0.pb +1 -1
  253. onnx/backend/test/data/node/test_mvn_expanded_ver18/test_data_set_0/output_0.pb +1 -1
  254. onnx/backend/test/data/node/test_pow/test_data_set_0/output_0.pb +0 -0
  255. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/model.onnx +0 -0
  256. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_0.pb +1 -0
  257. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_1.pb +2 -0
  258. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_2.pb +1 -0
  259. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_3.pb +0 -0
  260. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_4.pb +2 -0
  261. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_5.pb +1 -0
  262. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_6.pb +2 -0
  263. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/input_7.pb +1 -0
  264. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float16/test_data_set_0/output_0.pb +1 -0
  265. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float32/model.onnx +0 -0
  266. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float32/test_data_set_0/input_0.pb +1 -0
  267. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float32/test_data_set_0/input_2.pb +1 -0
  268. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float32/test_data_set_0/input_3.pb +0 -0
  269. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float32/test_data_set_0/input_5.pb +1 -0
  270. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float32/test_data_set_0/input_7.pb +1 -0
  271. onnx/backend/test/data/node/test_qlinearmatmul_2D_int8_float32/test_data_set_0/output_0.pb +1 -0
  272. onnx/backend/test/data/node/test_qlinearmatmul_2D_uint8_float16/model.onnx +0 -0
  273. onnx/backend/test/data/node/test_qlinearmatmul_2D_uint8_float16/test_data_set_0/input_1.pb +2 -0
  274. onnx/backend/test/data/node/test_qlinearmatmul_2D_uint8_float16/test_data_set_0/input_4.pb +2 -0
  275. onnx/backend/test/data/node/test_qlinearmatmul_2D_uint8_float16/test_data_set_0/input_6.pb +2 -0
  276. onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_uint8_float32}/model.onnx +0 -0
  277. onnx/backend/test/data/node/test_qlinearmatmul_2D_uint8_float32/test_data_set_0/input_0.pb +0 -0
  278. onnx/backend/test/data/node/test_qlinearmatmul_2D_uint8_float32/test_data_set_0/input_3.pb +0 -0
  279. onnx/backend/test/data/node/test_qlinearmatmul_2D_uint8_float32/test_data_set_0/output_0.pb +1 -0
  280. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/model.onnx +0 -0
  281. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_0.pb +1 -0
  282. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_1.pb +2 -0
  283. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_2.pb +1 -0
  284. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_3.pb +0 -0
  285. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_4.pb +2 -0
  286. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_5.pb +1 -0
  287. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_6.pb +2 -0
  288. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/input_7.pb +1 -0
  289. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float16/test_data_set_0/output_0.pb +1 -0
  290. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/model.onnx +0 -0
  291. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_0.pb +1 -0
  292. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_1.pb +1 -0
  293. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_2.pb +1 -0
  294. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_3.pb +0 -0
  295. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_4.pb +1 -0
  296. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_5.pb +1 -0
  297. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_6.pb +1 -0
  298. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/input_7.pb +1 -0
  299. onnx/backend/test/data/node/test_qlinearmatmul_3D_int8_float32/test_data_set_0/output_0.pb +1 -0
  300. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float16/model.onnx +0 -0
  301. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float16/test_data_set_0/input_1.pb +2 -0
  302. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float16/test_data_set_0/input_2.pb +1 -0
  303. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float16/test_data_set_0/input_4.pb +2 -0
  304. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float16/test_data_set_0/input_5.pb +1 -0
  305. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float16/test_data_set_0/input_6.pb +2 -0
  306. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float16/test_data_set_0/input_7.pb +1 -0
  307. onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_3D_uint8_float32}/model.onnx +0 -0
  308. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_0.pb +0 -0
  309. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_1.pb +1 -0
  310. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_2.pb +1 -0
  311. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_3.pb +0 -0
  312. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_4.pb +1 -0
  313. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_5.pb +1 -0
  314. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_6.pb +1 -0
  315. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/input_7.pb +1 -0
  316. onnx/backend/test/data/node/test_qlinearmatmul_3D_uint8_float32/test_data_set_0/output_0.pb +1 -0
  317. onnx/backend/test/data/node/test_quantizelinear/model.onnx +0 -0
  318. onnx/backend/test/data/node/test_quantizelinear_axis/model.onnx +0 -0
  319. onnx/backend/test/data/node/test_quantizelinear_blocked_asymmetric/model.onnx +0 -0
  320. onnx/backend/test/data/node/test_quantizelinear_blocked_asymmetric/test_data_set_0/input_0.pb +0 -0
  321. onnx/backend/test/data/node/test_quantizelinear_blocked_asymmetric/test_data_set_0/input_1.pb +0 -0
  322. onnx/backend/test/data/node/test_quantizelinear_blocked_asymmetric/test_data_set_0/input_2.pb +0 -0
  323. onnx/backend/test/data/node/test_quantizelinear_blocked_asymmetric/test_data_set_0/output_0.pb +1 -0
  324. onnx/backend/test/data/node/test_quantizelinear_blocked_symmetric/model.onnx +0 -0
  325. onnx/backend/test/data/node/test_quantizelinear_blocked_symmetric/test_data_set_0/input_0.pb +0 -0
  326. onnx/backend/test/data/node/test_quantizelinear_blocked_symmetric/test_data_set_0/input_1.pb +0 -0
  327. onnx/backend/test/data/node/test_quantizelinear_blocked_symmetric/test_data_set_0/output_0.pb +0 -0
  328. onnx/backend/test/data/node/test_quantizelinear_e4m3fn/model.onnx +0 -0
  329. onnx/backend/test/data/node/test_quantizelinear_e4m3fn/test_data_set_0/input_2.pb +0 -0
  330. onnx/backend/test/data/node/test_quantizelinear_e4m3fn/test_data_set_0/output_0.pb +0 -0
  331. onnx/backend/test/data/node/test_quantizelinear_e5m2/model.onnx +0 -0
  332. onnx/backend/test/data/node/test_quantizelinear_e5m2/test_data_set_0/input_2.pb +0 -0
  333. onnx/backend/test/data/node/test_quantizelinear_e5m2/test_data_set_0/output_0.pb +0 -0
  334. onnx/backend/test/data/node/test_quantizelinear_int16/model.onnx +0 -0
  335. onnx/backend/test/data/node/test_quantizelinear_int16/test_data_set_0/input_0.pb +0 -0
  336. onnx/backend/test/data/node/test_quantizelinear_int16/test_data_set_0/input_1.pb +0 -0
  337. onnx/backend/test/data/node/test_quantizelinear_int16/test_data_set_0/input_2.pb +0 -0
  338. onnx/backend/test/data/node/test_quantizelinear_int16/test_data_set_0/output_0.pb +0 -0
  339. onnx/backend/test/data/node/test_quantizelinear_int4/model.onnx +0 -0
  340. onnx/backend/test/data/node/test_quantizelinear_int4/test_data_set_0/input_0.pb +0 -0
  341. onnx/backend/test/data/node/test_quantizelinear_int4/test_data_set_0/input_1.pb +0 -0
  342. onnx/backend/test/data/node/test_quantizelinear_int4/test_data_set_0/input_2.pb +1 -0
  343. onnx/backend/test/data/node/test_quantizelinear_int4/test_data_set_0/output_0.pb +1 -0
  344. onnx/backend/test/data/node/test_quantizelinear_uint16/model.onnx +0 -0
  345. onnx/backend/test/data/node/test_quantizelinear_uint16/test_data_set_0/input_0.pb +0 -0
  346. onnx/backend/test/data/node/test_quantizelinear_uint16/test_data_set_0/input_1.pb +0 -0
  347. onnx/backend/test/data/node/test_quantizelinear_uint16/test_data_set_0/input_2.pb +1 -0
  348. onnx/backend/test/data/node/test_quantizelinear_uint16/test_data_set_0/output_0.pb +0 -0
  349. onnx/backend/test/data/node/test_quantizelinear_uint4/model.onnx +0 -0
  350. onnx/backend/test/data/node/test_quantizelinear_uint4/test_data_set_0/input_0.pb +0 -0
  351. onnx/backend/test/data/node/test_quantizelinear_uint4/test_data_set_0/input_1.pb +0 -0
  352. onnx/backend/test/data/node/test_quantizelinear_uint4/test_data_set_0/input_2.pb +1 -0
  353. onnx/backend/test/data/node/test_quantizelinear_uint4/test_data_set_0/output_0.pb +0 -0
  354. onnx/backend/test/data/node/test_reflect_pad/model.onnx +0 -0
  355. onnx/backend/test/data/node/test_reshape_allowzero_reordered/model.onnx +0 -0
  356. onnx/backend/test/data/node/test_reshape_extended_dims/model.onnx +0 -0
  357. onnx/backend/test/data/node/test_reshape_negative_dim/model.onnx +0 -0
  358. onnx/backend/test/data/node/test_reshape_negative_extended_dims/model.onnx +0 -0
  359. onnx/backend/test/data/node/test_reshape_one_dim/model.onnx +0 -0
  360. onnx/backend/test/data/node/test_reshape_reduced_dims/model.onnx +0 -0
  361. onnx/backend/test/data/node/test_reshape_reordered_all_dims/model.onnx +0 -0
  362. onnx/backend/test/data/node/test_reshape_reordered_last_dims/model.onnx +0 -0
  363. onnx/backend/test/data/node/test_reshape_zero_and_negative_dim/model.onnx +0 -0
  364. onnx/backend/test/data/node/test_reshape_zero_dim/model.onnx +0 -0
  365. onnx/backend/test/data/node/test_shape/model.onnx +0 -0
  366. onnx/backend/test/data/node/test_shape_clip_end/model.onnx +0 -0
  367. onnx/backend/test/data/node/test_shape_clip_start/model.onnx +0 -0
  368. onnx/backend/test/data/node/test_shape_end_1/model.onnx +0 -0
  369. onnx/backend/test/data/node/test_shape_end_negative_1/model.onnx +0 -0
  370. onnx/backend/test/data/node/test_shape_example/model.onnx +0 -0
  371. onnx/backend/test/data/node/test_shape_start_1/model.onnx +0 -0
  372. onnx/backend/test/data/node/test_shape_start_1_end_2/model.onnx +0 -0
  373. onnx/backend/test/data/node/test_shape_start_1_end_negative_1/model.onnx +0 -0
  374. onnx/backend/test/data/node/test_shape_start_negative_1/model.onnx +0 -0
  375. onnx/backend/test/data/node/test_size/model.onnx +0 -0
  376. onnx/backend/test/data/node/test_size_example/model.onnx +0 -0
  377. onnx/backend/test/data/node/test_squeeze/model.onnx +0 -0
  378. onnx/backend/test/data/node/test_squeeze_negative_axes/model.onnx +0 -0
  379. onnx/backend/test/data/node/test_transpose_all_permutations_0/model.onnx +0 -0
  380. onnx/backend/test/data/node/test_transpose_all_permutations_1/model.onnx +0 -0
  381. onnx/backend/test/data/node/test_transpose_all_permutations_2/model.onnx +0 -0
  382. onnx/backend/test/data/node/test_transpose_all_permutations_3/model.onnx +0 -0
  383. onnx/backend/test/data/node/test_transpose_all_permutations_4/model.onnx +0 -0
  384. onnx/backend/test/data/node/test_transpose_all_permutations_5/model.onnx +0 -0
  385. onnx/backend/test/data/node/test_transpose_default/model.onnx +0 -0
  386. onnx/backend/test/data/node/test_unsqueeze_axis_0/model.onnx +0 -0
  387. onnx/backend/test/data/node/test_unsqueeze_axis_1/model.onnx +0 -0
  388. onnx/backend/test/data/node/test_unsqueeze_axis_2/model.onnx +0 -0
  389. onnx/backend/test/data/node/test_unsqueeze_negative_axes/model.onnx +0 -0
  390. onnx/backend/test/data/node/test_unsqueeze_three_axes/model.onnx +0 -0
  391. onnx/backend/test/data/node/test_unsqueeze_two_axes/model.onnx +0 -0
  392. onnx/backend/test/data/node/test_unsqueeze_unsorted_axes/model.onnx +0 -0
  393. onnx/backend/test/data/node/test_wrap_pad/model.onnx +0 -0
  394. onnx/backend/test/loader/__init__.py +0 -1
  395. onnx/backend/test/runner/__init__.py +43 -15
  396. onnx/checker.cc +104 -99
  397. onnx/checker.h +23 -3
  398. onnx/checker.py +56 -20
  399. onnx/common/assertions.cc +10 -5
  400. onnx/common/common.h +19 -0
  401. onnx/common/file_utils.h +3 -1
  402. onnx/common/interned_strings.h +7 -1
  403. onnx/common/ir.h +30 -7
  404. onnx/common/ir_pb_converter.cc +6 -0
  405. onnx/common/path.h +18 -2
  406. onnx/common/proto_util.h +43 -0
  407. onnx/common/version.h +1 -1
  408. onnx/cpp2py_export.cc +88 -56
  409. onnx/defs/__init__.py +29 -8
  410. onnx/defs/controlflow/defs.cc +16 -16
  411. onnx/defs/controlflow/old.cc +177 -0
  412. onnx/defs/data_propagators.h +2 -0
  413. onnx/defs/data_type_utils.cc +2 -0
  414. onnx/defs/generator/defs.cc +6 -4
  415. onnx/defs/generator/old.cc +115 -0
  416. onnx/defs/math/defs.cc +37 -142
  417. onnx/defs/math/old.cc +96 -12
  418. onnx/defs/math/utils.cc +127 -0
  419. onnx/defs/math/utils.h +8 -0
  420. onnx/defs/nn/defs.cc +72 -59
  421. onnx/defs/nn/old.cc +181 -2
  422. onnx/defs/object_detection/defs.cc +2 -2
  423. onnx/defs/object_detection/old.cc +2 -2
  424. onnx/defs/operator_sets.h +51 -0
  425. onnx/defs/operator_sets_ml.h +14 -0
  426. onnx/defs/parser.cc +112 -54
  427. onnx/defs/parser.h +14 -2
  428. onnx/defs/printer.cc +14 -7
  429. onnx/defs/quantization/defs.cc +111 -44
  430. onnx/defs/quantization/old.cc +130 -1
  431. onnx/defs/schema.cc +62 -18
  432. onnx/defs/schema.h +194 -48
  433. onnx/defs/shape_inference.cc +28 -19
  434. onnx/defs/shape_inference.h +2 -0
  435. onnx/defs/tensor/defs.cc +54 -96
  436. onnx/defs/tensor/old.cc +939 -34
  437. onnx/defs/tensor/utils.cc +6 -3
  438. onnx/defs/tensor/utils.h +5 -1
  439. onnx/defs/tensor_proto_util.cc +2 -0
  440. onnx/defs/tensor_util.cc +2 -0
  441. onnx/defs/traditionalml/defs.cc +273 -117
  442. onnx/defs/traditionalml/old.cc +329 -14
  443. onnx/defs/traditionalml/utils.h +27 -0
  444. onnx/external_data_helper.py +12 -26
  445. onnx/helper.py +242 -169
  446. onnx/hub.py +104 -70
  447. onnx/inliner/inliner.cc +89 -31
  448. onnx/inliner/inliner.h +5 -0
  449. onnx/inliner.py +2 -0
  450. onnx/mapping.py +9 -0
  451. onnx/model_container.py +346 -0
  452. onnx/numpy_helper.py +100 -38
  453. onnx/onnx-ml.proto +50 -13
  454. onnx/onnx.in.proto +50 -13
  455. onnx/onnx.proto +50 -13
  456. onnx/onnx_cpp2py_export/__init__.pyi +5 -0
  457. onnx/onnx_cpp2py_export/checker.pyi +21 -0
  458. onnx/onnx_cpp2py_export/defs.pyi +202 -0
  459. onnx/onnx_cpp2py_export/inliner.pyi +19 -0
  460. onnx/onnx_cpp2py_export/parser.pyi +32 -0
  461. onnx/onnx_cpp2py_export/printer.pyi +3 -0
  462. onnx/onnx_cpp2py_export/shape_inference.pyi +16 -0
  463. onnx/onnx_cpp2py_export/version_converter.pyi +4 -0
  464. onnx/onnx_cpp2py_export.cp311-win_amd64.pyd +0 -0
  465. onnx/onnx_data_pb2.pyi +146 -0
  466. onnx/onnx_ml_pb2.py +52 -52
  467. onnx/onnx_ml_pb2.pyi +663 -0
  468. onnx/onnx_operators_ml_pb2.pyi +67 -0
  469. onnx/reference/__init__.py +2 -0
  470. onnx/reference/custom_element_types.py +2 -0
  471. onnx/reference/op_run.py +166 -121
  472. onnx/reference/ops/_op.py +27 -50
  473. onnx/reference/ops/_op_list.py +36 -24
  474. onnx/reference/ops/aionnx_preview_training/_op_list.py +15 -8
  475. onnx/reference/ops/aionnxml/_common_classifier.py +3 -5
  476. onnx/reference/ops/aionnxml/_op_list.py +16 -8
  477. onnx/reference/ops/aionnxml/op_array_feature_extractor.py +4 -6
  478. onnx/reference/ops/aionnxml/op_linear_classifier.py +1 -2
  479. onnx/reference/ops/aionnxml/op_normalizer.py +3 -3
  480. onnx/reference/ops/aionnxml/op_svm_helper.py +1 -3
  481. onnx/reference/ops/aionnxml/op_svm_regressor.py +1 -3
  482. onnx/reference/ops/aionnxml/op_tree_ensemble.py +257 -0
  483. onnx/reference/ops/aionnxml/op_tree_ensemble_helper.py +2 -6
  484. onnx/reference/ops/aionnxml/op_tree_ensemble_regressor.py +4 -4
  485. onnx/reference/ops/experimental/_op_list.py +15 -8
  486. onnx/reference/ops/op_blackman_window.py +5 -6
  487. onnx/reference/ops/op_cast.py +22 -0
  488. onnx/reference/ops/op_cast_like.py +6 -0
  489. onnx/reference/ops/op_clip.py +5 -8
  490. onnx/reference/ops/op_col2im.py +1 -3
  491. onnx/reference/ops/op_constant.py +7 -1
  492. onnx/reference/ops/op_dequantize_linear.py +43 -40
  493. onnx/reference/ops/op_det.py +1 -1
  494. onnx/reference/ops/op_dynamic_quantize_linear.py +2 -2
  495. onnx/reference/ops/op_grid_sample.py +2 -4
  496. onnx/reference/ops/op_hamming_window.py +3 -6
  497. onnx/reference/ops/op_hann_window.py +3 -6
  498. onnx/reference/ops/op_if.py +4 -3
  499. onnx/reference/ops/op_loop.py +7 -9
  500. onnx/reference/ops/op_matmul.py +1 -2
  501. onnx/reference/ops/op_max_pool.py +5 -0
  502. onnx/reference/ops/op_optional.py +1 -1
  503. onnx/reference/ops/op_pool_common.py +3 -6
  504. onnx/reference/ops/op_qlinear_matmul.py +2 -2
  505. onnx/reference/ops/op_quantize_linear.py +166 -71
  506. onnx/reference/ops/op_resize.py +25 -21
  507. onnx/reference/ops/op_rnn.py +20 -12
  508. onnx/reference/ops/op_scan.py +23 -15
  509. onnx/reference/ops/op_scatter_elements.py +7 -6
  510. onnx/reference/ops/op_stft.py +3 -5
  511. onnx/reference/ops/op_string_normalizer.py +7 -7
  512. onnx/reference/ops/op_tfidf_vectorizer.py +7 -8
  513. onnx/reference/ops/op_topk.py +9 -11
  514. onnx/reference/ops/op_unique.py +1 -1
  515. onnx/reference/reference_evaluator.py +119 -63
  516. onnx/shape_inference/implementation.cc +160 -127
  517. onnx/shape_inference.py +11 -10
  518. onnx/subbyte.py +72 -0
  519. onnx/test/__init__.pyi +6 -0
  520. onnx/test/checker_test.py +21 -1
  521. onnx/test/compose_test.py +26 -74
  522. onnx/test/cpp/inliner_test.cc +76 -1
  523. onnx/test/cpp/ir_test.cc +60 -0
  524. onnx/test/cpp/parser_test.cc +106 -0
  525. onnx/test/function_test.py +1 -3
  526. onnx/test/helper_test.py +64 -4
  527. onnx/test/model_container_refeval_test.py +139 -0
  528. onnx/test/model_container_test.py +136 -0
  529. onnx/test/model_inference_test.py +44 -0
  530. onnx/test/reference_evaluator_ml_test.py +448 -47
  531. onnx/test/reference_evaluator_model_test.py +130 -0
  532. onnx/test/reference_evaluator_test.py +901 -14
  533. onnx/test/schema_test.py +166 -1
  534. onnx/test/shape_inference_test.py +285 -6
  535. onnx/test/symbolic_shape_test.py +3 -8
  536. onnx/test/test_backend_onnxruntime.py +238 -224
  537. onnx/test/test_backend_reference.py +11 -0
  538. onnx/test/test_external_data.py +51 -2
  539. onnx/test/version_converter/automatic_conversion_test_base.py +2 -1
  540. onnx/test/version_converter/automatic_upgrade_test.py +12 -10
  541. onnx/test/version_converter_test.py +166 -0
  542. onnx/tools/replace_constants.py +23 -26
  543. onnx/tools/update_model_dims.py +1 -2
  544. onnx/version.py +2 -2
  545. onnx/version_converter/adapters/group_normalization_20_21.h +128 -0
  546. onnx/version_converter/adapters/q_dq_21_20.h +77 -0
  547. onnx/version_converter/convert.h +67 -2
  548. onnx/version_converter.py +6 -142
  549. {onnx-1.15.0.dist-info → onnx-1.16.1.dist-info}/METADATA +18 -15
  550. {onnx-1.15.0.dist-info → onnx-1.16.1.dist-info}/RECORD +572 -406
  551. {onnx-1.15.0.dist-info → onnx-1.16.1.dist-info}/WHEEL +1 -1
  552. onnx/examples/Protobufs.ipynb +0 -639
  553. onnx/examples/check_model.ipynb +0 -128
  554. onnx/examples/load_model.ipynb +0 -116
  555. onnx/examples/make_model.ipynb +0 -176
  556. onnx/examples/np_array_tensorproto.ipynb +0 -136
  557. onnx/examples/resources/single_relu.onnx +0 -12
  558. onnx/examples/resources/single_relu_new.onnx +0 -12
  559. onnx/examples/resources/tensor.pb +0 -0
  560. onnx/examples/resources/two_transposes.onnx +0 -0
  561. onnx/examples/save_model.ipynb +0 -56
  562. onnx/examples/shape_inference.ipynb +0 -111
  563. onnx/test/reference_evaluator_backend_test.py +0 -876
  564. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_int8_float32}/test_data_set_0/input_1.pb +0 -0
  565. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_int8_float32}/test_data_set_0/input_4.pb +0 -0
  566. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_int8_float32}/test_data_set_0/input_6.pb +0 -0
  567. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_uint8_float16}/test_data_set_0/input_0.pb +0 -0
  568. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_uint8_float16}/test_data_set_0/input_2.pb +0 -0
  569. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_uint8_float16}/test_data_set_0/input_3.pb +0 -0
  570. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_uint8_float16}/test_data_set_0/input_5.pb +0 -0
  571. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_uint8_float16}/test_data_set_0/input_7.pb +0 -0
  572. /onnx/backend/test/data/node/{test_qlinearmatmul_2D → test_qlinearmatmul_2D_uint8_float16}/test_data_set_0/output_0.pb +0 -0
  573. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_2D_uint8_float32}/test_data_set_0/input_1.pb +0 -0
  574. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_2D_uint8_float32}/test_data_set_0/input_2.pb +0 -0
  575. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_2D_uint8_float32}/test_data_set_0/input_4.pb +0 -0
  576. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_2D_uint8_float32}/test_data_set_0/input_5.pb +0 -0
  577. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_2D_uint8_float32}/test_data_set_0/input_6.pb +0 -0
  578. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_2D_uint8_float32}/test_data_set_0/input_7.pb +0 -0
  579. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_3D_uint8_float16}/test_data_set_0/input_0.pb +0 -0
  580. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_3D_uint8_float16}/test_data_set_0/input_3.pb +0 -0
  581. /onnx/backend/test/data/node/{test_qlinearmatmul_3D → test_qlinearmatmul_3D_uint8_float16}/test_data_set_0/output_0.pb +0 -0
  582. {onnx-1.15.0.dist-info → onnx-1.16.1.dist-info}/LICENSE +0 -0
  583. {onnx-1.15.0.dist-info → onnx-1.16.1.dist-info}/entry_points.txt +0 -0
  584. {onnx-1.15.0.dist-info → onnx-1.16.1.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@
3
3
  */
4
4
 
5
5
  #include "onnx/defs/schema.h"
6
+ #include "onnx/defs/traditionalml/utils.h"
6
7
 
7
8
  #ifdef ONNX_ML
8
9
  namespace ONNX_NAMESPACE {
@@ -48,7 +49,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
48
49
  num_indices *= indices_shape.dim(i).dim_value();
49
50
  } else if (indices_shape.dim(i).has_dim_param()) {
50
51
  if (single_symbolic_dim.empty()) {
51
- // it is possible to set symbolic dimension param if the rest dim values are all value 1
52
+ // it is possible to set symbolic dimension param if the rest dim values are all
53
+ // value 1
52
54
  single_symbolic_dim = indices_shape.dim(i).dim_param();
53
55
  } else {
54
56
  return;
@@ -111,12 +113,14 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
111
113
  "The output is a 1-D tensor of string, float, or integer.")
112
114
  .Attr(
113
115
  "cast_to",
114
- "A string indicating the desired element type of the output tensor, one of 'TO_FLOAT', 'TO_STRING', 'TO_INT64'.",
116
+ "A string indicating the desired element type of the output tensor, one of 'TO_FLOAT', 'TO_STRING', "
117
+ "'TO_INT64'.",
115
118
  AttributeProto::STRING,
116
119
  std::string("TO_FLOAT"))
117
120
  .Attr(
118
121
  "map_form",
119
- "Indicates whether to only output as many values as are in the input (dense), or position the input based on using the key of the map as the index of the output (sparse).<br>One of 'DENSE', 'SPARSE'.",
122
+ "Indicates whether to only output as many values as are in the input (dense), or position the input based "
123
+ "on using the key of the map as the index of the output (sparse).<br>One of 'DENSE', 'SPARSE'.",
120
124
  AttributeProto::STRING,
121
125
  std::string("DENSE"))
122
126
  .Attr(
@@ -179,12 +183,14 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
179
183
  OPTIONAL_VALUE)
180
184
  .Attr(
181
185
  "default_string",
182
- "A string to use when an input integer value is not found in the map.<br>One and only one of the 'default_*' attributes must be defined.",
186
+ "A string to use when an input integer value is not found in the map.<br>One and only one of the "
187
+ "'default_*' attributes must be defined.",
183
188
  AttributeProto::STRING,
184
189
  std::string("_Unused"))
185
190
  .Attr(
186
191
  "default_int64",
187
- "An integer to use when an input string value is not found in the map.<br>One and only one of the 'default_*' attributes must be defined.",
192
+ "An integer to use when an input string value is not found in the map.<br>One and only one of the "
193
+ "'default_*' attributes must be defined.",
188
194
  AttributeProto::INT,
189
195
  static_cast<int64_t>(-1))
190
196
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
@@ -230,11 +236,13 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
230
236
  "map(int64, double)",
231
237
  "map(string, float)",
232
238
  "map(string, double)"},
233
- "The input must be a map from strings or integers to either strings or a numeric type. The key and value types cannot be the same.")
239
+ "The input must be a map from strings or integers to either strings or a numeric type. The key and value "
240
+ "types cannot be the same.")
234
241
  .TypeConstraint(
235
242
  "T2",
236
243
  {"tensor(int64)", "tensor(float)", "tensor(double)", "tensor(string)"},
237
- "The output will be a tensor of the value type of the input map. It's shape will be [1,C], where C is the length of the input dictionary.")
244
+ "The output will be a tensor of the value type of the input map. It's shape will be [1,C], where C is the "
245
+ "length of the input dictionary.")
238
246
  .Attr(
239
247
  "string_vocabulary",
240
248
  "A string vocabulary array.<br>One and only one of the vocabularies must be defined.",
@@ -292,7 +300,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
292
300
  .TypeConstraint(
293
301
  "T",
294
302
  {"tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
295
- "The input type must be a tensor of a numeric type, either [N,C] or [C]. The output type will be of the same tensor type and shape.")
303
+ "The input type must be a tensor of a numeric type, either [N,C] or [C]. The output type will be of the "
304
+ "same tensor type and shape.")
296
305
  .Attr("imputed_value_floats", "Value(s) to change to", AttributeProto::FLOATS, OPTIONAL_VALUE)
297
306
  .Attr("replaced_value_float", "A value that needs replacing.", AttributeProto::FLOAT, 0.f)
298
307
  .Attr("imputed_value_int64s", "Value(s) to change to.", AttributeProto::INTS, OPTIONAL_VALUE)
@@ -356,9 +365,10 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
356
365
  .Attr("default_float", "A float.", AttributeProto::FLOAT, -0.f)
357
366
  .Attr(
358
367
  "default_tensor",
359
- "A default tensor.",
360
- "{\"_Unused\"} if values_* has string type, {-1} if values_* has integral type, and {-0.f} if values_* has float type.",
361
- AttributeProto::TENSOR)
368
+ "A default tensor. {\"_Unused\"} if values_* has string type, {-1} if values_* has integral type, and "
369
+ "{-0.f} if values_* has float type.",
370
+ AttributeProto::TENSOR,
371
+ OPTIONAL_VALUE)
362
372
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
363
373
  int key_length, key_type;
364
374
  std::tie(key_type, key_length) =
@@ -427,7 +437,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
427
437
  .TypeConstraint(
428
438
  "T1",
429
439
  {"tensor(float)", "tensor(double)", "tensor(int64)", "tensor(int32)"},
430
- "The input must be a tensor of a numeric type, and of shape [N,C] or [C]. In the latter case, it will be treated as [1,C]")
440
+ "The input must be a tensor of a numeric type, and of shape [N,C] or [C]. In the latter case, it will be "
441
+ "treated as [1,C]")
431
442
  .TypeConstraint(
432
443
  "T2",
433
444
  {"tensor(string)", "tensor(int64)"},
@@ -451,7 +462,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
451
462
  OPTIONAL_VALUE)
452
463
  .Attr(
453
464
  "post_transform",
454
- "Indicates the transform to apply to the scores vector.<br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT'",
465
+ "Indicates the transform to apply to the scores vector.<br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' "
466
+ "'SOFTMAX_ZERO,' or 'PROBIT'",
455
467
  AttributeProto::STRING,
456
468
  std::string("NONE"))
457
469
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
@@ -529,7 +541,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
529
541
  "The input must be a tensor of a numeric type.")
530
542
  .Attr(
531
543
  "post_transform",
532
- "Indicates the transform to apply to the regression output vector.<br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT'",
544
+ "Indicates the transform to apply to the regression output vector.<br>One of 'NONE,' 'SOFTMAX,' "
545
+ "'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT'",
533
546
  AttributeProto::STRING,
534
547
  std::string("NONE"))
535
548
  .Attr("coefficients", "Weights of the model(s).", AttributeProto::FLOATS, OPTIONAL_VALUE)
@@ -600,7 +613,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
600
613
  OPTIONAL_VALUE)
601
614
  .Attr(
602
615
  "zeros",
603
- "If true and category is not present, will return all zeros; if false and a category if not found, the operator will fail.",
616
+ "If true and category is not present, will return all zeros; if false and a category if not found, the "
617
+ "operator will fail.",
604
618
  AttributeProto::INT,
605
619
  static_cast<int64_t>(1))
606
620
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
@@ -637,12 +651,14 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
637
651
  "The input must be a tensor of a numeric type.")
638
652
  .Attr(
639
653
  "offset",
640
- "First, offset by this.<br>Can be length of features in an [N,F] tensor or length 1, in which case it applies to all features, regardless of dimension count.",
654
+ "First, offset by this.<br>Can be length of features in an [N,F] tensor or length 1, in which case it "
655
+ "applies to all features, regardless of dimension count.",
641
656
  AttributeProto::FLOATS,
642
657
  OPTIONAL_VALUE)
643
658
  .Attr(
644
659
  "scale",
645
- "Second, multiply by this.<br>Can be length of features in an [N,F] tensor or length 1, in which case it applies to all features, regardless of dimension count.<br>Must be same length as 'offset'",
660
+ "Second, multiply by this.<br>Can be length of features in an [N,F] tensor or length 1, in which case it "
661
+ "applies to all features, regardless of dimension count.<br>Must be same length as 'offset'",
646
662
  AttributeProto::FLOATS,
647
663
  OPTIONAL_VALUE));
648
664
 
@@ -660,7 +676,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
660
676
  .Output(
661
677
  1,
662
678
  "Z",
663
- "Class scores (one per class per example), if prob_a and prob_b are provided they are probabilities for each class, otherwise they are raw scores.",
679
+ "Class scores (one per class per example), if prob_a and prob_b are provided they are probabilities for "
680
+ "each class, otherwise they are raw scores.",
664
681
  "tensor(float)")
665
682
  .TypeConstraint(
666
683
  "T1",
@@ -669,7 +686,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
669
686
  .TypeConstraint(
670
687
  "T2",
671
688
  {"tensor(string)", "tensor(int64)"},
672
- "The output type will be a tensor of strings or integers, depending on which of the classlabels_* attributes is used. Its size will match the bactch size of the input.")
689
+ "The output type will be a tensor of strings or integers, depending on which of the classlabels_* "
690
+ "attributes is used. Its size will match the bactch size of the input.")
673
691
  .Attr(
674
692
  "kernel_type",
675
693
  "The kernel type, one of 'LINEAR,' 'POLY,' 'RBF,' 'SIGMOID'.",
@@ -686,23 +704,27 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
686
704
  .Attr("prob_a", "First set of probability coefficients.", AttributeProto::FLOATS, OPTIONAL_VALUE)
687
705
  .Attr(
688
706
  "prob_b",
689
- "Second set of probability coefficients. This array must be same size as prob_a.<br>If these are provided then output Z are probability estimates, otherwise they are raw scores.",
707
+ "Second set of probability coefficients. This array must be same size as prob_a.<br>If these are provided "
708
+ "then output Z are probability estimates, otherwise they are raw scores.",
690
709
  AttributeProto::FLOATS,
691
710
  OPTIONAL_VALUE)
692
711
  .Attr("rho", "", AttributeProto::FLOATS, OPTIONAL_VALUE)
693
712
  .Attr(
694
713
  "post_transform",
695
- "Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT'",
714
+ "Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
715
+ "or 'PROBIT'",
696
716
  AttributeProto::STRING,
697
717
  std::string("NONE"))
698
718
  .Attr(
699
719
  "classlabels_strings",
700
- "Class labels if using string labels.<br>One and only one of the 'classlabels_*' attributes must be defined.",
720
+ "Class labels if using string labels.<br>One and only one of the 'classlabels_*' attributes must be "
721
+ "defined.",
701
722
  AttributeProto::STRINGS,
702
723
  OPTIONAL_VALUE)
703
724
  .Attr(
704
725
  "classlabels_ints",
705
- "Class labels if using integer labels.<br>One and only one of the 'classlabels_*' attributes must be defined.",
726
+ "Class labels if using integer labels.<br>One and only one of the 'classlabels_*' attributes must be "
727
+ "defined.",
706
728
  AttributeProto::INTS,
707
729
  OPTIONAL_VALUE)
708
730
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
@@ -752,12 +774,16 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
752
774
  .Attr("n_supports", "The number of support vectors.", AttributeProto::INT, static_cast<int64_t>(0))
753
775
  .Attr(
754
776
  "post_transform",
755
- "Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT.'",
777
+ "Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
778
+ "or 'PROBIT.'",
756
779
  AttributeProto::STRING,
757
780
  std::string("NONE"))
758
781
  .Attr("rho", "", AttributeProto::FLOATS, OPTIONAL_VALUE));
759
782
 
760
- static const char* TreeEnsembleClassifier_ver3_doc = R"DOC(
783
+ static const char* TreeEnsembleClassifier_ver5_doc = R"DOC(
784
+ This operator is DEPRECATED. Please use TreeEnsemble with provides similar functionality.
785
+ In order to determine the top class, the ArgMax node can be applied to the output of TreeEnsemble.
786
+ To encode class labels, use a LabelEncoder operator.
761
787
  Tree Ensemble classifier. Returns the top class for each of N inputs.<br>
762
788
  The attributes named 'nodes_X' form a sequence of tuples, associated by
763
789
  index into the sequences, which must all be of equal length. These tuples
@@ -773,9 +799,10 @@ static const char* TreeEnsembleClassifier_ver3_doc = R"DOC(
773
799
 
774
800
  ONNX_ML_OPERATOR_SET_SCHEMA(
775
801
  TreeEnsembleClassifier,
776
- 3,
802
+ 5,
777
803
  OpSchema()
778
- .SetDoc(TreeEnsembleClassifier_ver3_doc)
804
+ .Deprecate()
805
+ .SetDoc(TreeEnsembleClassifier_ver5_doc)
779
806
  .Input(0, "X", "Input of shape [N,F]", "T1")
780
807
  .Output(0, "Y", "N, Top class for each point", "T2")
781
808
  .Output(1, "Z", "The class score for each class, for each point, a tensor of shape [N,E].", "tensor(float)")
@@ -786,7 +813,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
786
813
  .TypeConstraint(
787
814
  "T2",
788
815
  {"tensor(string)", "tensor(int64)"},
789
- "The output type will be a tensor of strings or integers, depending on which of the classlabels_* attributes is used.")
816
+ "The output type will be a tensor of strings or integers, depending on which of the classlabels_* "
817
+ "attributes is used.")
790
818
  .Attr("nodes_treeids", "Tree id for each node.", AttributeProto::INTS, OPTIONAL_VALUE)
791
819
  .Attr(
792
820
  "nodes_nodeids",
@@ -816,14 +844,17 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
816
844
  OPTIONAL_VALUE)
817
845
  .Attr(
818
846
  "nodes_modes",
819
- "The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
847
+ "The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf "
848
+ "node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
820
849
  AttributeProto::STRINGS,
821
850
  OPTIONAL_VALUE)
822
851
  .Attr("nodes_truenodeids", "Child node if expression is true.", AttributeProto::INTS, OPTIONAL_VALUE)
823
852
  .Attr("nodes_falsenodeids", "Child node if expression is false.", AttributeProto::INTS, OPTIONAL_VALUE)
824
853
  .Attr(
825
854
  "nodes_missing_value_tracks_true",
826
- "For each node, define what to do in the presence of a missing value: if a value is missing (NaN), use the 'true' or 'false' branch based on the value in this array.<br>This attribute may be left undefined, and the default value is false (0) for all nodes.",
855
+ "For each node, define what to do in the presence of a missing value: if a value is missing (NaN), use the "
856
+ "'true' or 'false' branch based on the value in this array.<br>This attribute may be left undefined, and "
857
+ "the default value is false (0) for all nodes.",
827
858
  AttributeProto::INTS,
828
859
  OPTIONAL_VALUE)
829
860
  .Attr("class_treeids", "The id of the tree that this node is in.", AttributeProto::INTS, OPTIONAL_VALUE)
@@ -837,85 +868,38 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
837
868
  OPTIONAL_VALUE)
838
869
  .Attr(
839
870
  "classlabels_strings",
840
- "Class labels if using string labels.<br>One and only one of the 'classlabels_*' attributes must be defined.",
871
+ "Class labels if using string labels.<br>One and only one of the 'classlabels_*' attributes must be "
872
+ "defined.",
841
873
  AttributeProto::STRINGS,
842
874
  OPTIONAL_VALUE)
843
875
  .Attr(
844
876
  "classlabels_int64s",
845
- "Class labels if using integer labels.<br>One and only one of the 'classlabels_*' attributes must be defined.",
877
+ "Class labels if using integer labels.<br>One and only one of the 'classlabels_*' attributes must be "
878
+ "defined.",
846
879
  AttributeProto::INTS,
847
880
  OPTIONAL_VALUE)
848
881
  .Attr(
849
882
  "post_transform",
850
- "Indicates the transform to apply to the score. <br> One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT.'",
883
+ "Indicates the transform to apply to the score. <br> One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
884
+ "or 'PROBIT.'",
851
885
  AttributeProto::STRING,
852
886
  std::string("NONE"))
853
887
  .Attr(
854
888
  "base_values",
855
- "Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
889
+ "Base values for classification, added to final class score; the size must be the same as the classes or "
890
+ "can be left unassigned (assumed 0)",
856
891
  AttributeProto::FLOATS,
857
892
  OPTIONAL_VALUE)
858
893
  .Attr(
859
894
  "base_values_as_tensor",
860
- "Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
895
+ "Base values for classification, added to final class score; the size must be the same as the classes or "
896
+ "can be left unassigned (assumed 0)",
861
897
  AttributeProto::TENSOR,
862
- OPTIONAL_VALUE)
863
- .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
864
- auto* nodes_values = ctx.getAttribute("nodes_values");
865
- auto* nodes_values_as_tensor = ctx.getAttribute("nodes_values_as_tensor");
866
- auto* nodes_hitrates = ctx.getAttribute("nodes_hitrates");
867
- auto* nodes_hitrates_as_tensor = ctx.getAttribute("nodes_hitrates_as_tensor");
868
- auto* class_weights = ctx.getAttribute("class_weights");
869
- auto* class_weights_as_tensor = ctx.getAttribute("class_weights_as_tensor");
870
- auto* base_values = ctx.getAttribute("base_values");
871
- auto* base_values_as_tensor = ctx.getAttribute("base_values_as_tensor");
872
-
873
- if (nullptr != nodes_values && nullptr != nodes_values_as_tensor) {
874
- fail_shape_inference(
875
- "Only one of the attributes 'nodes_values', 'nodes_values_as_tensor' should be specified.");
876
- }
877
- if (nullptr != nodes_hitrates && nullptr != nodes_hitrates_as_tensor) {
878
- fail_shape_inference(
879
- "Only one of the attributes 'nodes_hitrates', 'nodes_hitrates_as_tensor' should be specified.");
880
- }
881
- if (nullptr != class_weights && nullptr != class_weights_as_tensor) {
882
- fail_shape_inference(
883
- "Only one of the attributes 'class_weights', 'class_weights_as_tensor' should be specified.");
884
- }
885
- if (nullptr != base_values && nullptr != base_values_as_tensor) {
886
- fail_shape_inference(
887
- "Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified.");
888
- }
889
-
890
- std::vector<std::string> classlabels_strings;
891
- auto result = getRepeatedAttribute(ctx, "classlabels_strings", classlabels_strings);
892
- bool using_strings = (result && !classlabels_strings.empty());
893
- if (using_strings) {
894
- updateOutputElemType(ctx, 0, TensorProto::STRING);
895
- } else {
896
- updateOutputElemType(ctx, 0, TensorProto::INT64);
897
- }
898
- updateOutputElemType(ctx, 1, TensorProto::FLOAT);
899
-
900
- checkInputRank(ctx, 0, 2);
901
- Dim N, E;
902
- unifyInputDim(ctx, 0, 0, N);
903
-
904
- if (using_strings) {
905
- unifyDim(E, classlabels_strings.size());
906
- } else {
907
- std::vector<int64_t> classlabels_int64s;
908
- result = getRepeatedAttribute(ctx, "classlabels_int64s", classlabels_int64s);
909
- if (!result || classlabels_int64s.empty()) {
910
- fail_shape_inference("Non of classlabels_int64s or classlabels_strings is set.");
911
- }
912
- unifyDim(E, classlabels_int64s.size());
913
- }
914
- updateOutputShape(ctx, 0, {N});
915
- updateOutputShape(ctx, 1, {N, E});
916
- }));
898
+ OPTIONAL_VALUE));
917
899
 
918
- static const char* TreeEnsembleRegressor_ver3_doc = R"DOC(
900
+ static const char* TreeEnsembleRegressor_ver5_doc = R"DOC(
901
+ This operator is DEPRECATED. Please use TreeEnsemble instead which provides the same
902
+ functionality.<br>
919
903
  Tree Ensemble regressor. Returns the regressed values for each input in N.<br>
920
904
  All args with nodes_ are fields of a tuple of tree nodes, and
921
905
  it is assumed they are the same length, and an index i will decode the
@@ -932,9 +916,10 @@ static const char* TreeEnsembleRegressor_ver3_doc = R"DOC(
932
916
 
933
917
  ONNX_ML_OPERATOR_SET_SCHEMA(
934
918
  TreeEnsembleRegressor,
935
- 3,
919
+ 5,
936
920
  OpSchema()
937
- .SetDoc(TreeEnsembleRegressor_ver3_doc)
921
+ .Deprecate()
922
+ .SetDoc(TreeEnsembleRegressor_ver5_doc)
938
923
  .Input(0, "X", "Input of shape [N,F]", "T")
939
924
  .Output(0, "Y", "N classes", "tensor(float)")
940
925
  .TypeConstraint(
@@ -970,14 +955,17 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
970
955
  OPTIONAL_VALUE)
971
956
  .Attr(
972
957
  "nodes_modes",
973
- "The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
958
+ "The node kind, that is, the comparison to make at the node. There is no comparison to make at a leaf "
959
+ "node.<br>One of 'BRANCH_LEQ', 'BRANCH_LT', 'BRANCH_GTE', 'BRANCH_GT', 'BRANCH_EQ', 'BRANCH_NEQ', 'LEAF'",
974
960
  AttributeProto::STRINGS,
975
961
  OPTIONAL_VALUE)
976
962
  .Attr("nodes_truenodeids", "Child node if expression is true", AttributeProto::INTS, OPTIONAL_VALUE)
977
963
  .Attr("nodes_falsenodeids", "Child node if expression is false", AttributeProto::INTS, OPTIONAL_VALUE)
978
964
  .Attr(
979
965
  "nodes_missing_value_tracks_true",
980
- "For each node, define what to do in the presence of a NaN: use the 'true' (if the attribute value is 1) or 'false' (if the attribute value is 0) branch based on the value in this array.<br>This attribute may be left undefined and the default value is false (0) for all nodes.",
966
+ "For each node, define what to do in the presence of a NaN: use the 'true' (if the attribute value is 1) "
967
+ "or 'false' (if the attribute value is 0) branch based on the value in this array.<br>This attribute may "
968
+ "be left undefined and the default value is false (0) for all nodes.",
981
969
  AttributeProto::INTS,
982
970
  OPTIONAL_VALUE)
983
971
  .Attr("target_treeids", "The id of the tree that each node is in.", AttributeProto::INTS, OPTIONAL_VALUE)
@@ -988,7 +976,8 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
988
976
  .Attr("n_targets", "The total number of targets.", AttributeProto::INT, OPTIONAL_VALUE)
989
977
  .Attr(
990
978
  "post_transform",
991
- "Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' or 'PROBIT'",
979
+ "Indicates the transform to apply to the score. <br>One of 'NONE,' 'SOFTMAX,' 'LOGISTIC,' 'SOFTMAX_ZERO,' "
980
+ "or 'PROBIT'",
992
981
  AttributeProto::STRING,
993
982
  std::string("NONE"))
994
983
  .Attr(
@@ -998,48 +987,215 @@ ONNX_ML_OPERATOR_SET_SCHEMA(
998
987
  std::string("SUM"))
999
988
  .Attr(
1000
989
  "base_values",
1001
- "Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
990
+ "Base values for regression, added to final prediction after applying aggregate_function; the size must be "
991
+ "the same as the classes or can be left unassigned (assumed 0)",
1002
992
  AttributeProto::FLOATS,
1003
993
  OPTIONAL_VALUE)
1004
994
  .Attr(
1005
995
  "base_values_as_tensor",
1006
- "Base values for classification, added to final class score; the size must be the same as the classes or can be left unassigned (assumed 0)",
996
+ "Base values for regression, added to final prediction after applying aggregate_function; the size must be "
997
+ "the same as the classes or can be left unassigned (assumed 0)",
998
+ AttributeProto::TENSOR,
999
+ OPTIONAL_VALUE));
1000
+
1001
+ static const char* TreeEnsemble_ver5_doc = R"DOC(
1002
+ Tree Ensemble operator. Returns the regressed values for each input in a batch.
1003
+ Inputs have dimensions `[N, F]` where `N` is the input batch size and `F` is the number of input features.
1004
+ Outputs have dimensions `[N, num_targets]` where `N` is the batch size and `num_targets` is the number of targets, which is a configurable attribute.
1005
+
1006
+ The encoding of this attribute is split along interior nodes and the leaves of the trees. Notably, attributes with the prefix `nodes_*` are associated with interior nodes, and attributes with the prefix `leaf_*` are associated with leaves.
1007
+ The attributes `nodes_*` must all have the same length and encode a sequence of tuples, as defined by taking all the `nodes_*` fields at a given position.
1008
+
1009
+ All fields prefixed with `leaf_*` represent tree leaves, and similarly define tuples of leaves and must have identical length.
1010
+
1011
+ This operator can be used to implement both the previous `TreeEnsembleRegressor` and `TreeEnsembleClassifier` nodes.
1012
+ The `TreeEnsembleRegressor` node maps directly to this node and requires changing how the nodes are represented.
1013
+ The `TreeEnsembleClassifier` node can be implemented by adding a `ArgMax` node after this node to determine the top class.
1014
+ To encode class labels, a `LabelEncoder` or `GatherND` operator may be used.
1015
+ )DOC";
1016
+
1017
+ ONNX_ML_OPERATOR_SET_SCHEMA(
1018
+ TreeEnsemble,
1019
+ 5,
1020
+ OpSchema()
1021
+ .SetDoc(TreeEnsemble_ver5_doc)
1022
+ .Input(0, "X", "Input of shape [Batch Size, Number of Features]", "T")
1023
+ .Output(0, "Y", "Output of shape [Batch Size, Number of targets]", "T")
1024
+ .TypeConstraint(
1025
+ "T",
1026
+ {"tensor(float)", "tensor(double)", "tensor(float16)"},
1027
+ "The input type must be a tensor of a numeric type.")
1028
+ .Attr("nodes_featureids", "Feature id for each node.", AttributeProto::INTS, true)
1029
+ .Attr(
1030
+ "nodes_splits",
1031
+ "Thresholds to do the splitting on for each node with mode that is not 'BRANCH_MEMBER'.",
1032
+ AttributeProto::TENSOR,
1033
+ true)
1034
+ .Attr(
1035
+ "nodes_hitrates",
1036
+ "Popularity of each node, used for performance and may be omitted.",
1037
+ AttributeProto::TENSOR,
1038
+ OPTIONAL_VALUE)
1039
+ .Attr(
1040
+ "nodes_modes",
1041
+ "The comparison operation performed by the node. This is encoded as an enumeration of 0 ('BRANCH_LEQ'), 1 "
1042
+ "('BRANCH_LT'), 2 ('BRANCH_GTE'), 3 ('BRANCH_GT'), 4 ('BRANCH_EQ'), 5 ('BRANCH_NEQ'), and 6 "
1043
+ "('BRANCH_MEMBER'). Note this is a tensor of type uint8.",
1044
+ AttributeProto::TENSOR,
1045
+ true)
1046
+ .Attr(
1047
+ "nodes_truenodeids",
1048
+ "If `nodes_trueleafs` is false at an entry, this represents the position of the true branch node. This "
1049
+ "position can be used to index into a `nodes_*` entry. If `nodes_trueleafs` is false, it is an index into "
1050
+ "the leaf_* attributes.",
1051
+ AttributeProto::INTS,
1052
+ true)
1053
+ .Attr(
1054
+ "nodes_falsenodeids",
1055
+ "If `nodes_falseleafs` is false at an entry, this represents the position of the false branch node. This "
1056
+ "position can be used to index into a `nodes_*` entry. If `nodes_falseleafs` is false, it is an index into "
1057
+ "the leaf_* attributes.",
1058
+ AttributeProto::INTS,
1059
+ true)
1060
+ .Attr(
1061
+ "nodes_trueleafs",
1062
+ "1 if true branch is leaf for each node and 0 an interior node. To represent a tree that is a leaf (only "
1063
+ "has one node), one can do so by having a single `nodes_*` entry with true and false branches referencing "
1064
+ "the same `leaf_*` entry",
1065
+ AttributeProto::INTS,
1066
+ true)
1067
+ .Attr(
1068
+ "nodes_falseleafs",
1069
+ "1 if false branch is leaf for each node and 0 if an interior node. To represent a tree that is a leaf "
1070
+ "(only has one node), one can do so by having a single `nodes_*` entry with true and false branches "
1071
+ "referencing the same `leaf_*` entry",
1072
+ AttributeProto::INTS,
1073
+ true)
1074
+ .Attr(
1075
+ "nodes_missing_value_tracks_true",
1076
+ "For each node, define whether to follow the true branch (if attribute value is 1) or false branch (if "
1077
+ "attribute value is 0) in the presence of a NaN input feature. This attribute may be left undefined and "
1078
+ "the default value is false (0) for all nodes.",
1079
+ AttributeProto::INTS,
1080
+ OPTIONAL_VALUE)
1081
+ .Attr(
1082
+ "tree_roots",
1083
+ "Index into `nodes_*` for the root of each tree. The tree structure is derived from the branching of each "
1084
+ "node.",
1085
+ AttributeProto::INTS,
1086
+ true)
1087
+ .Attr(
1088
+ "membership_values",
1089
+ "Members to test membership of for each set membership node. List all of the members to test again in the "
1090
+ "order that the 'BRANCH_MEMBER' mode appears in `node_modes`, delimited by `NaN`s. Will have the same "
1091
+ "number "
1092
+ "of sets of values as nodes with mode 'BRANCH_MEMBER'. This may be omitted if the node doesn't contain any "
1093
+ "'BRANCH_MEMBER' nodes.",
1007
1094
  AttributeProto::TENSOR,
1008
1095
  OPTIONAL_VALUE)
1096
+ .Attr(
1097
+ "leaf_targetids",
1098
+ "The index of the target that this leaf contributes to (this must be in range `[0, n_targets)`).",
1099
+ AttributeProto::INTS,
1100
+ true)
1101
+ .Attr("leaf_weights", "The weight for each leaf.", AttributeProto::TENSOR, true)
1102
+ .Attr("n_targets", "The total number of targets.", AttributeProto::INT, OPTIONAL_VALUE)
1103
+ .Attr(
1104
+ "post_transform",
1105
+ "Indicates the transform to apply to the score. <br>One of 'NONE' (0), 'SOFTMAX' (1), 'LOGISTIC' (2), "
1106
+ "'SOFTMAX_ZERO' (3) or 'PROBIT' (4), defaults to 'NONE' (0)",
1107
+ AttributeProto::INT,
1108
+ static_cast<int64_t>(0))
1109
+ .Attr(
1110
+ "aggregate_function",
1111
+ "Defines how to aggregate leaf values within a target. <br>One of 'AVERAGE' (0) 'SUM' (1) 'MIN' (2) 'MAX "
1112
+ "(3) defaults to 'SUM' (1)",
1113
+ AttributeProto::INT,
1114
+ static_cast<int64_t>(1))
1009
1115
  .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
1010
- auto* nodes_values = ctx.getAttribute("nodes_values");
1011
- auto* nodes_values_as_tensor = ctx.getAttribute("nodes_values_as_tensor");
1012
- auto* nodes_hitrates = ctx.getAttribute("nodes_hitrates");
1013
- auto* nodes_hitrates_as_tensor = ctx.getAttribute("nodes_hitrates_as_tensor");
1014
- auto* target_weights = ctx.getAttribute("target_weights");
1015
- auto* target_weights_as_tensor = ctx.getAttribute("target_weights_as_tensor");
1016
- auto* base_values = ctx.getAttribute("base_values");
1017
- auto* base_values_as_tensor = ctx.getAttribute("base_values_as_tensor");
1018
-
1019
- if (nullptr != nodes_values && nullptr != nodes_values_as_tensor) {
1020
- fail_shape_inference(
1021
- "Only one of the attributes 'nodes_values', 'nodes_values_as_tensor' should be specified.");
1116
+ checkInputRank(ctx, 0, 2);
1117
+ auto* nodes_splits = ctx.getAttribute("nodes_splits");
1118
+ if (nullptr == nodes_splits) {
1119
+ fail_shape_inference("Attribute 'nodes_splits' is required.");
1022
1120
  }
1023
- if (nullptr != nodes_hitrates && nullptr != nodes_hitrates_as_tensor) {
1121
+ if (nodes_splits->t().dims_size() != 1) {
1122
+ fail_shape_inference("Attribute 'nodes_splits' must be 1D.");
1123
+ }
1124
+ auto input_type = ctx.getInputType(0)->tensor_type().elem_type();
1125
+ // Check that input type is same as split type
1126
+ if (input_type != nodes_splits->t().data_type()) {
1024
1127
  fail_shape_inference(
1025
- "Only one of the attributes 'nodes_hitrates', 'nodes_hitrates_as_tensor' should be specified.");
1128
+ "Attribute 'nodes_splits' must have same type as input. Input type is ",
1129
+ input_type,
1130
+ " and attribute type is ",
1131
+ nodes_splits->t().data_type());
1026
1132
  }
1027
- if (nullptr != target_weights && nullptr != target_weights_as_tensor) {
1133
+
1134
+ // Expected nodes_* length
1135
+ auto expected_length = nodes_splits->t().dims(0);
1136
+ // Validate all nodes_* attributes that are set have the same length and are 1D.
1137
+ AssertAttributeProtoTypeAndLength(
1138
+ ctx.getAttribute("nodes_featureids"), expected_length, TensorProto_DataType_INT64, true);
1139
+ AssertAttributeProtoTypeAndLength(
1140
+ ctx.getAttribute("nodes_hitrates"), expected_length, TensorProto_DataType_FLOAT, false);
1141
+ AssertAttributeProtoTypeAndLength(
1142
+ ctx.getAttribute("nodes_modes"), expected_length, TensorProto_DataType_UINT8, true);
1143
+ AssertAttributeProtoTypeAndLength(
1144
+ ctx.getAttribute("nodes_truenodeids"), expected_length, TensorProto_DataType_INT64, true);
1145
+ AssertAttributeProtoTypeAndLength(
1146
+ ctx.getAttribute("nodes_falsenodeids"), expected_length, TensorProto_DataType_INT64, true);
1147
+ AssertAttributeProtoTypeAndLength(
1148
+ ctx.getAttribute("nodes_trueleafs"), expected_length, TensorProto_DataType_INT64, true);
1149
+ AssertAttributeProtoTypeAndLength(
1150
+ ctx.getAttribute("nodes_falseleafs"), expected_length, TensorProto_DataType_INT64, true);
1151
+ AssertAttributeProtoTypeAndLength(
1152
+ ctx.getAttribute("nodes_missing_value_tracks_true"), expected_length, TensorProto_DataType_INT64, false);
1153
+
1154
+ // The set membership values and the splits must have the same type as the input.
1155
+ auto* membership_values = ctx.getAttribute("membership_values");
1156
+ if (nullptr != membership_values && membership_values->t().data_type() != input_type) {
1028
1157
  fail_shape_inference(
1029
- "Only one of the attributes 'target_weights', 'target_weights_as_tensor' should be specified.");
1158
+ "Attribute 'membership_values' must have same type as input. Input type is ",
1159
+ input_type,
1160
+ " and attribute type is ",
1161
+ membership_values->t().data_type());
1030
1162
  }
1031
- if (nullptr != base_values && nullptr != base_values_as_tensor) {
1163
+ AssertAttributeProtoTypeAndLength(
1164
+ ctx.getAttribute("nodes_splits"), expected_length, static_cast<TensorProto_DataType>(input_type), true);
1165
+
1166
+ // Validate all leaf_* attributes that are set have the same length and are 1D.
1167
+ auto* leaf_targetids = ctx.getAttribute("leaf_targetids");
1168
+ auto* leaf_weights = ctx.getAttribute("leaf_weights");
1169
+ if (nullptr != leaf_targetids && nullptr != leaf_weights) {
1170
+ if (leaf_targetids->ints_size() != leaf_weights->t().dims(0)) {
1171
+ fail_shape_inference(
1172
+ "Attribute 'leaf_targetids' must have same length as attribute 'leaf_weights'. 'leaf_targetids' "
1173
+ "length is ",
1174
+ leaf_targetids->ints_size(),
1175
+ " and 'leaf_weights' length is ",
1176
+ leaf_weights->t().dims(0));
1177
+ }
1178
+ } else {
1179
+ fail_shape_inference("Attributes 'leaf_targetids' and 'leaf_weights' must both be set.");
1180
+ }
1181
+
1182
+ // Validate weights have same type as input.
1183
+ if (leaf_weights->t().data_type() != input_type) {
1032
1184
  fail_shape_inference(
1033
- "Only one of the attributes 'base_values', 'base_values_as_tensor' should be specified.");
1185
+ "Attribute 'leaf_weights' must have same type as input. Input type is ",
1186
+ input_type,
1187
+ " and attribute type is ",
1188
+ leaf_weights->t().data_type());
1034
1189
  }
1035
1190
 
1036
1191
  checkInputRank(ctx, 0, 2);
1192
+
1037
1193
  Dim N, E;
1038
1194
  unifyInputDim(ctx, 0, 0, N);
1039
1195
  if (nullptr != ctx.getAttribute("n_targets")) {
1040
1196
  unifyDim(E, ctx.getAttribute("n_targets")->i());
1041
1197
  }
1042
- updateOutputElemType(ctx, 0, TensorProto::FLOAT);
1198
+ updateOutputElemType(ctx, 0, input_type);
1043
1199
  updateOutputShape(ctx, 0, {N, E});
1044
1200
  }));
1045
1201