onnx 1.19.0__cp310-cp310-win_amd64.whl → 1.19.1rc1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx might be problematic. Click here for more details.

Files changed (202) hide show
  1. onnx/__init__.py +98 -0
  2. onnx/backend/test/case/node/__init__.py +20 -3
  3. onnx/backend/test/case/node/attention.py +62 -0
  4. onnx/backend/test/case/node/rotaryembedding.py +6 -6
  5. onnx/backend/test/data/node/test_attention_3d/model.onnx +0 -0
  6. onnx/backend/test/data/node/test_attention_3d_attn_mask/model.onnx +0 -0
  7. onnx/backend/test/data/node/test_attention_3d_attn_mask_expanded/model.onnx +0 -0
  8. onnx/backend/test/data/node/test_attention_3d_causal/model.onnx +0 -0
  9. onnx/backend/test/data/node/test_attention_3d_causal_expanded/model.onnx +0 -0
  10. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes/model.onnx +0 -0
  11. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask/model.onnx +0 -0
  12. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
  13. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal/model.onnx +0 -0
  14. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
  15. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_expanded/model.onnx +0 -0
  16. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled/model.onnx +0 -0
  17. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
  18. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap/model.onnx +0 -0
  19. onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
  20. onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present/model.onnx +0 -0
  21. onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
  22. onnx/backend/test/data/node/test_attention_3d_expanded/model.onnx +0 -0
  23. onnx/backend/test/data/node/test_attention_3d_gqa/model.onnx +0 -0
  24. onnx/backend/test/data/node/test_attention_3d_gqa/test_data_set_0/output_0.pb +0 -0
  25. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/model.onnx +0 -0
  26. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
  27. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/model.onnx +0 -0
  28. onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
  29. onnx/backend/test/data/node/test_attention_3d_gqa_causal/model.onnx +0 -0
  30. onnx/backend/test/data/node/test_attention_3d_gqa_causal/test_data_set_0/output_0.pb +0 -0
  31. onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/model.onnx +0 -0
  32. onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
  33. onnx/backend/test/data/node/test_attention_3d_gqa_expanded/model.onnx +0 -0
  34. onnx/backend/test/data/node/test_attention_3d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
  35. onnx/backend/test/data/node/test_attention_3d_gqa_scaled/model.onnx +0 -0
  36. onnx/backend/test/data/node/test_attention_3d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
  37. onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/model.onnx +0 -0
  38. onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
  39. onnx/backend/test/data/node/test_attention_3d_gqa_softcap/model.onnx +0 -0
  40. onnx/backend/test/data/node/test_attention_3d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
  41. onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/model.onnx +0 -0
  42. onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
  43. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/model.onnx +0 -0
  44. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
  45. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/model.onnx +0 -0
  46. onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
  47. onnx/backend/test/data/node/test_attention_3d_scaled/model.onnx +0 -0
  48. onnx/backend/test/data/node/test_attention_3d_scaled_expanded/model.onnx +0 -0
  49. onnx/backend/test/data/node/test_attention_3d_softcap/model.onnx +0 -0
  50. onnx/backend/test/data/node/test_attention_3d_softcap_expanded/model.onnx +0 -0
  51. onnx/backend/test/data/node/test_attention_3d_transpose_verification/model.onnx +0 -0
  52. onnx/backend/test/data/node/test_attention_3d_transpose_verification_expanded/model.onnx +0 -0
  53. onnx/backend/test/data/node/test_attention_3d_with_past_and_present/model.onnx +0 -0
  54. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_expanded/model.onnx +0 -0
  55. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul/model.onnx +0 -0
  56. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
  57. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
  58. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
  59. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap/model.onnx +0 -0
  60. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap_expanded/model.onnx +0 -0
  61. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax/model.onnx +0 -0
  62. onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax_expanded/model.onnx +0 -0
  63. onnx/backend/test/data/node/test_attention_4d/model.onnx +0 -0
  64. onnx/backend/test/data/node/test_attention_4d_attn_mask/model.onnx +0 -0
  65. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d/model.onnx +0 -0
  66. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal/model.onnx +0 -0
  67. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal_expanded/model.onnx +0 -0
  68. onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_expanded/model.onnx +0 -0
  69. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d/model.onnx +0 -0
  70. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal/model.onnx +0 -0
  71. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal_expanded/model.onnx +0 -0
  72. onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_expanded/model.onnx +0 -0
  73. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool/model.onnx +0 -0
  74. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d/model.onnx +0 -0
  75. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d_expanded/model.onnx +0 -0
  76. onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_expanded/model.onnx +0 -0
  77. onnx/backend/test/data/node/test_attention_4d_attn_mask_expanded/model.onnx +0 -0
  78. onnx/backend/test/data/node/test_attention_4d_causal/model.onnx +0 -0
  79. onnx/backend/test/data/node/test_attention_4d_causal_expanded/model.onnx +0 -0
  80. onnx/backend/test/data/node/test_attention_4d_diff_heads_mask4d_padded_kv_expanded/model.onnx +0 -0
  81. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes/model.onnx +0 -0
  82. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask/model.onnx +0 -0
  83. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
  84. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal/model.onnx +0 -0
  85. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
  86. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_expanded/model.onnx +0 -0
  87. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled/model.onnx +0 -0
  88. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
  89. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap/model.onnx +0 -0
  90. onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
  91. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present/model.onnx +0 -0
  92. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
  93. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d/model.onnx +0 -0
  94. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d_expanded/model.onnx +0 -0
  95. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d/model.onnx +0 -0
  96. onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d_expanded/model.onnx +0 -0
  97. onnx/backend/test/data/node/test_attention_4d_expanded/model.onnx +0 -0
  98. onnx/backend/test/data/node/test_attention_4d_fp16/model.onnx +0 -0
  99. onnx/backend/test/data/node/test_attention_4d_fp16_expanded/model.onnx +0 -0
  100. onnx/backend/test/data/node/test_attention_4d_gqa/model.onnx +0 -0
  101. onnx/backend/test/data/node/test_attention_4d_gqa/test_data_set_0/output_0.pb +0 -0
  102. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/model.onnx +0 -0
  103. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
  104. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/model.onnx +0 -0
  105. onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
  106. onnx/backend/test/data/node/test_attention_4d_gqa_causal/model.onnx +0 -0
  107. onnx/backend/test/data/node/test_attention_4d_gqa_causal/test_data_set_0/output_0.pb +0 -0
  108. onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/model.onnx +0 -0
  109. onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
  110. onnx/backend/test/data/node/test_attention_4d_gqa_expanded/model.onnx +0 -0
  111. onnx/backend/test/data/node/test_attention_4d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
  112. onnx/backend/test/data/node/test_attention_4d_gqa_scaled/model.onnx +0 -0
  113. onnx/backend/test/data/node/test_attention_4d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
  114. onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/model.onnx +0 -0
  115. onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
  116. onnx/backend/test/data/node/test_attention_4d_gqa_softcap/model.onnx +0 -0
  117. onnx/backend/test/data/node/test_attention_4d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
  118. onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/model.onnx +0 -0
  119. onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
  120. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/model.onnx +0 -0
  121. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
  122. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/model.onnx +0 -0
  123. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
  124. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/model.onnx +0 -0
  125. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/test_data_set_0/output_0.pb +0 -0
  126. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/model.onnx +0 -0
  127. onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/test_data_set_0/output_0.pb +0 -0
  128. onnx/backend/test/data/node/test_attention_4d_scaled/model.onnx +0 -0
  129. onnx/backend/test/data/node/test_attention_4d_scaled_expanded/model.onnx +0 -0
  130. onnx/backend/test/data/node/test_attention_4d_softcap/model.onnx +0 -0
  131. onnx/backend/test/data/node/test_attention_4d_softcap_expanded/model.onnx +0 -0
  132. onnx/backend/test/data/node/test_attention_4d_with_past_and_present/model.onnx +0 -0
  133. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_expanded/model.onnx +0 -0
  134. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul/model.onnx +0 -0
  135. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
  136. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask/model.onnx +0 -0
  137. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/model.onnx +0 -0
  138. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_0.pb +0 -0
  139. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_3.pb +0 -0
  140. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/model.onnx +0 -0
  141. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
  142. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
  143. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_expanded/model.onnx +0 -0
  144. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask/model.onnx +0 -0
  145. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/model.onnx +0 -0
  146. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_0.pb +0 -0
  147. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_3.pb +0 -0
  148. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/model.onnx +0 -0
  149. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
  150. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
  151. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_expanded/model.onnx +0 -0
  152. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
  153. onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
  154. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul/model.onnx +0 -0
  155. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias/model.onnx +0 -0
  156. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias_expanded/model.onnx +0 -0
  157. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_expanded/model.onnx +0 -0
  158. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap/model.onnx +0 -0
  159. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap_expanded/model.onnx +0 -0
  160. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax/model.onnx +0 -0
  161. onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax_expanded/model.onnx +0 -0
  162. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/model.onnx +0 -0
  163. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_1.pb +1 -1
  164. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_2.pb +1 -1
  165. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/output_0.pb +0 -0
  166. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/model.onnx +0 -0
  167. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_1.pb +1 -1
  168. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_2.pb +1 -1
  169. onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  170. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/model.onnx +0 -0
  171. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_1.pb +0 -0
  172. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_2.pb +0 -0
  173. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/output_0.pb +0 -0
  174. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/model.onnx +0 -0
  175. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
  176. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
  177. onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  178. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/model.onnx +0 -0
  179. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_1.pb +0 -0
  180. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_2.pb +0 -0
  181. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/output_0.pb +0 -0
  182. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx +0 -0
  183. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
  184. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
  185. onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
  186. onnx/defs/nn/defs.cc +70 -228
  187. onnx/defs/nn/old.cc +31 -201
  188. onnx/defs/nn/utils.cc +222 -0
  189. onnx/defs/nn/utils.h +25 -0
  190. onnx/onnx_cpp2py_export.cp310-win_amd64.pyd +0 -0
  191. onnx/reference/ops/op_attention.py +33 -14
  192. onnx/reference/ops/op_rotary_embedding.py +21 -19
  193. onnx/test/basic_test.py +84 -0
  194. onnx/test/reference_evaluator_test.py +23 -0
  195. onnx/test/test_backend_reference.py +2 -1
  196. onnx/version.py +2 -2
  197. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/METADATA +2 -2
  198. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/RECORD +202 -200
  199. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/WHEEL +1 -1
  200. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/entry_points.txt +0 -0
  201. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/licenses/LICENSE +0 -0
  202. {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/top_level.txt +0 -0
onnx/defs/nn/defs.cc CHANGED
@@ -8,6 +8,7 @@
8
8
 
9
9
  #include "onnx/common/assertions.h"
10
10
  #include "onnx/defs/function.h"
11
+ #include "onnx/defs/nn/utils.h"
11
12
  #include "onnx/defs/schema.h"
12
13
 
13
14
  namespace ONNX_NAMESPACE {
@@ -2992,15 +2993,16 @@ The rotation ensures that the model captures both absolute and relative position
2992
2993
  Rotary embeddings are defined using the following algorithm:
2993
2994
 
2994
2995
  ```python
2995
- def compute_rotary_embedding(
2996
- input,
2997
- position_ids,
2998
- sin_cache,
2999
- cos_cache,
3000
- interleaved=0,
3001
- rotary_embedding_dim=0,
3002
- num_heads=0,
3003
- ):
2996
+ def rotary_embedding(
2997
+ input: np.ndarray,
2998
+ cos_cache: np.ndarray,
2999
+ sin_cache: np.ndarray,
3000
+ position_ids: np.ndarray | None = None,
3001
+ interleaved=None,
3002
+ rotary_embedding_dim=None,
3003
+ num_heads=None,
3004
+ ) -> np.ndarray:
3005
+ original_input_shape = input.shape
3004
3006
  # First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
3005
3007
  if len(input.shape) == 4:
3006
3008
  input = np.transpose(input, (0, 2, 1, 3))
@@ -3016,7 +3018,7 @@ def compute_rotary_embedding(
3016
3018
  head_size = input.shape[3]
3017
3019
 
3018
3020
  # Fully or partially perform rotation on input based on rotary_embedding_dim attribute
3019
- if rotary_embedding_dim == 0:
3021
+ if rotary_embedding_dim is None or rotary_embedding_dim == 0:
3020
3022
  # If rotary_embedding_dim not provided, perform full rotation by using head_size
3021
3023
  rotary_embedding_dim = head_size
3022
3024
  x_rotate = input[:, :, :, :rotary_embedding_dim]
@@ -3025,15 +3027,29 @@ def compute_rotary_embedding(
3025
3027
 
3026
3028
  # Retrieve sin and cos caches using position ids
3027
3029
  if position_ids is not None:
3028
- cos = cos_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2]
3029
- sin = sin_cache[position_ids] # Shape: [batch_size, sequence_length, head_size/2]
3030
- else:
3031
- cos = cos_cache
3032
- sin = sin_cache
3033
- cos = cos[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
3034
- sin = sin[:, :, :rotary_embedding_dim_half] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
3035
- cos = np.expand_dims(cos, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
3036
- sin = np.expand_dims(sin, axis=2) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
3030
+ cos_cache = cos_cache[
3031
+ position_ids
3032
+ ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
3033
+ sin_cache = sin_cache[
3034
+ position_ids
3035
+ ] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
3036
+
3037
+ # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
3038
+ if cos_cache.shape[-1] != rotary_embedding_dim_half:
3039
+ raise ValueError(
3040
+ f"Last dimension of cos cache ({cos_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half})."
3041
+ )
3042
+ if sin_cache.shape[-1] != rotary_embedding_dim_half:
3043
+ raise ValueError(
3044
+ f"Last dimension of sin cache ({sin_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half})."
3045
+ )
3046
+
3047
+ cos_cache = np.expand_dims(
3048
+ cos_cache, axis=2
3049
+ ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
3050
+ sin_cache = np.expand_dims(
3051
+ sin_cache, axis=2
3052
+ ) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
3037
3053
 
3038
3054
  # Either divide the input in halves or interleave (based on interleaved attribute)
3039
3055
  if interleaved:
@@ -3043,8 +3059,8 @@ def compute_rotary_embedding(
3043
3059
  x1, x2 = np.split(x_rotate, 2, axis=-1)
3044
3060
 
3045
3061
  # Calculate real and imaginary values
3046
- real = cos * x1 - sin * x2
3047
- imag = sin * x1 + cos * x2
3062
+ real = (cos_cache * x1) - (sin_cache * x2)
3063
+ imag = (sin_cache * x1) + (cos_cache * x2)
3048
3064
 
3049
3065
  # Inserted rotated embeddings back to the original input
3050
3066
  if interleaved:
@@ -3058,7 +3074,7 @@ def compute_rotary_embedding(
3058
3074
  x_rotate = np.concatenate((real, imag), axis=-1)
3059
3075
  output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
3060
3076
  if len(original_input_shape) == 3:
3061
- output = np.reshape(output, input.shape)
3077
+ output = np.reshape(output, original_input_shape)
3062
3078
  else:
3063
3079
  output = np.transpose(output, (0, 2, 1, 3))
3064
3080
  return output
@@ -3505,154 +3521,7 @@ ONNX_OPERATOR_SET_SCHEMA(
3505
3521
  "U",
3506
3522
  OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
3507
3523
  "Constrain output 'mask' types to boolean tensors and input types.")
3508
- .TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
3509
- propagateElemTypeFromInputToOutput(ctx, 0, 0);
3510
-
3511
- int64_t kv_sequence_length = -1;
3512
- ONNX_NAMESPACE::TensorShapeProto output_shape;
3513
- ONNX_NAMESPACE::TensorShapeProto qk_matmul_shape;
3514
- if (hasInputShape(ctx, 0)) {
3515
- auto& query_shape = getInputShape(ctx, 0);
3516
- auto& query_dims = query_shape.dim();
3517
- if ((query_dims.size() != 3) && (query_dims.size() != 4)) {
3518
- fail_shape_inference("Inputs 0 (query) shall be 3 or 4 dimensions");
3519
- }
3520
-
3521
- if (query_dims.size() == 3) {
3522
- auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
3523
- if (q_num_heads_attr == nullptr) {
3524
- fail_type_inference("3D inputs expected to have q_num_heads attribute.");
3525
- }
3526
- auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
3527
- if (kv_num_heads_attr == nullptr) {
3528
- fail_type_inference("3D inputs expected to have q_num_heads attribute.");
3529
- }
3530
- }
3531
-
3532
- *output_shape.add_dim() = query_dims[0]; // batch_size
3533
- *output_shape.add_dim() = query_dims[1]; // num_heads for 4D, sequence_length for 3D
3534
-
3535
- *qk_matmul_shape.add_dim() = query_dims[0]; // batch_size
3536
-
3537
- if (hasInputShape(ctx, 1)) {
3538
- auto& key_shape = getInputShape(ctx, 1);
3539
- auto& key_dims = key_shape.dim();
3540
- if ((key_dims.size() != 3) && (key_dims.size() != 4)) {
3541
- fail_shape_inference("Inputs 1 (key) shall be 3 or 4 dimensions");
3542
- }
3543
- }
3544
-
3545
- if (hasInputShape(ctx, 2)) {
3546
- auto& value_shape = getInputShape(ctx, 2);
3547
- auto& value_dims = value_shape.dim();
3548
- if ((value_dims.size() != 3) && (value_dims.size() != 4)) {
3549
- fail_shape_inference("Inputs 2 (value) shall be 3 or 4 dimensions");
3550
- }
3551
-
3552
- // Update Output Shape for 4D inputs
3553
- // Input 0 (query) has shape (batch_size, q_num_heads, q_sequence_length, head_size)
3554
- // Input 1 (key) has shape (batch_size, kv_num_heads, kv_sequence_length, head_size)
3555
- // Input 2 (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
3556
- // Output 0 has shape (batch_size, q_num_heads, q_sequence_length, v_head_size)
3557
- if (value_dims.size() == 4 && query_dims.size() == 4) {
3558
- kv_sequence_length = value_dims[2].dim_value();
3559
- *output_shape.add_dim() = query_dims[2]; // sequence_length
3560
- *output_shape.add_dim() = value_dims[3]; // head_size
3561
- updateOutputShape(ctx, 0, output_shape);
3562
- // Update qk_matmul_shape
3563
- *qk_matmul_shape.add_dim() = query_dims[1]; // q_num_heads
3564
- *qk_matmul_shape.add_dim() = query_dims[2]; // q_sequence_length
3565
- qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
3566
- }
3567
-
3568
- // Update Output Shape for 3D inputs
3569
- // Input 0 (query) has shape (batch_size, q_sequence_length, q_hidden_size),
3570
- // q_hidden_size = q_num_heads * head_size
3571
- // Input 1 (key) has shape (batch_size, kv_sequence_length, k_hidden_size),
3572
- // k_hidden_size = kv_num_heads * head_size
3573
- // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size),
3574
- // v_hidden_size = kv_num_heads * v_head_size
3575
- // Output 0 has shape (batch_size, q_sequence_length, hidden_size),
3576
- // hidden_size = q_num_heads * v_head_size
3577
- if (value_dims.size() == 3 && query_dims.size() == 3) {
3578
- kv_sequence_length = value_dims[1].dim_value();
3579
- auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
3580
- if (q_num_heads_attr == nullptr) {
3581
- fail_type_inference("3D inputs expected to have q_num_heads attribute.");
3582
- }
3583
- auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
3584
- if (kv_num_heads_attr == nullptr) {
3585
- fail_type_inference("3D inputs expected to have kv_num_heads attribute.");
3586
- }
3587
- int64_t q_num_heads = q_num_heads_attr->i();
3588
- int64_t kv_num_heads = kv_num_heads_attr->i();
3589
- // Calculate v_head_size
3590
- int64_t v_head_size = value_dims[2].dim_value() / kv_num_heads;
3591
- output_shape.add_dim()->set_dim_value(v_head_size * q_num_heads);
3592
- updateOutputShape(ctx, 0, output_shape);
3593
- // Update qk_matmul_shape
3594
- qk_matmul_shape.add_dim()->set_dim_value(q_num_heads);
3595
- *qk_matmul_shape.add_dim() = query_dims[1];
3596
- qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
3597
- }
3598
- }
3599
- }
3600
-
3601
- if (ctx.hasOutput(3)) { // has qk_matmul_output
3602
- propagateElemTypeFromInputToOutput(ctx, 0, 3);
3603
- updateOutputShape(ctx, 3, qk_matmul_shape);
3604
- }
3605
-
3606
- if (ctx.hasOutput(1) && ctx.hasOutput(2)) { // has present outputs
3607
- if (ctx.hasInput(4) && ctx.hasInput(5)) { // has past_key
3608
- // copy the type from query to present key and value
3609
- propagateElemTypeFromInputToOutput(ctx, 4, 1);
3610
- propagateElemTypeFromInputToOutput(ctx, 5, 2);
3611
-
3612
- if (hasInputShape(ctx, 4) && hasInputShape(ctx, 5)) {
3613
- auto& past_key_shape = getInputShape(ctx, 4);
3614
- auto& past_key_dims = past_key_shape.dim();
3615
- auto& past_value_shape = getInputShape(ctx, 5);
3616
- auto& past_value_dims = past_value_shape.dim();
3617
-
3618
- // past key has shape (batch_size, kv_num_heads, past_sequence_length, head_size)
3619
- if (past_key_dims.size() != 4) {
3620
- fail_shape_inference("The past_key input shall be 4 dimensions");
3621
- }
3622
- // past value has shape (batch_size, kv_num_heads, past_sequence_length, v_head_size)
3623
- if (past_value_dims.size() != 4) {
3624
- fail_shape_inference("The past_value input shall be 4 dimensions");
3625
- }
3626
-
3627
- if (kv_sequence_length > 0 && past_key_dims[2].has_dim_value()) {
3628
- int64_t total_sequence_length = kv_sequence_length + past_key_dims[2].dim_value();
3629
-
3630
- ONNX_NAMESPACE::TensorShapeProto present_key_shape;
3631
- for (auto& dim : past_key_dims) {
3632
- *present_key_shape.add_dim() = dim;
3633
- }
3634
-
3635
- ONNX_NAMESPACE::TensorShapeProto present_value_shape;
3636
- for (auto& dim : past_value_dims) {
3637
- *present_value_shape.add_dim() = dim;
3638
- }
3639
-
3640
- if (ctx.hasOutput(3)) { // has qk_matmul_output with bias
3641
- qk_matmul_shape.mutable_dim(3)->set_dim_value(total_sequence_length);
3642
- updateOutputShape(ctx, 3, qk_matmul_shape);
3643
- }
3644
-
3645
- // shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
3646
- present_key_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
3647
- present_value_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
3648
-
3649
- updateOutputShape(ctx, 1, present_key_shape);
3650
- updateOutputShape(ctx, 2, present_value_shape);
3651
- }
3652
- }
3653
- }
3654
- }
3655
- })
3524
+ .TypeAndShapeInferenceFunction(defs::nn::utils::AttentionPropagateElemTypeFromInputToOutput)
3656
3525
  .SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx,
3657
3526
  const OpSchema& schema,
3658
3527
  FunctionProto& functionProto) {
@@ -3676,11 +3545,6 @@ ONNX_OPERATOR_SET_SCHEMA(
3676
3545
  (softmax_precision != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))
3677
3546
  return false; // Error
3678
3547
 
3679
- auto mkbooltensor = [](bool val) -> ONNX_NAMESPACE::TensorProto {
3680
- auto tp = ONNX_NAMESPACE::ToTensor(std::vector<bool>{val});
3681
- tp.add_dims(1);
3682
- return tp;
3683
- };
3684
3548
  // If shape is 3D, q_num_heads and kv_num_heads is provided,
3685
3549
  // for 4D cases, set num_heads to zero for reshape purposes
3686
3550
  auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
@@ -3692,15 +3556,17 @@ ONNX_OPERATOR_SET_SCHEMA(
3692
3556
  bool is_3d_input = (q_num_heads > 0 && kv_num_heads > 0);
3693
3557
 
3694
3558
  FunctionBuilder builder(functionProto);
3559
+ builder
3560
+ .Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
3561
+ .Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
3562
+ .Add("KVSeqLen = Shape <start = -2, end = -1> (K)"); // kv_sequence_length
3563
+
3695
3564
  if (is_3d_input) {
3696
3565
  // For 3D inputs: First reshape to [batch_size, seq_length, num_heads, head_size]
3697
3566
  // then transpose to [batch_size, num_heads, seq_length, head_size]
3698
3567
  builder
3699
- .Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
3700
3568
  .Const1D("QNumHeadsAttr", q_num_heads) // q_num_heads from attrs
3701
3569
  .Const1D("KVNumHeadsAttr", kv_num_heads) // kv_num_heads from attrs
3702
- .Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
3703
- .Add("KVSeqLen = Shape <start = -2, end = -1> (K)") // kv_sequence_length
3704
3570
  .Const1D("NegOne", static_cast<int64_t>(-1)); // head_size, inferred from other dimensions
3705
3571
 
3706
3572
  builder.Add("QIntermediateShape = Concat <axis = 0> (BatchSize, QSeqLen, QNumHeadsAttr, NegOne)")
@@ -3715,7 +3581,6 @@ ONNX_OPERATOR_SET_SCHEMA(
3715
3581
  } else {
3716
3582
  // For 4D inputs: Already in desired shape [batch_size, num_heads, seq_length, head_size]
3717
3583
  builder.Add("QReshaped = Identity(Q)").Add("KReshaped = Identity(K)").Add("VReshaped = Identity(V)");
3718
- builder.Add("QSeqLen = Shape <start = -2, end = -1> (Q)");
3719
3584
  }
3720
3585
 
3721
3586
  builder
@@ -3728,6 +3593,7 @@ ONNX_OPERATOR_SET_SCHEMA(
3728
3593
  builder
3729
3594
  .Add("QKHeadSize = Shape <start = 3, end = 4> (QReshaped)") // head_size for Q and K
3730
3595
  .Add("QKHeadSizeF = Cast (QKHeadSize)", "to", float_type)
3596
+ .Add("VHeadSize = Shape <start = 3, end = 4> (VReshaped)") // head_size for V
3731
3597
  .Add("SqrtHeadSize = Sqrt(QKHeadSizeF)")
3732
3598
  .Const1D("One1D", static_cast<int64_t>(1))
3733
3599
  .Const1D("NegOne1D", static_cast<int64_t>(-1))
@@ -3743,8 +3609,10 @@ ONNX_OPERATOR_SET_SCHEMA(
3743
3609
 
3744
3610
  if (ctx.hasInput(4)) {
3745
3611
  builder.Add("PresentKey = Concat <axis = 2> (past_key, KReshaped)");
3612
+ builder.Add("PastKVSeqLen = Shape <start = -2, end = -1> (past_key)");
3746
3613
  } else {
3747
3614
  builder.Add("PresentKey = Identity (KReshaped)");
3615
+ builder.Const1D("PastKVSeqLen", static_cast<int64_t>(0));
3748
3616
  }
3749
3617
  if (ctx.hasOutput(1)) {
3750
3618
  builder.Add("present_key = Identity (PresentKey)");
@@ -3759,52 +3627,11 @@ ONNX_OPERATOR_SET_SCHEMA(
3759
3627
  builder.Add("present_value = Identity (PresentValue)");
3760
3628
  }
3761
3629
 
3762
- builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)");
3763
- builder.Add("AttnBiasShape = Concat <axis = 0> (QSeqLen, NewKVSeqLen)");
3764
- float neg_inf = -std::numeric_limits<float>::infinity();
3765
- builder.Const1D("FloatNegInf", neg_inf);
3766
- builder.Const1D("ScalarZero", 0.f);
3767
-
3768
- // If attn_mask is provided
3769
- if (ctx.hasInput(3)) {
3770
- auto* up = ctx.getInputType(3);
3771
- if ((up == nullptr) || (!up->has_tensor_type()))
3772
- return false;
3773
- int64_t U = up->tensor_type().elem_type();
3774
- builder.Add(
3775
- U == ONNX_NAMESPACE::TensorProto_DataType_BOOL
3776
- ? "AttnBiasShort = Where(attn_mask, ScalarZero, FloatNegInf)"
3777
- : "AttnBiasShort = Identity(attn_mask)");
3778
- // If attn_mask has a shorter kv sequence length, we pad it to NewKVSeqLen with FloatNegInf
3779
- builder.Add("MaskKVSeqLen = Shape <start = -1> (attn_mask)")
3780
- .Add("PaddingKVSeqLen = Sub(NewKVSeqLen, MaskKVSeqLen)")
3781
- .Add("Pads = Concat <axis = 0> (Zero1D, PaddingKVSeqLen)")
3782
- .Add("FloatNegInfCast = CastLike(FloatNegInf, AttnBiasShort)")
3783
- .Add("AttnBias = Pad(AttnBiasShort, Pads, FloatNegInfCast, NegOne1D)");
3784
- } else {
3785
- builder.Add("AttnBias = ConstantOfShape(AttnBiasShape)");
3786
- }
3787
-
3788
- // If is_causal set to true, the attention masking is a lower triangular matrix when the mask
3789
- // is a square matrix. The attention masking has the form of the upper left causal bias due to
3790
- // the alignment when the mask is a non-square matrix.
3791
- // An error is thrown if both attn_mask and is_causal are set.
3792
- auto* is_causal_attr = ctx.getAttribute("is_causal");
3793
- int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
3794
- if (is_causal == 1) {
3795
- builder.Add("BoolMask = ConstantOfShape(AttnBiasShape)", "value", mkbooltensor(1))
3796
- .Add("BoolMaskTri = Trilu <upper = 0> (BoolMask, Zero1D)")
3797
- .Add("MaskTri = Where(BoolMaskTri, ScalarZero, FloatNegInf)")
3798
- .Add("AttnBiasCausal = Add(AttnBias, MaskTri)");
3799
- } else {
3800
- builder.Add("AttnBiasCausal = Identity(AttnBias)");
3801
- }
3630
+ if (!defs::nn::utils::AttentionAppendFunctionCausalMask(ctx, builder, true))
3631
+ return false;
3802
3632
 
3803
3633
  // Add padding mask if kv_nonpad_seqlen is provided
3804
3634
  if (ctx.hasInput(6)) {
3805
- if (!is_3d_input) {
3806
- builder.Add("KVSeqLen = Shape <start = -2, end = -1> (K)");
3807
- }
3808
3635
  builder
3809
3636
  .Add("KVSeqLenExpanded = Unsqueeze(nonpad_kv_seqlen, One1D)") // [batch_size, 1]
3810
3637
  .Add("KVSeqLen0D = Squeeze(KVSeqLen)")
@@ -3815,9 +3642,9 @@ ONNX_OPERATOR_SET_SCHEMA(
3815
3642
  .Add("PaddingMaskFloat = Where(PaddingMaskBool, ScalarZero, FloatNegInf)") // [batch_size, KVSeqLen]
3816
3643
  .Add("PaddingMask3D = Unsqueeze(PaddingMaskFloat, One1D)") // [batch_size, 1, KVSeqLen]
3817
3644
  .Add("PaddingMask4D = Unsqueeze(PaddingMask3D, One1D)") // [batch_size, 1, 1, KVSeqLen]
3818
- .Add("AttnBiasCausalPad = Add(AttnBiasCausal, PaddingMask4D)");
3645
+ .Add("AttnBiasCausalPad = Add(AttnBiasCausalOrNot, PaddingMask4D)");
3819
3646
  } else {
3820
- builder.Add("AttnBiasCausalPad = Identity(AttnBiasCausal)");
3647
+ builder.Add("AttnBiasCausalPad = Identity(AttnBiasCausalOrNot)");
3821
3648
  }
3822
3649
  builder.Add("AttnBiasT = Cast (AttnBiasCausalPad)", "to", T1);
3823
3650
 
@@ -3832,10 +3659,25 @@ ONNX_OPERATOR_SET_SCHEMA(
3832
3659
  .Add("RemainderNumHeads = Mod(QNumHeads, KVNumHeads)")
3833
3660
  .Add("GQACond2 = Equal(RemainderNumHeads, Zero1D)")
3834
3661
  .Add("GQACond = And(GQACond1, GQACond2)")
3835
- .Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)")
3836
- .Add("InterleaveShape = Concat <axis = 0> (One1D, InterleaveDim, One1D, One1D)")
3837
- .Add("KAttentionInput = Tile(PresentKey, InterleaveShape)")
3838
- .Add("VAttentionInput = Tile(PresentValue, InterleaveShape)");
3662
+ .Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)");
3663
+
3664
+ // repeat kv (repeat_interleave)
3665
+ builder.Const1D("Two1D", static_cast<int64_t>(2))
3666
+ .Add("KUnsqueezed = Unsqueeze(PresentKey, Two1D)") // [B, Hk, 1, T, Dk]
3667
+ .Add("VUnsqueezed = Unsqueeze(PresentValue, Two1D)"); // [B, Hk, 1, T, Dv]
3668
+
3669
+ // Build expand shape: [B, Hk, repeats, T, Dk]
3670
+ builder
3671
+ .Add("KExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, QKHeadSize)")
3672
+ .Add("KExpanded = Expand(KUnsqueezed, KExpandShape)");
3673
+ builder.Add("VExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, VHeadSize)")
3674
+ .Add("VExpanded = Expand(VUnsqueezed, VExpandShape)");
3675
+
3676
+ // Reshape to [B, Hq, T, Dk] where Hq = Hk * repeats
3677
+ builder.Add("KAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, QKHeadSize)")
3678
+ .Add("VAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, VHeadSize)")
3679
+ .Add("KAttentionInput = Reshape(KExpanded, KAttentionShape)")
3680
+ .Add("VAttentionInput = Reshape(VExpanded, VAttentionShape)");
3839
3681
 
3840
3682
  // The following pattern is applied
3841
3683
  // Q K V