onnx 1.19.0__cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl → 1.19.1rc1__cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx might be problematic. Click here for more details.

Files changed (202) hide show
  1. onnx/__init__.py +98 -0
  2. onnx/backend/test/case/node/__init__.py +20 -3
  3. onnx/backend/test/case/node/attention.py +62 -0
  4. onnx/backend/test/case/node/rotaryembedding.py +6 -6
  5. onnx/backend/test/data/node/test_attention_3d/model.onnx +0 -0
  6. onnx/backend/test/data/node/test_attention_3d_attn_mask/model.onnx +0 -0
  7. onnx/backend/test/data/node/test_attention_3d_attn_mask_expanded/model.onnx +0 -0
  8. onnx/backend/test/data/node/test_attention_3d_causal/model.onnx +0 -0
  9. onnx/backend/test/data/node/test_attention_3d_causal_expanded/model.onnx +0 -0
  10. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes/model.onnx +0 -0
  11. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask/model.onnx +0 -0
  12. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
  13. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal/model.onnx +0 -0
  14. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
  15. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_expanded/model.onnx +0 -0
  16. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled/model.onnx +0 -0
  17. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
  18. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap/model.onnx +0 -0
  19. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
  20. onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present/model.onnx +0 -0
  21. onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
  22. onnx/backend/test/data/node/test_attention_3d_expanded/model.onnx +0 -0
  23. onnx/backend/test/data/node/test_attention_3d_gqa/model.onnx +0 -0
  24. onnx/backend/test/data/node/test_attention_3d_gqa/test_data_set_0/output_0.pb +0 -0
  25. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/model.onnx +0 -0
  26. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
  27. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/model.onnx +0 -0
  28. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
  29. onnx/backend/test/data/node/test_attention_3d_gqa_causal/model.onnx +0 -0
  30. onnx/backend/test/data/node/test_attention_3d_gqa_causal/test_data_set_0/output_0.pb +0 -0
  31. onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/model.onnx +0 -0
  32. onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
  33. onnx/backend/test/data/node/test_attention_3d_gqa_expanded/model.onnx +0 -0
  34. onnx/backend/test/data/node/test_attention_3d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
  35. onnx/backend/test/data/node/test_attention_3d_gqa_scaled/model.onnx +0 -0
  36. onnx/backend/test/data/node/test_attention_3d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
  37. onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/model.onnx +0 -0
  38. onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
  39. onnx/backend/test/data/node/test_attention_3d_gqa_softcap/model.onnx +0 -0
  40. onnx/backend/test/data/node/test_attention_3d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
  41. onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/model.onnx +0 -0
  42. onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
  43. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/model.onnx +0 -0
  44. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
  45. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/model.onnx +0 -0
  46. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
  47. onnx/backend/test/data/node/test_attention_3d_scaled/model.onnx +0 -0
  48. onnx/backend/test/data/node/test_attention_3d_scaled_expanded/model.onnx +0 -0
  49. onnx/backend/test/data/node/test_attention_3d_softcap/model.onnx +0 -0
  50. onnx/backend/test/data/node/test_attention_3d_softcap_expanded/model.onnx +0 -0
  51. onnx/backend/test/data/node/test_attention_3d_transpose_verification/model.onnx +0 -0
  52. onnx/backend/test/data/node/test_attention_3d_transpose_verification_expanded/model.onnx +0 -0
  53. onnx/backend/test/data/node/test_attention_3d_with_past_and_present/model.onnx +0 -0
  54. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_expanded/model.onnx +0 -0
  55. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul/model.onnx +0 -0
  56. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
  57. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
  58. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
  59. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap/model.onnx +0 -0
  60. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap_expanded/model.onnx +0 -0
  61. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax/model.onnx +0 -0
  62. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax_expanded/model.onnx +0 -0
  63. onnx/backend/test/data/node/test_attention_4d/model.onnx +0 -0
  64. onnx/backend/test/data/node/test_attention_4d_attn_mask/model.onnx +0 -0
  65. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d/model.onnx +0 -0
  66. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal/model.onnx +0 -0
  67. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal_expanded/model.onnx +0 -0
  68. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_expanded/model.onnx +0 -0
  69. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d/model.onnx +0 -0
  70. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal/model.onnx +0 -0
  71. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal_expanded/model.onnx +0 -0
  72. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_expanded/model.onnx +0 -0
  73. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool/model.onnx +0 -0
  74. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d/model.onnx +0 -0
  75. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d_expanded/model.onnx +0 -0
  76. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_expanded/model.onnx +0 -0
  77. onnx/backend/test/data/node/test_attention_4d_attn_mask_expanded/model.onnx +0 -0
  78. onnx/backend/test/data/node/test_attention_4d_causal/model.onnx +0 -0
  79. onnx/backend/test/data/node/test_attention_4d_causal_expanded/model.onnx +0 -0
  80. onnx/backend/test/data/node/test_attention_4d_diff_heads_mask4d_padded_kv_expanded/model.onnx +0 -0
  81. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes/model.onnx +0 -0
  82. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask/model.onnx +0 -0
  83. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
  84. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal/model.onnx +0 -0
  85. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
  86. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_expanded/model.onnx +0 -0
  87. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled/model.onnx +0 -0
  88. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
  89. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap/model.onnx +0 -0
  90. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
  91. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present/model.onnx +0 -0
  92. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
  93. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d/model.onnx +0 -0
  94. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d_expanded/model.onnx +0 -0
  95. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d/model.onnx +0 -0
  96. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d_expanded/model.onnx +0 -0
  97. onnx/backend/test/data/node/test_attention_4d_expanded/model.onnx +0 -0
  98. onnx/backend/test/data/node/test_attention_4d_fp16/model.onnx +0 -0
  99. onnx/backend/test/data/node/test_attention_4d_fp16_expanded/model.onnx +0 -0
  100. onnx/backend/test/data/node/test_attention_4d_gqa/model.onnx +0 -0
  101. onnx/backend/test/data/node/test_attention_4d_gqa/test_data_set_0/output_0.pb +0 -0
  102. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/model.onnx +0 -0
  103. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
  104. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/model.onnx +0 -0
  105. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
  106. onnx/backend/test/data/node/test_attention_4d_gqa_causal/model.onnx +0 -0
  107. onnx/backend/test/data/node/test_attention_4d_gqa_causal/test_data_set_0/output_0.pb +0 -0
  108. onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/model.onnx +0 -0
  109. onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
  110. onnx/backend/test/data/node/test_attention_4d_gqa_expanded/model.onnx +0 -0
  111. onnx/backend/test/data/node/test_attention_4d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
  112. onnx/backend/test/data/node/test_attention_4d_gqa_scaled/model.onnx +0 -0
  113. onnx/backend/test/data/node/test_attention_4d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
  114. onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/model.onnx +0 -0
  115. onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
  116. onnx/backend/test/data/node/test_attention_4d_gqa_softcap/model.onnx +0 -0
  117. onnx/backend/test/data/node/test_attention_4d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
  118. onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/model.onnx +0 -0
  119. onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
  120. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/model.onnx +0 -0
  121. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
  122. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/model.onnx +0 -0
  123. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
  124. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/model.onnx +0 -0
  125. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/test_data_set_0/output_0.pb +0 -0
  126. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/model.onnx +0 -0
  127. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/test_data_set_0/output_0.pb +0 -0
  128. onnx/backend/test/data/node/test_attention_4d_scaled/model.onnx +0 -0
  129. onnx/backend/test/data/node/test_attention_4d_scaled_expanded/model.onnx +0 -0
  130. onnx/backend/test/data/node/test_attention_4d_softcap/model.onnx +0 -0
  131. onnx/backend/test/data/node/test_attention_4d_softcap_expanded/model.onnx +0 -0
  132. onnx/backend/test/data/node/test_attention_4d_with_past_and_present/model.onnx +0 -0
  133. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_expanded/model.onnx +0 -0
  134. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul/model.onnx +0 -0
  135. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
  136. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask/model.onnx +0 -0
  137. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/model.onnx +0 -0
  138. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_0.pb +0 -0
  139. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_3.pb +0 -0
  140. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/model.onnx +0 -0
  141. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
  142. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
  143. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_expanded/model.onnx +0 -0
  144. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask/model.onnx +0 -0
  145. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/model.onnx +0 -0
  146. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_0.pb +0 -0
  147. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_3.pb +0 -0
  148. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/model.onnx +0 -0
  149. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
  150. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
  151. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_expanded/model.onnx +0 -0
  152. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
  153. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
  154. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul/model.onnx +0 -0
  155. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias/model.onnx +0 -0
  156. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias_expanded/model.onnx +0 -0
  157. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_expanded/model.onnx +0 -0
  158. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap/model.onnx +0 -0
  159. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap_expanded/model.onnx +0 -0
  160. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax/model.onnx +0 -0
  161. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax_expanded/model.onnx +0 -0
  162. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/model.onnx +0 -0
  163. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_1.pb +1 -1
  164. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_2.pb +1 -1
  165. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/output_0.pb +0 -0
  166. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/model.onnx +0 -0
  167. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_1.pb +1 -1
  168. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_2.pb +1 -1
  169. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  170. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/model.onnx +0 -0
  171. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_1.pb +0 -0
  172. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_2.pb +0 -0
  173. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/output_0.pb +0 -0
  174. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/model.onnx +0 -0
  175. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
  176. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
  177. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  178. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/model.onnx +0 -0
  179. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_1.pb +0 -0
  180. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_2.pb +0 -0
  181. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/output_0.pb +0 -0
  182. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx +0 -0
  183. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
  184. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
  185. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  186. onnx/defs/nn/defs.cc +70 -228
  187. onnx/defs/nn/old.cc +31 -201
  188. onnx/defs/nn/utils.cc +222 -0
  189. onnx/defs/nn/utils.h +25 -0
  190. onnx/onnx_cpp2py_export.cpython-39-aarch64-linux-gnu.so +0 -0
  191. onnx/reference/ops/op_attention.py +33 -14
  192. onnx/reference/ops/op_rotary_embedding.py +21 -19
  193. onnx/test/basic_test.py +84 -0
  194. onnx/test/reference_evaluator_test.py +23 -0
  195. onnx/test/test_backend_reference.py +2 -1
  196. onnx/version.py +1 -1
  197. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/METADATA +2 -2
  198. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/RECORD +202 -200
  199. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/WHEEL +1 -1
  200. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/entry_points.txt +0 -0
  201. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/licenses/LICENSE +0 -0
  202. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/top_level.txt +0 -0
onnx/__init__.py CHANGED
@@ -143,6 +143,9 @@ from onnx import (
143
143
  version_converter,
144
144
  )
145
145
 
146
+ if typing.TYPE_CHECKING:
147
+ from collections.abc import Sequence
148
+
146
149
  __version__ = onnx.version.version
147
150
 
148
151
  # Supported model formats that can be loaded from and saved to
@@ -368,3 +371,98 @@ def save_tensor(
368
371
  load = load_model
369
372
  load_from_string = load_model_from_string
370
373
  save = save_model
374
+
375
+
376
+ def _model_proto_repr(self: ModelProto) -> str:
377
+ if self.domain:
378
+ domain = f", domain='{self.domain}'"
379
+ else:
380
+ domain = ""
381
+ if self.producer_name:
382
+ producer_name = f", producer_name='{self.producer_name}'"
383
+ else:
384
+ producer_name = ""
385
+ if self.producer_version:
386
+ producer_version = f", producer_version='{self.producer_version}'"
387
+ else:
388
+ producer_version = ""
389
+ if self.graph:
390
+ graph = f", graph={self.graph!r}"
391
+ else:
392
+ graph = ""
393
+ if self.functions:
394
+ functions = f", functions=<{len(self.functions)} functions>"
395
+ else:
396
+ functions = ""
397
+ if self.opset_import:
398
+ opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}"
399
+ else:
400
+ opset_import = ""
401
+ return f"ModelProto(ir_version={self.ir_version}{opset_import}{domain}{producer_name}{producer_version}{graph}{functions})"
402
+
403
+
404
+ def _graph_proto_repr(self: GraphProto) -> str:
405
+ if self.initializer:
406
+ initializer = f", initializer=<{len(self.initializer)} initializers>"
407
+ else:
408
+ initializer = ""
409
+ if self.node:
410
+ node = f", node=<{len(self.node)} nodes>"
411
+ else:
412
+ node = ""
413
+ if self.value_info:
414
+ value_info = f", value_info=<{len(self.value_info)} value_info>"
415
+ else:
416
+ value_info = ""
417
+ if self.input:
418
+ input = f", input=<{len(self.input)} inputs>"
419
+ else:
420
+ input = ""
421
+ if self.output:
422
+ output = f", output=<{len(self.output)} outputs>"
423
+ else:
424
+ output = ""
425
+ return f"GraphProto('{self.name}'{input}{output}{initializer}{node}{value_info})"
426
+
427
+
428
+ def _function_proto_repr(self: FunctionProto) -> str:
429
+ if self.domain:
430
+ domain = f", domain='{self.domain}'"
431
+ else:
432
+ domain = ""
433
+ if self.overload:
434
+ overload = f", overload='{self.overload}'"
435
+ else:
436
+ overload = ""
437
+ if self.node:
438
+ node = f", node=<{len(self.node)} nodes>"
439
+ else:
440
+ node = ""
441
+ if self.attribute:
442
+ attribute = f", attribute={self.attribute}"
443
+ else:
444
+ attribute = ""
445
+ if self.opset_import:
446
+ opset_import = f", opset_import={_operator_set_protos_repr(self.opset_import)}"
447
+ else:
448
+ opset_import = ""
449
+ if self.input:
450
+ input = f", input=<{len(self.input)} inputs>"
451
+ else:
452
+ input = ""
453
+ if self.output:
454
+ output = f", output=<{len(self.output)} outputs>"
455
+ else:
456
+ output = ""
457
+ return f"FunctionProto('{self.name}'{domain}{overload}{opset_import}{input}{output}{attribute}{node})"
458
+
459
+
460
+ def _operator_set_protos_repr(protos: Sequence[OperatorSetIdProto]) -> str:
461
+ opset_imports = {proto.domain: proto.version for proto in protos}
462
+ return repr(opset_imports)
463
+
464
+
465
+ # Override __repr__ for some proto classes to make it more efficient
466
+ ModelProto.__repr__ = _model_proto_repr # type: ignore[method-assign,assignment]
467
+ GraphProto.__repr__ = _graph_proto_repr # type: ignore[method-assign,assignment]
468
+ FunctionProto.__repr__ = _function_proto_repr # type: ignore[method-assign,assignment]
@@ -18,6 +18,7 @@ from onnx.onnx_pb import (
18
18
  GraphProto,
19
19
  ModelProto,
20
20
  NodeProto,
21
+ OperatorSetIdProto,
21
22
  TensorProto,
22
23
  TypeProto,
23
24
  )
@@ -128,11 +129,25 @@ def function_expand_helper(
128
129
 
129
130
 
130
131
  def function_testcase_helper(
131
- node: NodeProto, input_types: list[TypeProto], name: str
132
+ node: NodeProto,
133
+ input_types: list[TypeProto],
134
+ name: str,
135
+ opset_imports: Sequence[OperatorSetIdProto] | None = None,
132
136
  ) -> tuple[list[tuple[list[NodeProto], Any]], int]:
133
137
  test_op = node.op_type
134
138
  op_prefix = test_op + "_" + name + "_expanded_function_"
135
- schema = onnx.defs.get_schema(test_op, domain=node.domain)
139
+ if opset_imports is None:
140
+ # No opset in the model. We take the most recent definition.
141
+ schema = onnx.defs.get_schema(test_op, domain=node.domain)
142
+ else:
143
+ # We take the function coming defined in the specific version mentioned
144
+ # in the model.
145
+ if len(opset_imports) != 1:
146
+ raise ValueError(
147
+ f"Only one domain is allowed but {len(opset_imports)} found."
148
+ )
149
+ version = opset_imports[0].version
150
+ schema = onnx.defs.get_schema(test_op, version, domain=node.domain)
136
151
 
137
152
  # an op schema may have several functions, each for one opset version
138
153
  # opset versions include the op's since_version and other opset versions
@@ -327,7 +342,9 @@ def expect(
327
342
  (
328
343
  expanded_tests,
329
344
  since_version,
330
- ) = function_testcase_helper(node, merged_types, name)
345
+ ) = function_testcase_helper(
346
+ node, merged_types, name, opset_imports=kwargs.get("opset_imports")
347
+ )
331
348
  for expanded_function_nodes, func_opset_import in expanded_tests:
332
349
  kwargs["producer_name"] = "backend-test"
333
350
 
@@ -27,6 +27,7 @@ class Attention(Base):
27
27
  inputs=[Q, K, V],
28
28
  outputs=[Y],
29
29
  name="test_attention_4d",
30
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
30
31
  )
31
32
 
32
33
  @staticmethod
@@ -44,6 +45,7 @@ class Attention(Base):
44
45
  inputs=[Q, K, V],
45
46
  outputs=[Y],
46
47
  name="test_attention_4d_fp16",
48
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
47
49
  )
48
50
 
49
51
  @staticmethod
@@ -61,6 +63,7 @@ class Attention(Base):
61
63
  inputs=[Q, K, V],
62
64
  outputs=[Y],
63
65
  name="test_attention_4d_gqa",
66
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
64
67
  )
65
68
 
66
69
  @staticmethod
@@ -78,6 +81,7 @@ class Attention(Base):
78
81
  inputs=[Q, K, V],
79
82
  outputs=[Y],
80
83
  name="test_attention_4d_diff_heads_sizes",
84
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
81
85
  )
82
86
 
83
87
  @staticmethod
@@ -101,6 +105,7 @@ class Attention(Base):
101
105
  inputs=[Q, K, V],
102
106
  outputs=[Y],
103
107
  name="test_attention_4d_scaled",
108
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
104
109
  )
105
110
 
106
111
  @staticmethod
@@ -124,6 +129,7 @@ class Attention(Base):
124
129
  inputs=[Q, K, V],
125
130
  outputs=[Y],
126
131
  name="test_attention_4d_gqa_scaled",
132
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
127
133
  )
128
134
 
129
135
  @staticmethod
@@ -147,6 +153,7 @@ class Attention(Base):
147
153
  inputs=[Q, K, V],
148
154
  outputs=[Y],
149
155
  name="test_attention_4d_diff_heads_sizes_scaled",
156
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
150
157
  )
151
158
 
152
159
  @staticmethod
@@ -169,6 +176,7 @@ class Attention(Base):
169
176
  inputs=[Q, K, V],
170
177
  outputs=[Y],
171
178
  name="test_attention_4d_causal",
179
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
172
180
  )
173
181
 
174
182
  @staticmethod
@@ -191,6 +199,7 @@ class Attention(Base):
191
199
  inputs=[Q, K, V],
192
200
  outputs=[Y],
193
201
  name="test_attention_4d_gqa_causal",
202
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
194
203
  )
195
204
 
196
205
  @staticmethod
@@ -218,6 +227,7 @@ class Attention(Base):
218
227
  inputs=[Q, K, V],
219
228
  outputs=[Y],
220
229
  name="test_attention_4d_diff_heads_sizes_causal",
230
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
221
231
  )
222
232
 
223
233
  @staticmethod
@@ -245,6 +255,7 @@ class Attention(Base):
245
255
  inputs=[Q, K, V, attn_mask],
246
256
  outputs=[Y],
247
257
  name="test_attention_4d_attn_mask",
258
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
248
259
  )
249
260
 
250
261
  @staticmethod
@@ -272,6 +283,7 @@ class Attention(Base):
272
283
  inputs=[Q, K, V, attn_mask],
273
284
  outputs=[Y],
274
285
  name="test_attention_4d_attn_mask_3d",
286
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
275
287
  )
276
288
 
277
289
  @staticmethod
@@ -301,6 +313,7 @@ class Attention(Base):
301
313
  inputs=[Q, K, V, attn_mask],
302
314
  outputs=[Y],
303
315
  name="test_attention_4d_attn_mask_3d_causal",
316
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
304
317
  )
305
318
 
306
319
  @staticmethod
@@ -328,6 +341,7 @@ class Attention(Base):
328
341
  inputs=[Q, K, V, attn_mask],
329
342
  outputs=[Y],
330
343
  name="test_attention_4d_attn_mask_4d",
344
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
331
345
  )
332
346
 
333
347
  @staticmethod
@@ -357,6 +371,7 @@ class Attention(Base):
357
371
  inputs=[Q, K, V, attn_mask],
358
372
  outputs=[Y],
359
373
  name="test_attention_4d_attn_mask_4d_causal",
374
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
360
375
  )
361
376
 
362
377
  @staticmethod
@@ -384,6 +399,7 @@ class Attention(Base):
384
399
  inputs=[Q, K, V, attn_mask],
385
400
  outputs=[Y],
386
401
  name="test_attention_4d_attn_mask_bool",
402
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
387
403
  )
388
404
 
389
405
  @staticmethod
@@ -411,6 +427,7 @@ class Attention(Base):
411
427
  inputs=[Q, K, V, attn_mask],
412
428
  outputs=[Y],
413
429
  name="test_attention_4d_attn_mask_bool_4d",
430
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
414
431
  )
415
432
 
416
433
  @staticmethod
@@ -438,6 +455,7 @@ class Attention(Base):
438
455
  inputs=[Q, K, V, attn_mask],
439
456
  outputs=[Y],
440
457
  name="test_attention_4d_gqa_attn_mask",
458
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
441
459
  )
442
460
 
443
461
  @staticmethod
@@ -465,6 +483,7 @@ class Attention(Base):
465
483
  inputs=[Q, K, V, attn_mask],
466
484
  outputs=[Y],
467
485
  name="test_attention_4d_diff_heads_sizes_attn_mask",
486
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
468
487
  )
469
488
 
470
489
  @staticmethod
@@ -497,6 +516,7 @@ class Attention(Base):
497
516
  inputs=[Q, K, V, attn_mask, past_key, past_value],
498
517
  outputs=[Y, present_key, present_value],
499
518
  name="test_attention_4d_with_past_and_present",
519
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
500
520
  )
501
521
 
502
522
  @staticmethod
@@ -529,6 +549,7 @@ class Attention(Base):
529
549
  inputs=[Q, K, V, attn_mask, past_key, past_value],
530
550
  outputs=[Y, present_key, present_value],
531
551
  name="test_attention_4d_gqa_with_past_and_present",
552
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
532
553
  )
533
554
 
534
555
  @staticmethod
@@ -561,6 +582,7 @@ class Attention(Base):
561
582
  inputs=[Q, K, V, attn_mask, past_key, past_value],
562
583
  outputs=[Y, present_key, present_value],
563
584
  name="test_attention_4d_gqa_with_past_and_present_fp16",
585
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
564
586
  )
565
587
 
566
588
  @staticmethod
@@ -593,6 +615,7 @@ class Attention(Base):
593
615
  inputs=[Q, K, V, attn_mask, past_key, past_value],
594
616
  outputs=[Y, present_key, present_value],
595
617
  name="test_attention_4d_diff_heads_with_past_and_present",
618
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
596
619
  )
597
620
 
598
621
  @staticmethod
@@ -625,6 +648,7 @@ class Attention(Base):
625
648
  inputs=[Q, K, V, attn_mask, past_key, past_value],
626
649
  outputs=[Y, present_key, present_value],
627
650
  name="test_attention_4d_diff_heads_with_past_and_present_mask3d",
651
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
628
652
  )
629
653
 
630
654
  @staticmethod
@@ -657,6 +681,7 @@ class Attention(Base):
657
681
  inputs=[Q, K, V, attn_mask, past_key, past_value],
658
682
  outputs=[Y, present_key, present_value],
659
683
  name="test_attention_4d_diff_heads_with_past_and_present_mask4d",
684
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
660
685
  )
661
686
 
662
687
  @staticmethod
@@ -679,6 +704,7 @@ class Attention(Base):
679
704
  inputs=[Q, K, V],
680
705
  outputs=[Y],
681
706
  name="test_attention_4d_softcap",
707
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
682
708
  )
683
709
 
684
710
  @staticmethod
@@ -701,6 +727,7 @@ class Attention(Base):
701
727
  inputs=[Q, K, V],
702
728
  outputs=[Y],
703
729
  name="test_attention_4d_gqa_softcap",
730
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
704
731
  )
705
732
 
706
733
  @staticmethod
@@ -728,6 +755,7 @@ class Attention(Base):
728
755
  inputs=[Q, K, V],
729
756
  outputs=[Y],
730
757
  name="test_attention_4d_diff_heads_sizes_softcap",
758
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
731
759
  )
732
760
 
733
761
  @staticmethod
@@ -749,6 +777,7 @@ class Attention(Base):
749
777
  inputs=[Q, K, V],
750
778
  outputs=[Y, qk_matmul_output],
751
779
  name="test_attention_4d_with_qk_matmul",
780
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
752
781
  )
753
782
 
754
783
  @staticmethod
@@ -778,6 +807,7 @@ class Attention(Base):
778
807
  inputs=[Q, K, V, attn_mask],
779
808
  outputs=[Y, qk_matmul_output],
780
809
  name="test_attention_4d_with_qk_matmul_bias",
810
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
781
811
  )
782
812
 
783
813
  @staticmethod
@@ -809,6 +839,7 @@ class Attention(Base):
809
839
  inputs=[Q, K, V, attn_mask],
810
840
  outputs=[Y, qk_matmul_output],
811
841
  name="test_attention_4d_with_qk_matmul_softcap",
842
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
812
843
  )
813
844
 
814
845
  @staticmethod
@@ -838,6 +869,7 @@ class Attention(Base):
838
869
  inputs=[Q, K, V, attn_mask],
839
870
  outputs=[Y, qk_matmul_output],
840
871
  name="test_attention_4d_with_qk_matmul_softmax",
872
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
841
873
  )
842
874
 
843
875
  @staticmethod
@@ -872,6 +904,7 @@ class Attention(Base):
872
904
  inputs=[Q, K, V, attn_mask, past_key, past_value],
873
905
  outputs=[Y, present_key, present_value, qk_matmul_output],
874
906
  name="test_attention_4d_with_past_and_present_qk_matmul_bias",
907
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
875
908
  )
876
909
 
877
910
  @staticmethod
@@ -906,6 +939,7 @@ class Attention(Base):
906
939
  inputs=[Q, K, V, attn_mask, past_key, past_value],
907
940
  outputs=[Y, present_key, present_value, qk_matmul_output],
908
941
  name="test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask",
942
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
909
943
  )
910
944
 
911
945
  @staticmethod
@@ -940,6 +974,7 @@ class Attention(Base):
940
974
  inputs=[Q, K, V, attn_mask, past_key, past_value],
941
975
  outputs=[Y, present_key, present_value, qk_matmul_output],
942
976
  name="test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask",
977
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
943
978
  )
944
979
 
945
980
  @staticmethod
@@ -976,6 +1011,7 @@ class Attention(Base):
976
1011
  inputs=[Q, K, V, attn_mask, past_key, past_value],
977
1012
  outputs=[Y, present_key, present_value, qk_matmul_output],
978
1013
  name="test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal",
1014
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
979
1015
  )
980
1016
 
981
1017
  @staticmethod
@@ -1012,6 +1048,7 @@ class Attention(Base):
1012
1048
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1013
1049
  outputs=[Y, present_key, present_value, qk_matmul_output],
1014
1050
  name="test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal",
1051
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1015
1052
  )
1016
1053
 
1017
1054
  @staticmethod
@@ -1044,6 +1081,7 @@ class Attention(Base):
1044
1081
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1045
1082
  outputs=[Y, present_key, present_value, qk_matmul_output],
1046
1083
  name="test_attention_4d_with_past_and_present_qk_matmul",
1084
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1047
1085
  )
1048
1086
 
1049
1087
  @staticmethod
@@ -1074,6 +1112,7 @@ class Attention(Base):
1074
1112
  inputs=[Q, K, V],
1075
1113
  outputs=[Y],
1076
1114
  name="test_attention_3d",
1115
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1077
1116
  )
1078
1117
 
1079
1118
  @staticmethod
@@ -1104,6 +1143,7 @@ class Attention(Base):
1104
1143
  inputs=[Q, K, V],
1105
1144
  outputs=[Y],
1106
1145
  name="test_attention_3d_gqa",
1146
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1107
1147
  )
1108
1148
 
1109
1149
  @staticmethod
@@ -1134,6 +1174,7 @@ class Attention(Base):
1134
1174
  inputs=[Q, K, V],
1135
1175
  outputs=[Y],
1136
1176
  name="test_attention_3d_diff_heads_sizes",
1177
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1137
1178
  )
1138
1179
 
1139
1180
  @staticmethod
@@ -1167,6 +1208,7 @@ class Attention(Base):
1167
1208
  inputs=[Q, K, V],
1168
1209
  outputs=[Y],
1169
1210
  name="test_attention_3d_scaled",
1211
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1170
1212
  )
1171
1213
 
1172
1214
  @staticmethod
@@ -1200,6 +1242,7 @@ class Attention(Base):
1200
1242
  inputs=[Q, K, V],
1201
1243
  outputs=[Y],
1202
1244
  name="test_attention_3d_gqa_scaled",
1245
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1203
1246
  )
1204
1247
 
1205
1248
  @staticmethod
@@ -1233,6 +1276,7 @@ class Attention(Base):
1233
1276
  inputs=[Q, K, V],
1234
1277
  outputs=[Y],
1235
1278
  name="test_attention_3d_diff_heads_sizes_scaled",
1279
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1236
1280
  )
1237
1281
 
1238
1282
  @staticmethod
@@ -1265,6 +1309,7 @@ class Attention(Base):
1265
1309
  inputs=[Q, K, V],
1266
1310
  outputs=[Y],
1267
1311
  name="test_attention_3d_causal",
1312
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1268
1313
  )
1269
1314
 
1270
1315
  @staticmethod
@@ -1297,6 +1342,7 @@ class Attention(Base):
1297
1342
  inputs=[Q, K, V],
1298
1343
  outputs=[Y],
1299
1344
  name="test_attention_3d_gqa_causal",
1345
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1300
1346
  )
1301
1347
 
1302
1348
  @staticmethod
@@ -1329,6 +1375,7 @@ class Attention(Base):
1329
1375
  inputs=[Q, K, V],
1330
1376
  outputs=[Y],
1331
1377
  name="test_attention_3d_diff_heads_sizes_causal",
1378
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1332
1379
  )
1333
1380
 
1334
1381
  @staticmethod
@@ -1361,6 +1408,7 @@ class Attention(Base):
1361
1408
  inputs=[Q, K, V, attn_mask],
1362
1409
  outputs=[Y],
1363
1410
  name="test_attention_3d_attn_mask",
1411
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1364
1412
  )
1365
1413
 
1366
1414
  @staticmethod
@@ -1393,6 +1441,7 @@ class Attention(Base):
1393
1441
  inputs=[Q, K, V, attn_mask],
1394
1442
  outputs=[Y],
1395
1443
  name="test_attention_3d_gqa_attn_mask",
1444
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1396
1445
  )
1397
1446
 
1398
1447
  @staticmethod
@@ -1425,6 +1474,7 @@ class Attention(Base):
1425
1474
  inputs=[Q, K, V, attn_mask],
1426
1475
  outputs=[Y],
1427
1476
  name="test_attention_3d_diff_heads_sizes_attn_mask",
1477
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1428
1478
  )
1429
1479
 
1430
1480
  @staticmethod
@@ -1457,6 +1507,7 @@ class Attention(Base):
1457
1507
  inputs=[Q, K, V],
1458
1508
  outputs=[Y],
1459
1509
  name="test_attention_3d_softcap",
1510
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1460
1511
  )
1461
1512
 
1462
1513
  @staticmethod
@@ -1489,6 +1540,7 @@ class Attention(Base):
1489
1540
  inputs=[Q, K, V],
1490
1541
  outputs=[Y],
1491
1542
  name="test_attention_3d_gqa_softcap",
1543
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1492
1544
  )
1493
1545
 
1494
1546
  @staticmethod
@@ -1521,6 +1573,7 @@ class Attention(Base):
1521
1573
  inputs=[Q, K, V],
1522
1574
  outputs=[Y],
1523
1575
  name="test_attention_3d_diff_heads_sizes_softcap",
1576
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1524
1577
  )
1525
1578
 
1526
1579
  @staticmethod
@@ -1558,6 +1611,7 @@ class Attention(Base):
1558
1611
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1559
1612
  outputs=[Y, present_key, present_value],
1560
1613
  name="test_attention_3d_with_past_and_present",
1614
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1561
1615
  )
1562
1616
 
1563
1617
  @staticmethod
@@ -1595,6 +1649,7 @@ class Attention(Base):
1595
1649
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1596
1650
  outputs=[Y, present_key, present_value],
1597
1651
  name="test_attention_3d_gqa_with_past_and_present",
1652
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1598
1653
  )
1599
1654
 
1600
1655
  @staticmethod
@@ -1632,6 +1687,7 @@ class Attention(Base):
1632
1687
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1633
1688
  outputs=[Y, present_key, present_value],
1634
1689
  name="test_attention_3d_diff_heads_with_past_and_present",
1690
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1635
1691
  )
1636
1692
 
1637
1693
  @staticmethod
@@ -1669,6 +1725,7 @@ class Attention(Base):
1669
1725
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1670
1726
  outputs=[Y, present_key, present_value, qk_matmul_output],
1671
1727
  name="test_attention_3d_with_past_and_present_qk_matmul",
1728
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1672
1729
  )
1673
1730
 
1674
1731
  @staticmethod
@@ -1708,6 +1765,7 @@ class Attention(Base):
1708
1765
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1709
1766
  outputs=[Y, present_key, present_value, qk_matmul_output],
1710
1767
  name="test_attention_3d_with_past_and_present_qk_matmul_bias",
1768
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1711
1769
  )
1712
1770
 
1713
1771
  @staticmethod
@@ -1749,6 +1807,7 @@ class Attention(Base):
1749
1807
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1750
1808
  outputs=[Y, present_key, present_value, qk_matmul_output],
1751
1809
  name="test_attention_3d_with_past_and_present_qk_matmul_softcap",
1810
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1752
1811
  )
1753
1812
 
1754
1813
  @staticmethod
@@ -1788,6 +1847,7 @@ class Attention(Base):
1788
1847
  inputs=[Q, K, V, attn_mask, past_key, past_value],
1789
1848
  outputs=[Y, present_key, present_value, qk_matmul_output],
1790
1849
  name="test_attention_3d_with_past_and_present_qk_matmul_softmax",
1850
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1791
1851
  )
1792
1852
 
1793
1853
  @staticmethod
@@ -1842,6 +1902,7 @@ class Attention(Base):
1842
1902
  inputs=[Q, K, V],
1843
1903
  outputs=[Y],
1844
1904
  name="test_attention_3d_transpose_verification",
1905
+ opset_imports=[onnx.helper.make_opsetid("", 23)],
1845
1906
  )
1846
1907
 
1847
1908
  @staticmethod
@@ -1871,4 +1932,5 @@ class Attention(Base):
1871
1932
  inputs=[Q, K, V, attn_mask, nonpad_kv_seqlen],
1872
1933
  outputs=[Y],
1873
1934
  name="test_attention_4d_diff_heads_mask4d_padded_kv",
1935
+ opset_imports=[onnx.helper.make_opsetid("", 24)],
1874
1936
  )
@@ -106,8 +106,8 @@ class RotaryEmbedding(Base):
106
106
 
107
107
  input_data = np.random.rand(2, 4, 3, 8).astype(np.float32)
108
108
  position_ids_data = np.random.uniform(0, 50, (2, 3)).astype(np.int64)
109
- sin_cache_data = np.random.rand(50, 4).astype(np.float32)
110
- cos_cache_data = np.random.rand(50, 4).astype(np.float32)
109
+ sin_cache_data = np.random.rand(50, 2).astype(np.float32)
110
+ cos_cache_data = np.random.rand(50, 2).astype(np.float32)
111
111
 
112
112
  expected_output = rotary_embedding(
113
113
  input_data,
@@ -136,8 +136,8 @@ class RotaryEmbedding(Base):
136
136
 
137
137
  input_data = np.random.rand(2, 4, 3, 8).astype(np.float32)
138
138
  position_ids_data = np.random.uniform(0, 50, (2, 3)).astype(np.int64)
139
- sin_cache_data = np.random.rand(50, 4).astype(np.float32)
140
- cos_cache_data = np.random.rand(50, 4).astype(np.float32)
139
+ sin_cache_data = np.random.rand(50, 2).astype(np.float32)
140
+ cos_cache_data = np.random.rand(50, 2).astype(np.float32)
141
141
 
142
142
  expected_output = rotary_embedding(
143
143
  input_data,
@@ -213,8 +213,8 @@ class RotaryEmbedding(Base):
213
213
  )
214
214
 
215
215
  input_data = np.random.rand(2, 4, 3, 8).astype(np.float32)
216
- sin_cache_data = np.random.rand(2, 3, 4).astype(np.float32)
217
- cos_cache_data = np.random.rand(2, 3, 4).astype(np.float32)
216
+ sin_cache_data = np.random.rand(2, 3, 2).astype(np.float32)
217
+ cos_cache_data = np.random.rand(2, 3, 2).astype(np.float32)
218
218
 
219
219
  expected_output = rotary_embedding(
220
220
  input_data,