onnx 1.19.0__cp313-cp313t-win_amd64.whl → 1.19.1rc1__cp313-cp313t-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx might be problematic. Click here for more details.

Files changed (202) hide show
  1. onnx/__init__.py +98 -0
  2. onnx/backend/test/case/node/__init__.py +20 -3
  3. onnx/backend/test/case/node/attention.py +62 -0
  4. onnx/backend/test/case/node/rotaryembedding.py +6 -6
  5. onnx/backend/test/data/node/test_attention_3d/model.onnx +0 -0
  6. onnx/backend/test/data/node/test_attention_3d_attn_mask/model.onnx +0 -0
  7. onnx/backend/test/data/node/test_attention_3d_attn_mask_expanded/model.onnx +0 -0
  8. onnx/backend/test/data/node/test_attention_3d_causal/model.onnx +0 -0
  9. onnx/backend/test/data/node/test_attention_3d_causal_expanded/model.onnx +0 -0
  10. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes/model.onnx +0 -0
  11. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask/model.onnx +0 -0
  12. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
  13. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal/model.onnx +0 -0
  14. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
  15. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_expanded/model.onnx +0 -0
  16. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled/model.onnx +0 -0
  17. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
  18. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap/model.onnx +0 -0
  19. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
  20. onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present/model.onnx +0 -0
  21. onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
  22. onnx/backend/test/data/node/test_attention_3d_expanded/model.onnx +0 -0
  23. onnx/backend/test/data/node/test_attention_3d_gqa/model.onnx +0 -0
  24. onnx/backend/test/data/node/test_attention_3d_gqa/test_data_set_0/output_0.pb +0 -0
  25. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/model.onnx +0 -0
  26. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
  27. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/model.onnx +0 -0
  28. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
  29. onnx/backend/test/data/node/test_attention_3d_gqa_causal/model.onnx +0 -0
  30. onnx/backend/test/data/node/test_attention_3d_gqa_causal/test_data_set_0/output_0.pb +0 -0
  31. onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/model.onnx +0 -0
  32. onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
  33. onnx/backend/test/data/node/test_attention_3d_gqa_expanded/model.onnx +0 -0
  34. onnx/backend/test/data/node/test_attention_3d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
  35. onnx/backend/test/data/node/test_attention_3d_gqa_scaled/model.onnx +0 -0
  36. onnx/backend/test/data/node/test_attention_3d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
  37. onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/model.onnx +0 -0
  38. onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
  39. onnx/backend/test/data/node/test_attention_3d_gqa_softcap/model.onnx +0 -0
  40. onnx/backend/test/data/node/test_attention_3d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
  41. onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/model.onnx +0 -0
  42. onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
  43. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/model.onnx +0 -0
  44. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
  45. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/model.onnx +0 -0
  46. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
  47. onnx/backend/test/data/node/test_attention_3d_scaled/model.onnx +0 -0
  48. onnx/backend/test/data/node/test_attention_3d_scaled_expanded/model.onnx +0 -0
  49. onnx/backend/test/data/node/test_attention_3d_softcap/model.onnx +0 -0
  50. onnx/backend/test/data/node/test_attention_3d_softcap_expanded/model.onnx +0 -0
  51. onnx/backend/test/data/node/test_attention_3d_transpose_verification/model.onnx +0 -0
  52. onnx/backend/test/data/node/test_attention_3d_transpose_verification_expanded/model.onnx +0 -0
  53. onnx/backend/test/data/node/test_attention_3d_with_past_and_present/model.onnx +0 -0
  54. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_expanded/model.onnx +0 -0
  55. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul/model.onnx +0 -0
  56. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
  57. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
  58. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
  59. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap/model.onnx +0 -0
  60. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap_expanded/model.onnx +0 -0
  61. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax/model.onnx +0 -0
  62. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax_expanded/model.onnx +0 -0
  63. onnx/backend/test/data/node/test_attention_4d/model.onnx +0 -0
  64. onnx/backend/test/data/node/test_attention_4d_attn_mask/model.onnx +0 -0
  65. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d/model.onnx +0 -0
  66. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal/model.onnx +0 -0
  67. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal_expanded/model.onnx +0 -0
  68. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_expanded/model.onnx +0 -0
  69. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d/model.onnx +0 -0
  70. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal/model.onnx +0 -0
  71. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal_expanded/model.onnx +0 -0
  72. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_expanded/model.onnx +0 -0
  73. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool/model.onnx +0 -0
  74. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d/model.onnx +0 -0
  75. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d_expanded/model.onnx +0 -0
  76. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_expanded/model.onnx +0 -0
  77. onnx/backend/test/data/node/test_attention_4d_attn_mask_expanded/model.onnx +0 -0
  78. onnx/backend/test/data/node/test_attention_4d_causal/model.onnx +0 -0
  79. onnx/backend/test/data/node/test_attention_4d_causal_expanded/model.onnx +0 -0
  80. onnx/backend/test/data/node/test_attention_4d_diff_heads_mask4d_padded_kv_expanded/model.onnx +0 -0
  81. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes/model.onnx +0 -0
  82. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask/model.onnx +0 -0
  83. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
  84. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal/model.onnx +0 -0
  85. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
  86. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_expanded/model.onnx +0 -0
  87. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled/model.onnx +0 -0
  88. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
  89. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap/model.onnx +0 -0
  90. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
  91. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present/model.onnx +0 -0
  92. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
  93. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d/model.onnx +0 -0
  94. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d_expanded/model.onnx +0 -0
  95. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d/model.onnx +0 -0
  96. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d_expanded/model.onnx +0 -0
  97. onnx/backend/test/data/node/test_attention_4d_expanded/model.onnx +0 -0
  98. onnx/backend/test/data/node/test_attention_4d_fp16/model.onnx +0 -0
  99. onnx/backend/test/data/node/test_attention_4d_fp16_expanded/model.onnx +0 -0
  100. onnx/backend/test/data/node/test_attention_4d_gqa/model.onnx +0 -0
  101. onnx/backend/test/data/node/test_attention_4d_gqa/test_data_set_0/output_0.pb +0 -0
  102. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/model.onnx +0 -0
  103. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
  104. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/model.onnx +0 -0
  105. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
  106. onnx/backend/test/data/node/test_attention_4d_gqa_causal/model.onnx +0 -0
  107. onnx/backend/test/data/node/test_attention_4d_gqa_causal/test_data_set_0/output_0.pb +0 -0
  108. onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/model.onnx +0 -0
  109. onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
  110. onnx/backend/test/data/node/test_attention_4d_gqa_expanded/model.onnx +0 -0
  111. onnx/backend/test/data/node/test_attention_4d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
  112. onnx/backend/test/data/node/test_attention_4d_gqa_scaled/model.onnx +0 -0
  113. onnx/backend/test/data/node/test_attention_4d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
  114. onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/model.onnx +0 -0
  115. onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
  116. onnx/backend/test/data/node/test_attention_4d_gqa_softcap/model.onnx +0 -0
  117. onnx/backend/test/data/node/test_attention_4d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
  118. onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/model.onnx +0 -0
  119. onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
  120. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/model.onnx +0 -0
  121. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
  122. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/model.onnx +0 -0
  123. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
  124. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/model.onnx +0 -0
  125. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/test_data_set_0/output_0.pb +0 -0
  126. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/model.onnx +0 -0
  127. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/test_data_set_0/output_0.pb +0 -0
  128. onnx/backend/test/data/node/test_attention_4d_scaled/model.onnx +0 -0
  129. onnx/backend/test/data/node/test_attention_4d_scaled_expanded/model.onnx +0 -0
  130. onnx/backend/test/data/node/test_attention_4d_softcap/model.onnx +0 -0
  131. onnx/backend/test/data/node/test_attention_4d_softcap_expanded/model.onnx +0 -0
  132. onnx/backend/test/data/node/test_attention_4d_with_past_and_present/model.onnx +0 -0
  133. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_expanded/model.onnx +0 -0
  134. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul/model.onnx +0 -0
  135. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
  136. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask/model.onnx +0 -0
  137. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/model.onnx +0 -0
  138. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_0.pb +0 -0
  139. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_3.pb +0 -0
  140. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/model.onnx +0 -0
  141. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
  142. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
  143. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_expanded/model.onnx +0 -0
  144. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask/model.onnx +0 -0
  145. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/model.onnx +0 -0
  146. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_0.pb +0 -0
  147. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_3.pb +0 -0
  148. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/model.onnx +0 -0
  149. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
  150. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
  151. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_expanded/model.onnx +0 -0
  152. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
  153. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
  154. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul/model.onnx +0 -0
  155. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias/model.onnx +0 -0
  156. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias_expanded/model.onnx +0 -0
  157. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_expanded/model.onnx +0 -0
  158. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap/model.onnx +0 -0
  159. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap_expanded/model.onnx +0 -0
  160. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax/model.onnx +0 -0
  161. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax_expanded/model.onnx +0 -0
  162. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/model.onnx +0 -0
  163. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_1.pb +1 -1
  164. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_2.pb +1 -1
  165. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/output_0.pb +0 -0
  166. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/model.onnx +0 -0
  167. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_1.pb +1 -1
  168. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_2.pb +1 -1
  169. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  170. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/model.onnx +0 -0
  171. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_1.pb +0 -0
  172. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_2.pb +0 -0
  173. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/output_0.pb +0 -0
  174. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/model.onnx +0 -0
  175. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
  176. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
  177. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  178. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/model.onnx +0 -0
  179. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_1.pb +0 -0
  180. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_2.pb +0 -0
  181. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/output_0.pb +0 -0
  182. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx +0 -0
  183. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
  184. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
  185. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  186. onnx/defs/nn/defs.cc +70 -228
  187. onnx/defs/nn/old.cc +31 -201
  188. onnx/defs/nn/utils.cc +222 -0
  189. onnx/defs/nn/utils.h +25 -0
  190. onnx/onnx_cpp2py_export.cp313t-win_amd64.pyd +0 -0
  191. onnx/reference/ops/op_attention.py +33 -14
  192. onnx/reference/ops/op_rotary_embedding.py +21 -19
  193. onnx/test/basic_test.py +84 -0
  194. onnx/test/reference_evaluator_test.py +23 -0
  195. onnx/test/test_backend_reference.py +2 -1
  196. onnx/version.py +2 -2
  197. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/METADATA +2 -2
  198. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/RECORD +202 -200
  199. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/WHEEL +1 -1
  200. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/entry_points.txt +0 -0
  201. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/licenses/LICENSE +0 -0
  202. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/top_level.txt +0 -0
onnx/defs/nn/old.cc CHANGED
@@ -5,6 +5,7 @@
5
5
  #include <cmath>
6
6
 
7
7
  #include "onnx/defs/function.h"
8
+ #include "onnx/defs/nn/utils.h"
8
9
  #include "onnx/defs/schema.h"
9
10
 
10
11
  namespace ONNX_NAMESPACE {
@@ -4373,154 +4374,7 @@ ONNX_OPERATOR_SET_SCHEMA(
4373
4374
  "U",
4374
4375
  OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
4375
4376
  "Constrain output 'mask' types to boolean tensors and input types.")
4376
- .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
4377
- propagateElemTypeFromInputToOutput(ctx, 0, 0);
4378
-
4379
- int64_t kv_sequence_length = -1;
4380
- ONNX_NAMESPACE::TensorShapeProto output_shape;
4381
- ONNX_NAMESPACE::TensorShapeProto qk_matmul_shape;
4382
- if (hasInputShape(ctx, 0)) {
4383
- auto& query_shape = getInputShape(ctx, 0);
4384
- auto& query_dims = query_shape.dim();
4385
- if ((query_dims.size() != 3) && (query_dims.size() != 4)) {
4386
- fail_shape_inference("Inputs 0 (query) shall be 3 or 4 dimensions");
4387
- }
4388
-
4389
- if (query_dims.size() == 3) {
4390
- auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
4391
- if (q_num_heads_attr == nullptr) {
4392
- fail_type_inference("3D inputs expected to have q_num_heads attribute.");
4393
- }
4394
- auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
4395
- if (kv_num_heads_attr == nullptr) {
4396
- fail_type_inference("3D inputs expected to have q_num_heads attribute.");
4397
- }
4398
- }
4399
-
4400
- *output_shape.add_dim() = query_dims[0]; // batch_size
4401
- *output_shape.add_dim() = query_dims[1]; // num_heads for 4D, sequence_length for 3D
4402
-
4403
- *qk_matmul_shape.add_dim() = query_dims[0]; // batch_size
4404
-
4405
- if (hasInputShape(ctx, 1)) {
4406
- auto& key_shape = getInputShape(ctx, 1);
4407
- auto& key_dims = key_shape.dim();
4408
- if ((key_dims.size() != 3) && (key_dims.size() != 4)) {
4409
- fail_shape_inference("Inputs 1 (key) shall be 3 or 4 dimensions");
4410
- }
4411
- }
4412
-
4413
- if (hasInputShape(ctx, 2)) {
4414
- auto& value_shape = getInputShape(ctx, 2);
4415
- auto& value_dims = value_shape.dim();
4416
- if ((value_dims.size() != 3) && (value_dims.size() != 4)) {
4417
- fail_shape_inference("Inputs 2 (value) shall be 3 or 4 dimensions");
4418
- }
4419
-
4420
- // Update Output Shape for 4D inputs
4421
- // Input 0 (query) has shape (batch_size, q_num_heads, q_sequence_length, head_size)
4422
- // Input 1 (key) has shape (batch_size, kv_num_heads, kv_sequence_length, head_size)
4423
- // Input 2 (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
4424
- // Output 0 has shape (batch_size, q_num_heads, q_sequence_length, v_head_size)
4425
- if (value_dims.size() == 4 && query_dims.size() == 4) {
4426
- kv_sequence_length = value_dims[2].dim_value();
4427
- *output_shape.add_dim() = query_dims[2]; // sequence_length
4428
- *output_shape.add_dim() = value_dims[3]; // head_size
4429
- updateOutputShape(ctx, 0, output_shape);
4430
- // Update qk_matmul_shape
4431
- *qk_matmul_shape.add_dim() = query_dims[1]; // q_num_heads
4432
- *qk_matmul_shape.add_dim() = query_dims[2]; // q_sequence_length
4433
- qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
4434
- }
4435
-
4436
- // Update Output Shape for 3D inputs
4437
- // Input 0 (query) has shape (batch_size, q_sequence_length, q_hidden_size),
4438
- // q_hidden_size = q_num_heads * head_size
4439
- // Input 1 (key) has shape (batch_size, kv_sequence_length, k_hidden_size),
4440
- // k_hidden_size = kv_num_heads * head_size
4441
- // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size),
4442
- // v_hidden_size = kv_num_heads * v_head_size
4443
- // Output 0 has shape (batch_size, q_sequence_length, hidden_size),
4444
- // hidden_size = q_num_heads * v_head_size
4445
- if (value_dims.size() == 3 && query_dims.size() == 3) {
4446
- kv_sequence_length = value_dims[1].dim_value();
4447
- auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
4448
- if (q_num_heads_attr == nullptr) {
4449
- fail_type_inference("3D inputs expected to have q_num_heads attribute.");
4450
- }
4451
- auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
4452
- if (kv_num_heads_attr == nullptr) {
4453
- fail_type_inference("3D inputs expected to have kv_num_heads attribute.");
4454
- }
4455
- int64_t q_num_heads = q_num_heads_attr->i();
4456
- int64_t kv_num_heads = kv_num_heads_attr->i();
4457
- // Calculate v_head_size
4458
- int64_t v_head_size = value_dims[2].dim_value() / kv_num_heads;
4459
- output_shape.add_dim()->set_dim_value(v_head_size * q_num_heads);
4460
- updateOutputShape(ctx, 0, output_shape);
4461
- // Update qk_matmul_shape
4462
- qk_matmul_shape.add_dim()->set_dim_value(q_num_heads);
4463
- *qk_matmul_shape.add_dim() = query_dims[1];
4464
- qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
4465
- }
4466
- }
4467
- }
4468
-
4469
- if (ctx.hasOutput(3)) { // has qk_matmul_output
4470
- propagateElemTypeFromInputToOutput(ctx, 0, 3);
4471
- updateOutputShape(ctx, 3, qk_matmul_shape);
4472
- }
4473
-
4474
- if (ctx.hasOutput(1) && ctx.hasOutput(2)) { // has present outputs
4475
- if (ctx.hasInput(4) && ctx.hasInput(5)) { // has past_key
4476
- // copy the type from query to present key and value
4477
- propagateElemTypeFromInputToOutput(ctx, 4, 1);
4478
- propagateElemTypeFromInputToOutput(ctx, 5, 2);
4479
-
4480
- if (hasInputShape(ctx, 4) && hasInputShape(ctx, 5)) {
4481
- auto& past_key_shape = getInputShape(ctx, 4);
4482
- auto& past_key_dims = past_key_shape.dim();
4483
- auto& past_value_shape = getInputShape(ctx, 5);
4484
- auto& past_value_dims = past_value_shape.dim();
4485
-
4486
- // past key has shape (batch_size, kv_num_heads, past_sequence_length, head_size)
4487
- if (past_key_dims.size() != 4) {
4488
- fail_shape_inference("The past_key input shall be 4 dimensions");
4489
- }
4490
- // past value has shape (batch_size, kv_num_heads, past_sequence_length, v_head_size)
4491
- if (past_value_dims.size() != 4) {
4492
- fail_shape_inference("The past_value input shall be 4 dimensions");
4493
- }
4494
-
4495
- if (kv_sequence_length > 0 && past_key_dims[2].has_dim_value()) {
4496
- int64_t total_sequence_length = kv_sequence_length + past_key_dims[2].dim_value();
4497
-
4498
- ONNX_NAMESPACE::TensorShapeProto present_key_shape;
4499
- for (auto& dim : past_key_dims) {
4500
- *present_key_shape.add_dim() = dim;
4501
- }
4502
-
4503
- ONNX_NAMESPACE::TensorShapeProto present_value_shape;
4504
- for (auto& dim : past_value_dims) {
4505
- *present_value_shape.add_dim() = dim;
4506
- }
4507
-
4508
- if (ctx.hasOutput(3)) { // has qk_matmul_output with bias
4509
- qk_matmul_shape.mutable_dim(3)->set_dim_value(total_sequence_length);
4510
- updateOutputShape(ctx, 3, qk_matmul_shape);
4511
- }
4512
-
4513
- // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
4514
- present_key_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
4515
- present_value_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
4516
-
4517
- updateOutputShape(ctx, 1, present_key_shape);
4518
- updateOutputShape(ctx, 2, present_value_shape);
4519
- }
4520
- }
4521
- }
4522
- }
4523
- })
4377
+ .TypeAndShapeInferenceFunction(defs::nn::utils::AttentionPropagateElemTypeFromInputToOutput)
4524
4378
  .SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx,
4525
4379
  const OpSchema& schema,
4526
4380
  FunctionProto& functionProto) {
@@ -4544,12 +4398,6 @@ ONNX_OPERATOR_SET_SCHEMA(
4544
4398
  (softmax_precision != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))
4545
4399
  return false; // Error
4546
4400
 
4547
- auto mkbooltensor = [](bool val) -> ONNX_NAMESPACE::TensorProto {
4548
- auto tp = ONNX_NAMESPACE::ToTensor(std::vector<bool>{val});
4549
- tp.add_dims(1);
4550
- return tp;
4551
- };
4552
-
4553
4401
  // If shape is 3D, q_num_heads and kv_num_heads is provided,
4554
4402
  // for 4D cases, set num_heads to zero for reshape purposes
4555
4403
  auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
@@ -4561,15 +4409,16 @@ ONNX_OPERATOR_SET_SCHEMA(
4561
4409
  bool is_3d_input = (q_num_heads > 0 && kv_num_heads > 0);
4562
4410
 
4563
4411
  FunctionBuilder builder(functionProto);
4412
+ builder
4413
+ .Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
4414
+ .Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
4415
+ .Add("KVSeqLen = Shape <start = -2, end = -1> (K)"); // kv_sequence_length
4564
4416
  if (is_3d_input) {
4565
4417
  // For 3D inputs: First reshape to [batch_size, seq_length, num_heads, head_size]
4566
4418
  // then transpose to [batch_size, num_heads, seq_length, head_size]
4567
4419
  builder
4568
- .Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
4569
4420
  .Const1D("QNumHeadsAttr", q_num_heads) // q_num_heads from attrs
4570
4421
  .Const1D("KVNumHeadsAttr", kv_num_heads) // kv_num_heads from attrs
4571
- .Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
4572
- .Add("KVSeqLen = Shape <start = -2, end = -1> (K)") // kv_sequence_length
4573
4422
  .Const1D("NegOne", static_cast<int64_t>(-1)); // head_size, inferred from other dimensions
4574
4423
 
4575
4424
  builder.Add("QIntermediateShape = Concat <axis = 0> (BatchSize, QSeqLen, QNumHeadsAttr, NegOne)")
@@ -4596,6 +4445,7 @@ ONNX_OPERATOR_SET_SCHEMA(
4596
4445
  builder
4597
4446
  .Add("QKHeadSize = Shape <start = 3, end = 4> (QReshaped)") // head_size for Q and K
4598
4447
  .Add("QKHeadSizeF = Cast (QKHeadSize)", "to", float_type)
4448
+ .Add("VHeadSize = Shape <start = 3, end = 4> (VReshaped)") // head_size for V
4599
4449
  .Add("SqrtHeadSize = Sqrt(QKHeadSizeF)")
4600
4450
  .Const1D("One1D", static_cast<int64_t>(1))
4601
4451
  .Const1D("One1DF", static_cast<float>(1))
@@ -4610,8 +4460,10 @@ ONNX_OPERATOR_SET_SCHEMA(
4610
4460
 
4611
4461
  if (ctx.hasInput(4)) {
4612
4462
  builder.Add("PresentKey = Concat <axis = 2> (past_key, KReshaped)");
4463
+ builder.Add("PastKVSeqLen = Shape <start = -2, end = -1> (past_key)");
4613
4464
  } else {
4614
4465
  builder.Add("PresentKey = Identity (KReshaped)");
4466
+ builder.Const1D("PastKVSeqLen", static_cast<int64_t>(0));
4615
4467
  }
4616
4468
  if (ctx.hasOutput(1)) {
4617
4469
  builder.Add("present_key = Identity (PresentKey)");
@@ -4626,46 +4478,9 @@ ONNX_OPERATOR_SET_SCHEMA(
4626
4478
  builder.Add("present_value = Identity (PresentValue)");
4627
4479
  }
4628
4480
 
4629
- // If attn_mask is provided
4630
- float neg_inf = -std::numeric_limits<float>::infinity();
4631
- if (ctx.hasInput(3)) {
4632
- auto* up = ctx.getInputType(3);
4633
- if ((up == nullptr) || (!up->has_tensor_type()))
4634
- return false;
4635
- int64_t U = up->tensor_type().elem_type();
4636
- builder.Const1D("FloatInf", neg_inf);
4637
- builder.Const1D("ScalarZero", 0.f);
4638
- builder.Add(
4639
- U == ONNX_NAMESPACE::TensorProto_DataType_BOOL ? "AttnBias = Where(attn_mask, ScalarZero, FloatInf)"
4640
- : "AttnBias = Identity(attn_mask)");
4641
- } else {
4642
- // If is_causal set to true, the attention masking is a lower triangular matrix when the mask
4643
- // is a square matrix. The attention masking has the form of the upper left causal bias due to
4644
- // the alignment when the mask is a non-square matrix.
4645
- // An error is thrown if both attn_mask and is_causal are set.
4646
- auto* is_causal_attr = ctx.getAttribute("is_causal");
4647
- int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
4648
- if (!is_3d_input) {
4649
- builder.Add("QSeqLen = Shape <start = -2, end = -1> (Q)"); // q_sequence_length
4650
- }
4651
-
4652
- if (is_causal == 1) {
4653
- builder.Const1D("FloatInf", neg_inf);
4654
- builder.Const1D("ScalarZero", 0.f);
4655
- builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)")
4656
- .Add("AttnBiasShape = Concat <axis = -1> (QSeqLen, NewKVSeqLen)")
4657
- .Add("AttnBiasZeros_ = ConstantOfShape(AttnBiasShape)")
4658
- .Add("AttnBiasZeros = CastLike(AttnBiasZeros_, Q)")
4659
- .Add("BoolMask = ConstantOfShape(AttnBiasShape)", "value", mkbooltensor(1))
4660
- .Add("BoolMaskTri = Trilu <upper = 0> (BoolMask, Zero1D)")
4661
- .Add("AttnBias = Where(BoolMaskTri, ScalarZero, FloatInf)");
4662
- } else {
4663
- builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)")
4664
- .Add("AttnBiasShape = Concat <axis = -1> (QSeqLen, NewKVSeqLen)")
4665
- .Add("AttnBias = ConstantOfShape(AttnBiasShape)");
4666
- }
4667
- }
4668
- builder.Add("AttnBiasT = Cast (AttnBias)", "to", T1);
4481
+ if (!defs::nn::utils::AttentionAppendFunctionCausalMask(ctx, builder, false))
4482
+ return false;
4483
+ builder.Add("AttnBiasT = Cast (AttnBiasCausalOrNot)", "to", T1);
4669
4484
 
4670
4485
  // Group Query Attention is applied if the following are satisfied
4671
4486
  // 1) q_num_heads != kv_num_heads
@@ -4678,10 +4493,25 @@ ONNX_OPERATOR_SET_SCHEMA(
4678
4493
  .Add("RemainderNumHeads = Mod(QNumHeads, KVNumHeads)")
4679
4494
  .Add("GQACond2 = Equal(RemainderNumHeads, Zero1D)")
4680
4495
  .Add("GQACond = And(GQACond1, GQACond2)")
4681
- .Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)")
4682
- .Add("InterleaveShape = Concat <axis = 0> (One1D, InterleaveDim, One1D, One1D)")
4683
- .Add("KAttentionInput = Tile(PresentKey, InterleaveShape)")
4684
- .Add("VAttentionInput = Tile(PresentValue, InterleaveShape)");
4496
+ .Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)");
4497
+
4498
+ // repeat kv (repeat_interleave)
4499
+ builder.Const1D("Two1D", static_cast<int64_t>(2))
4500
+ .Add("KUnsqueezed = Unsqueeze(PresentKey, Two1D)") // [B, Hk, 1, T, Dk]
4501
+ .Add("VUnsqueezed = Unsqueeze(PresentValue, Two1D)"); // [B, Hk, 1, T, Dv]
4502
+
4503
+ // Build expand shape: [B, Hk, repeats, T, Dk]
4504
+ builder
4505
+ .Add("KExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, QKHeadSize)")
4506
+ .Add("KExpanded = Expand(KUnsqueezed, KExpandShape)");
4507
+ builder.Add("VExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, VHeadSize)")
4508
+ .Add("VExpanded = Expand(VUnsqueezed, VExpandShape)");
4509
+
4510
+ // Reshape to [B, Hq, T, Dk] where Hq = Hk * repeats
4511
+ builder.Add("KAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, QKHeadSize)")
4512
+ .Add("VAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, VHeadSize)")
4513
+ .Add("KAttentionInput = Reshape(KExpanded, KAttentionShape)")
4514
+ .Add("VAttentionInput = Reshape(VExpanded, VAttentionShape)");
4685
4515
 
4686
4516
  // The following pattern is applied
4687
4517
  // Q K V
onnx/defs/nn/utils.cc ADDED
@@ -0,0 +1,222 @@
1
+ /*
2
+ * SPDX-License-Identifier: Apache-2.0
3
+ */
4
+ #include "onnx/defs/nn/utils.h"
5
+
6
+ #include <algorithm>
7
+
8
+ namespace ONNX_NAMESPACE {
9
+ namespace defs {
10
+ namespace nn {
11
+ namespace utils {
12
+
13
+ void AttentionPropagateElemTypeFromInputToOutput(InferenceContext& ctx) {
14
+ propagateElemTypeFromInputToOutput(ctx, 0, 0);
15
+
16
+ int64_t kv_sequence_length = -1;
17
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
18
+ ONNX_NAMESPACE::TensorShapeProto qk_matmul_shape;
19
+ if (hasInputShape(ctx, 0)) {
20
+ auto& query_shape = getInputShape(ctx, 0);
21
+ auto& query_dims = query_shape.dim();
22
+ if ((query_dims.size() != 3) && (query_dims.size() != 4)) {
23
+ fail_shape_inference("Inputs 0 (query) shall be 3 or 4 dimensions");
24
+ }
25
+
26
+ if (query_dims.size() == 3) {
27
+ auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
28
+ if (q_num_heads_attr == nullptr) {
29
+ fail_type_inference("3D inputs expected to have q_num_heads attribute.");
30
+ }
31
+ auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
32
+ if (kv_num_heads_attr == nullptr) {
33
+ fail_type_inference("3D inputs expected to have q_num_heads attribute.");
34
+ }
35
+ }
36
+
37
+ *output_shape.add_dim() = query_dims[0]; // batch_size
38
+ *output_shape.add_dim() = query_dims[1]; // num_heads for 4D, sequence_length for 3D
39
+
40
+ *qk_matmul_shape.add_dim() = query_dims[0]; // batch_size
41
+
42
+ if (hasInputShape(ctx, 1)) {
43
+ auto& key_shape = getInputShape(ctx, 1);
44
+ auto& key_dims = key_shape.dim();
45
+ if ((key_dims.size() != 3) && (key_dims.size() != 4)) {
46
+ fail_shape_inference("Inputs 1 (key) shall be 3 or 4 dimensions");
47
+ }
48
+ }
49
+
50
+ if (hasInputShape(ctx, 2)) {
51
+ auto& value_shape = getInputShape(ctx, 2);
52
+ auto& value_dims = value_shape.dim();
53
+ if ((value_dims.size() != 3) && (value_dims.size() != 4)) {
54
+ fail_shape_inference("Inputs 2 (value) shall be 3 or 4 dimensions");
55
+ }
56
+
57
+ // Update Output Shape for 4D inputs
58
+ // Input 0 (query) has shape (batch_size, q_num_heads, q_sequence_length, head_size)
59
+ // Input 1 (key) has shape (batch_size, kv_num_heads, kv_sequence_length, head_size)
60
+ // Input 2 (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
61
+ // Output 0 has shape (batch_size, q_num_heads, q_sequence_length, v_head_size)
62
+ if (value_dims.size() == 4 && query_dims.size() == 4) {
63
+ kv_sequence_length = value_dims[2].dim_value();
64
+ *output_shape.add_dim() = query_dims[2]; // sequence_length
65
+ *output_shape.add_dim() = value_dims[3]; // head_size
66
+ updateOutputShape(ctx, 0, output_shape);
67
+ // Update qk_matmul_shape
68
+ *qk_matmul_shape.add_dim() = query_dims[1]; // q_num_heads
69
+ *qk_matmul_shape.add_dim() = query_dims[2]; // q_sequence_length
70
+ qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
71
+ }
72
+
73
+ // Update Output Shape for 3D inputs
74
+ // Input 0 (query) has shape (batch_size, q_sequence_length, q_hidden_size),
75
+ // q_hidden_size = q_num_heads * head_size
76
+ // Input 1 (key) has shape (batch_size, kv_sequence_length, k_hidden_size),
77
+ // k_hidden_size = kv_num_heads * head_size
78
+ // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size),
79
+ // v_hidden_size = kv_num_heads * v_head_size
80
+ // Output 0 has shape (batch_size, q_sequence_length, hidden_size),
81
+ // hidden_size = q_num_heads * v_head_size
82
+ if (value_dims.size() == 3 && query_dims.size() == 3) {
83
+ kv_sequence_length = value_dims[1].dim_value();
84
+ auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
85
+ if (q_num_heads_attr == nullptr) {
86
+ fail_type_inference("3D inputs expected to have q_num_heads attribute.");
87
+ }
88
+ auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
89
+ if (kv_num_heads_attr == nullptr) {
90
+ fail_type_inference("3D inputs expected to have kv_num_heads attribute.");
91
+ }
92
+ int64_t q_num_heads = q_num_heads_attr->i();
93
+ int64_t kv_num_heads = kv_num_heads_attr->i();
94
+ // Calculate v_head_size
95
+ int64_t v_head_size = value_dims[2].dim_value() / kv_num_heads;
96
+ output_shape.add_dim()->set_dim_value(v_head_size * q_num_heads);
97
+ updateOutputShape(ctx, 0, output_shape);
98
+ // Update qk_matmul_shape
99
+ qk_matmul_shape.add_dim()->set_dim_value(q_num_heads);
100
+ *qk_matmul_shape.add_dim() = query_dims[1];
101
+ qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
102
+ }
103
+ }
104
+ }
105
+
106
+ if (ctx.hasOutput(3)) { // has qk_matmul_output
107
+ propagateElemTypeFromInputToOutput(ctx, 0, 3);
108
+ updateOutputShape(ctx, 3, qk_matmul_shape);
109
+ }
110
+
111
+ if (ctx.hasOutput(1) && ctx.hasOutput(2)) { // has present outputs
112
+ if (ctx.hasInput(4) && ctx.hasInput(5)) { // has past_key
113
+ // copy the type from query to present key and value
114
+ propagateElemTypeFromInputToOutput(ctx, 4, 1);
115
+ propagateElemTypeFromInputToOutput(ctx, 5, 2);
116
+
117
+ if (hasInputShape(ctx, 4) && hasInputShape(ctx, 5)) {
118
+ auto& past_key_shape = getInputShape(ctx, 4);
119
+ auto& past_key_dims = past_key_shape.dim();
120
+ auto& past_value_shape = getInputShape(ctx, 5);
121
+ auto& past_value_dims = past_value_shape.dim();
122
+
123
+ // past key has shape (batch_size, kv_num_heads, past_sequence_length, head_size)
124
+ if (past_key_dims.size() != 4) {
125
+ fail_shape_inference("The past_key input shall be 4 dimensions");
126
+ }
127
+ // past value has shape (batch_size, kv_num_heads, past_sequence_length, v_head_size)
128
+ if (past_value_dims.size() != 4) {
129
+ fail_shape_inference("The past_value input shall be 4 dimensions");
130
+ }
131
+
132
+ if (kv_sequence_length > 0 && past_key_dims[2].has_dim_value()) {
133
+ int64_t total_sequence_length = kv_sequence_length + past_key_dims[2].dim_value();
134
+
135
+ ONNX_NAMESPACE::TensorShapeProto present_key_shape;
136
+ for (auto& dim : past_key_dims) {
137
+ *present_key_shape.add_dim() = dim;
138
+ }
139
+
140
+ ONNX_NAMESPACE::TensorShapeProto present_value_shape;
141
+ for (auto& dim : past_value_dims) {
142
+ *present_value_shape.add_dim() = dim;
143
+ }
144
+
145
+ if (ctx.hasOutput(3)) { // has qk_matmul_output with bias
146
+ qk_matmul_shape.mutable_dim(3)->set_dim_value(total_sequence_length);
147
+ updateOutputShape(ctx, 3, qk_matmul_shape);
148
+ }
149
+
150
+ // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
151
+ present_key_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
152
+ present_value_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
153
+
154
+ updateOutputShape(ctx, 1, present_key_shape);
155
+ updateOutputShape(ctx, 2, present_value_shape);
156
+ }
157
+ }
158
+ }
159
+ }
160
+ }
161
+
162
+ bool AttentionAppendFunctionCausalMask(const FunctionBodyBuildContext& ctx, FunctionBuilder& builder, bool padding) {
163
+ builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)")
164
+ .Add("AttnBiasShape = Concat <axis = 0> (QSeqLen, NewKVSeqLen)");
165
+ float neg_inf = -std::numeric_limits<float>::infinity();
166
+ builder.Const1D("FloatNegInf", neg_inf);
167
+ builder.Const1D("ScalarZero", 0.f);
168
+
169
+ // If attn_mask is provided
170
+ if (ctx.hasInput(3)) {
171
+ auto* up = ctx.getInputType(3);
172
+ if ((up == nullptr) || (!up->has_tensor_type()))
173
+ return false;
174
+ int64_t U = up->tensor_type().elem_type();
175
+ builder.Add(
176
+ U == ONNX_NAMESPACE::TensorProto_DataType_BOOL ? "AttnBiasShort = Where(attn_mask, ScalarZero, FloatNegInf)"
177
+ : "AttnBiasShort = Identity(attn_mask)");
178
+ // If attn_mask has a shorter kv sequence length, we pad it to NewKVSeqLen with FloatNegInf
179
+ if (padding) {
180
+ builder.Add("MaskKVSeqLen = Shape <start = -1> (attn_mask)")
181
+ .Add("PaddingKVSeqLen = Sub(NewKVSeqLen, MaskKVSeqLen)")
182
+ .Add("Pads = Concat <axis = 0> (Zero1D, PaddingKVSeqLen)")
183
+ .Add("FloatNegInfCast = CastLike(FloatNegInf, AttnBiasShort)")
184
+ .Add("AttnBias = Pad(AttnBiasShort, Pads, FloatNegInfCast, NegOne1D)");
185
+ } else {
186
+ builder.Add("AttnBias = Identity(AttnBiasShort)");
187
+ }
188
+ } else {
189
+ builder.Add("AttnBias = ConstantOfShape(AttnBiasShape)");
190
+ }
191
+
192
+ // If is_causal set to true, the attention masking is a lower triangular matrix when the mask
193
+ // is a square matrix. The attention masking has the form of the upper left causal bias due to
194
+ // the alignment when the mask is a non-square matrix.
195
+ // An error is thrown if both attn_mask and is_causal are set.
196
+ auto* is_causal_attr = ctx.getAttribute("is_causal");
197
+ int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
198
+ if (is_causal == 1) {
199
+ builder.Const1D("Zero", static_cast<int64_t>(0))
200
+ .Const1D("One", static_cast<int64_t>(1))
201
+ .Add("ZeroNoDim = Squeeze(Zero, Zero)")
202
+ .Add("OneNoDim = Squeeze(One, Zero)")
203
+ .Add("SequenceLength = Gather(AttnBiasShape, ZeroNoDim)")
204
+ .Add("TotalSequenceLength = Gather(AttnBiasShape, OneNoDim)")
205
+ .Add("RangeRow = Range(ZeroNoDim, SequenceLength, OneNoDim)")
206
+ .Add("RangeRow2D = Unsqueeze(RangeRow, One)")
207
+ .Add("RangeCol = Range(ZeroNoDim, TotalSequenceLength, OneNoDim)")
208
+ .Add("RangeCol2D = Unsqueeze(RangeCol, Zero)")
209
+ .Add("RangeRow2DPast = Add(RangeRow2D, PastKVSeqLen)")
210
+ .Add("BoolMaskTri = Less(RangeRow2DPast, RangeCol2D)")
211
+ .Add("MaskTri = Where(BoolMaskTri, FloatNegInf, ScalarZero)")
212
+ .Add("AttnBiasCausalOrNot = Add(AttnBias, MaskTri)");
213
+ } else {
214
+ builder.Add("AttnBiasCausalOrNot = Identity(AttnBias)");
215
+ }
216
+ return true;
217
+ }
218
+
219
+ } // namespace utils
220
+ } // namespace nn
221
+ } // namespace defs
222
+ } // namespace ONNX_NAMESPACE
onnx/defs/nn/utils.h ADDED
@@ -0,0 +1,25 @@
1
+ /*
2
+ * SPDX-License-Identifier: Apache-2.0
3
+ */
4
+
5
+ #pragma once
6
+
7
+ #include "onnx/common/assertions.h"
8
+ #include "onnx/defs/function.h"
9
+ #include "onnx/defs/schema.h"
10
+
11
+ namespace ONNX_NAMESPACE {
12
+ namespace defs {
13
+ namespace nn {
14
+ namespace utils {
15
+
16
+ /** Implements shape and type propagation for Attention (23-). */
17
+ void AttentionPropagateElemTypeFromInputToOutput(InferenceContext& ctx);
18
+
19
+ /** Implements CausalMask for Attention. */
20
+ bool AttentionAppendFunctionCausalMask(const FunctionBodyBuildContext& ctx, FunctionBuilder& builder, bool padding);
21
+
22
+ } // namespace utils
23
+ } // namespace nn
24
+ } // namespace defs
25
+ } // namespace ONNX_NAMESPACE
@@ -24,6 +24,25 @@ def _softcap(X, softcap):
24
24
  return X
25
25
 
26
26
 
27
+ def _apply_causal(mask, past_sequence_length):
28
+ """Applies a causal mask on the input `mask`:
29
+ ``mask[i, j] = -inf if past_sequence_length + i > j else 0``.
30
+ Because a softmax is applied on the mask, -inf becomes 0 and 0 becomes 1.
31
+ The modification is done inplace.
32
+ """
33
+ q_sequence_length, total_sequence_length = mask.shape[-2:]
34
+ triu = np.triu(
35
+ np.ones(
36
+ (q_sequence_length, total_sequence_length - past_sequence_length),
37
+ dtype=mask.dtype,
38
+ ),
39
+ k=1,
40
+ )
41
+ triu[triu == 1] = -np.inf
42
+ mask[..., :, past_sequence_length:] += triu
43
+ return mask
44
+
45
+
27
46
  def _compute_attention(
28
47
  Q: np.ndarray,
29
48
  K: np.ndarray,
@@ -114,18 +133,18 @@ def _compute_attention(
114
133
  # bias due to the alignment when the mask is a non-square matrix.
115
134
  if is_causal:
116
135
  if attn_mask is None:
117
- temp_mask = np.ones((q_sequence_length, kv_sequence_length), dtype=bool)
118
- temp_mask = np.tril(temp_mask, k=0)
119
- temp_mask = np.logical_not(temp_mask)
120
- attn_bias_ma = np.ma.array(attn_bias, mask=temp_mask)
121
- attn_bias = attn_bias_ma.filled(fill_value=float("-inf"))
136
+ temp_mask = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype)
137
+ attn_bias = _apply_causal(
138
+ temp_mask,
139
+ past_sequence_length=past_key.shape[2] if past_key is not None else 0,
140
+ )
122
141
  else:
123
142
  if attn_mask.dtype == np.bool_:
124
143
  attn_mask = (1 - attn_mask).astype(Q.dtype) * (-np.inf)
125
- temp_mask = np.ones((q_sequence_length, kv_sequence_length), dtype=Q.dtype)
126
- temp_mask = 1 - np.tril(temp_mask, k=0)
127
- temp_mask[temp_mask == 1] = -np.inf
128
- attn_bias = attn_mask + temp_mask
144
+ attn_bias = _apply_causal(
145
+ attn_mask.copy(),
146
+ past_sequence_length=past_key.shape[2] if past_key is not None else 0,
147
+ )
129
148
  elif attn_mask is not None:
130
149
  if attn_mask.dtype == np.bool_:
131
150
  attn_mask = (1 - attn_mask).astype(Q.dtype)
@@ -158,10 +177,10 @@ def _compute_attention(
158
177
  and (q_num_heads % k_num_heads == 0)
159
178
  and (k_num_heads == v_num_heads)
160
179
  ):
161
- seq_reps = int(q_num_heads / k_num_heads)
162
- reps = [1, seq_reps, 1, 1]
163
- K = np.tile(K, reps)
164
- V = np.tile(V, reps)
180
+ seq_reps = q_num_heads // k_num_heads
181
+ # Interleave-repeat each KV head: [h0, h0, h1, h1, ...]
182
+ K = np.repeat(K, repeats=seq_reps, axis=1)
183
+ V = np.repeat(V, repeats=seq_reps, axis=1)
165
184
 
166
185
  # The following pattern is applied
167
186
  # Q K V
@@ -185,7 +204,7 @@ def _compute_attention(
185
204
  qk_matmul_output = np.matmul(Q * scale, k_transpose * scale)
186
205
  qk_with_bias = qk_matmul_output + attn_bias
187
206
  if qk_matmul_output_mode == 1:
188
- qk_matmul_output = qk_matmul_output + attn_bias
207
+ qk_matmul_output = qk_with_bias.copy()
189
208
 
190
209
  # Apply softcap
191
210
  if softcap is not None: