onnx 1.19.0__cp312-cp312-macosx_12_0_universal2.whl → 1.19.1rc1__cp312-cp312-macosx_12_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx might be problematic. Click here for more details.
- onnx/__init__.py +98 -0
- onnx/backend/test/case/node/__init__.py +20 -3
- onnx/backend/test/case/node/attention.py +62 -0
- onnx/backend/test/case/node/rotaryembedding.py +6 -6
- onnx/backend/test/data/node/test_attention_3d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_transpose_verification/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_transpose_verification_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_mask4d_padded_kv_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_fp16/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_fp16_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_1.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_2.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_1.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_2.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/defs/nn/defs.cc +70 -228
- onnx/defs/nn/old.cc +31 -201
- onnx/defs/nn/utils.cc +222 -0
- onnx/defs/nn/utils.h +25 -0
- onnx/onnx_cpp2py_export.cpython-312-darwin.so +0 -0
- onnx/reference/ops/op_attention.py +33 -14
- onnx/reference/ops/op_rotary_embedding.py +21 -19
- onnx/test/basic_test.py +84 -0
- onnx/test/reference_evaluator_test.py +23 -0
- onnx/test/test_backend_reference.py +2 -1
- onnx/version.py +2 -2
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/METADATA +2 -2
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/RECORD +202 -200
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/WHEEL +1 -1
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/entry_points.txt +0 -0
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/licenses/LICENSE +0 -0
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/top_level.txt +0 -0
onnx/defs/nn/defs.cc
CHANGED
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
|
|
9
9
|
#include "onnx/common/assertions.h"
|
|
10
10
|
#include "onnx/defs/function.h"
|
|
11
|
+
#include "onnx/defs/nn/utils.h"
|
|
11
12
|
#include "onnx/defs/schema.h"
|
|
12
13
|
|
|
13
14
|
namespace ONNX_NAMESPACE {
|
|
@@ -2992,15 +2993,16 @@ The rotation ensures that the model captures both absolute and relative position
|
|
|
2992
2993
|
Rotary embeddings are defined using the following algorithm:
|
|
2993
2994
|
|
|
2994
2995
|
```python
|
|
2995
|
-
def
|
|
2996
|
-
input,
|
|
2997
|
-
|
|
2998
|
-
sin_cache,
|
|
2999
|
-
|
|
3000
|
-
interleaved=
|
|
3001
|
-
rotary_embedding_dim=
|
|
3002
|
-
num_heads=
|
|
3003
|
-
):
|
|
2996
|
+
def rotary_embedding(
|
|
2997
|
+
input: np.ndarray,
|
|
2998
|
+
cos_cache: np.ndarray,
|
|
2999
|
+
sin_cache: np.ndarray,
|
|
3000
|
+
position_ids: np.ndarray | None = None,
|
|
3001
|
+
interleaved=None,
|
|
3002
|
+
rotary_embedding_dim=None,
|
|
3003
|
+
num_heads=None,
|
|
3004
|
+
) -> np.ndarray:
|
|
3005
|
+
original_input_shape = input.shape
|
|
3004
3006
|
# First ensure input to be processed has shape [batch_size, seq_len, num_heads, head_size]
|
|
3005
3007
|
if len(input.shape) == 4:
|
|
3006
3008
|
input = np.transpose(input, (0, 2, 1, 3))
|
|
@@ -3016,7 +3018,7 @@ def compute_rotary_embedding(
|
|
|
3016
3018
|
head_size = input.shape[3]
|
|
3017
3019
|
|
|
3018
3020
|
# Fully or partially perform rotation on input based on rotary_embedding_dim attribute
|
|
3019
|
-
if rotary_embedding_dim == 0:
|
|
3021
|
+
if rotary_embedding_dim is None or rotary_embedding_dim == 0:
|
|
3020
3022
|
# If rotary_embedding_dim not provided, perform full rotation by using head_size
|
|
3021
3023
|
rotary_embedding_dim = head_size
|
|
3022
3024
|
x_rotate = input[:, :, :, :rotary_embedding_dim]
|
|
@@ -3025,15 +3027,29 @@ def compute_rotary_embedding(
|
|
|
3025
3027
|
|
|
3026
3028
|
# Retrieve sin and cos caches using position ids
|
|
3027
3029
|
if position_ids is not None:
|
|
3028
|
-
|
|
3029
|
-
|
|
3030
|
-
|
|
3031
|
-
|
|
3032
|
-
|
|
3033
|
-
|
|
3034
|
-
|
|
3035
|
-
|
|
3036
|
-
|
|
3030
|
+
cos_cache = cos_cache[
|
|
3031
|
+
position_ids
|
|
3032
|
+
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
|
3033
|
+
sin_cache = sin_cache[
|
|
3034
|
+
position_ids
|
|
3035
|
+
] # Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
|
3036
|
+
|
|
3037
|
+
# Shape: [batch_size, sequence_length, rotary_embedding_dim/2]
|
|
3038
|
+
if cos_cache.shape[-1] != rotary_embedding_dim_half:
|
|
3039
|
+
raise ValueError(
|
|
3040
|
+
f"Last dimension of cos cache ({cos_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half})."
|
|
3041
|
+
)
|
|
3042
|
+
if sin_cache.shape[-1] != rotary_embedding_dim_half:
|
|
3043
|
+
raise ValueError(
|
|
3044
|
+
f"Last dimension of sin cache ({sin_cache.shape[-1]}) does not match rotary_embedding_dim/2 ({rotary_embedding_dim_half})."
|
|
3045
|
+
)
|
|
3046
|
+
|
|
3047
|
+
cos_cache = np.expand_dims(
|
|
3048
|
+
cos_cache, axis=2
|
|
3049
|
+
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
|
|
3050
|
+
sin_cache = np.expand_dims(
|
|
3051
|
+
sin_cache, axis=2
|
|
3052
|
+
) # Shape: [batch_size, sequence_length, 1, rotary_embedding_dim/2]
|
|
3037
3053
|
|
|
3038
3054
|
# Either divide the input in halves or interleave (based on interleaved attribute)
|
|
3039
3055
|
if interleaved:
|
|
@@ -3043,8 +3059,8 @@ def compute_rotary_embedding(
|
|
|
3043
3059
|
x1, x2 = np.split(x_rotate, 2, axis=-1)
|
|
3044
3060
|
|
|
3045
3061
|
# Calculate real and imaginary values
|
|
3046
|
-
real =
|
|
3047
|
-
imag =
|
|
3062
|
+
real = (cos_cache * x1) - (sin_cache * x2)
|
|
3063
|
+
imag = (sin_cache * x1) + (cos_cache * x2)
|
|
3048
3064
|
|
|
3049
3065
|
# Inserted rotated embeddings back to the original input
|
|
3050
3066
|
if interleaved:
|
|
@@ -3058,7 +3074,7 @@ def compute_rotary_embedding(
|
|
|
3058
3074
|
x_rotate = np.concatenate((real, imag), axis=-1)
|
|
3059
3075
|
output = np.concatenate((x_rotate, x_not_rotate), axis=-1)
|
|
3060
3076
|
if len(original_input_shape) == 3:
|
|
3061
|
-
output = np.reshape(output,
|
|
3077
|
+
output = np.reshape(output, original_input_shape)
|
|
3062
3078
|
else:
|
|
3063
3079
|
output = np.transpose(output, (0, 2, 1, 3))
|
|
3064
3080
|
return output
|
|
@@ -3505,154 +3521,7 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3505
3521
|
"U",
|
|
3506
3522
|
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
|
|
3507
3523
|
"Constrain output 'mask' types to boolean tensors and input types.")
|
|
3508
|
-
.TypeAndShapeInferenceFunction(
|
|
3509
|
-
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
|
3510
|
-
|
|
3511
|
-
int64_t kv_sequence_length = -1;
|
|
3512
|
-
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
|
3513
|
-
ONNX_NAMESPACE::TensorShapeProto qk_matmul_shape;
|
|
3514
|
-
if (hasInputShape(ctx, 0)) {
|
|
3515
|
-
auto& query_shape = getInputShape(ctx, 0);
|
|
3516
|
-
auto& query_dims = query_shape.dim();
|
|
3517
|
-
if ((query_dims.size() != 3) && (query_dims.size() != 4)) {
|
|
3518
|
-
fail_shape_inference("Inputs 0 (query) shall be 3 or 4 dimensions");
|
|
3519
|
-
}
|
|
3520
|
-
|
|
3521
|
-
if (query_dims.size() == 3) {
|
|
3522
|
-
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
3523
|
-
if (q_num_heads_attr == nullptr) {
|
|
3524
|
-
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
3525
|
-
}
|
|
3526
|
-
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
|
|
3527
|
-
if (kv_num_heads_attr == nullptr) {
|
|
3528
|
-
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
3529
|
-
}
|
|
3530
|
-
}
|
|
3531
|
-
|
|
3532
|
-
*output_shape.add_dim() = query_dims[0]; // batch_size
|
|
3533
|
-
*output_shape.add_dim() = query_dims[1]; // num_heads for 4D, sequence_length for 3D
|
|
3534
|
-
|
|
3535
|
-
*qk_matmul_shape.add_dim() = query_dims[0]; // batch_size
|
|
3536
|
-
|
|
3537
|
-
if (hasInputShape(ctx, 1)) {
|
|
3538
|
-
auto& key_shape = getInputShape(ctx, 1);
|
|
3539
|
-
auto& key_dims = key_shape.dim();
|
|
3540
|
-
if ((key_dims.size() != 3) && (key_dims.size() != 4)) {
|
|
3541
|
-
fail_shape_inference("Inputs 1 (key) shall be 3 or 4 dimensions");
|
|
3542
|
-
}
|
|
3543
|
-
}
|
|
3544
|
-
|
|
3545
|
-
if (hasInputShape(ctx, 2)) {
|
|
3546
|
-
auto& value_shape = getInputShape(ctx, 2);
|
|
3547
|
-
auto& value_dims = value_shape.dim();
|
|
3548
|
-
if ((value_dims.size() != 3) && (value_dims.size() != 4)) {
|
|
3549
|
-
fail_shape_inference("Inputs 2 (value) shall be 3 or 4 dimensions");
|
|
3550
|
-
}
|
|
3551
|
-
|
|
3552
|
-
// Update Output Shape for 4D inputs
|
|
3553
|
-
// Input 0 (query) has shape (batch_size, q_num_heads, q_sequence_length, head_size)
|
|
3554
|
-
// Input 1 (key) has shape (batch_size, kv_num_heads, kv_sequence_length, head_size)
|
|
3555
|
-
// Input 2 (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
|
|
3556
|
-
// Output 0 has shape (batch_size, q_num_heads, q_sequence_length, v_head_size)
|
|
3557
|
-
if (value_dims.size() == 4 && query_dims.size() == 4) {
|
|
3558
|
-
kv_sequence_length = value_dims[2].dim_value();
|
|
3559
|
-
*output_shape.add_dim() = query_dims[2]; // sequence_length
|
|
3560
|
-
*output_shape.add_dim() = value_dims[3]; // head_size
|
|
3561
|
-
updateOutputShape(ctx, 0, output_shape);
|
|
3562
|
-
// Update qk_matmul_shape
|
|
3563
|
-
*qk_matmul_shape.add_dim() = query_dims[1]; // q_num_heads
|
|
3564
|
-
*qk_matmul_shape.add_dim() = query_dims[2]; // q_sequence_length
|
|
3565
|
-
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
|
|
3566
|
-
}
|
|
3567
|
-
|
|
3568
|
-
// Update Output Shape for 3D inputs
|
|
3569
|
-
// Input 0 (query) has shape (batch_size, q_sequence_length, q_hidden_size),
|
|
3570
|
-
// q_hidden_size = q_num_heads * head_size
|
|
3571
|
-
// Input 1 (key) has shape (batch_size, kv_sequence_length, k_hidden_size),
|
|
3572
|
-
// k_hidden_size = kv_num_heads * head_size
|
|
3573
|
-
// Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size),
|
|
3574
|
-
// v_hidden_size = kv_num_heads * v_head_size
|
|
3575
|
-
// Output 0 has shape (batch_size, q_sequence_length, hidden_size),
|
|
3576
|
-
// hidden_size = q_num_heads * v_head_size
|
|
3577
|
-
if (value_dims.size() == 3 && query_dims.size() == 3) {
|
|
3578
|
-
kv_sequence_length = value_dims[1].dim_value();
|
|
3579
|
-
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
3580
|
-
if (q_num_heads_attr == nullptr) {
|
|
3581
|
-
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
3582
|
-
}
|
|
3583
|
-
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
|
|
3584
|
-
if (kv_num_heads_attr == nullptr) {
|
|
3585
|
-
fail_type_inference("3D inputs expected to have kv_num_heads attribute.");
|
|
3586
|
-
}
|
|
3587
|
-
int64_t q_num_heads = q_num_heads_attr->i();
|
|
3588
|
-
int64_t kv_num_heads = kv_num_heads_attr->i();
|
|
3589
|
-
// Calculate v_head_size
|
|
3590
|
-
int64_t v_head_size = value_dims[2].dim_value() / kv_num_heads;
|
|
3591
|
-
output_shape.add_dim()->set_dim_value(v_head_size * q_num_heads);
|
|
3592
|
-
updateOutputShape(ctx, 0, output_shape);
|
|
3593
|
-
// Update qk_matmul_shape
|
|
3594
|
-
qk_matmul_shape.add_dim()->set_dim_value(q_num_heads);
|
|
3595
|
-
*qk_matmul_shape.add_dim() = query_dims[1];
|
|
3596
|
-
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
|
|
3597
|
-
}
|
|
3598
|
-
}
|
|
3599
|
-
}
|
|
3600
|
-
|
|
3601
|
-
if (ctx.hasOutput(3)) { // has qk_matmul_output
|
|
3602
|
-
propagateElemTypeFromInputToOutput(ctx, 0, 3);
|
|
3603
|
-
updateOutputShape(ctx, 3, qk_matmul_shape);
|
|
3604
|
-
}
|
|
3605
|
-
|
|
3606
|
-
if (ctx.hasOutput(1) && ctx.hasOutput(2)) { // has present outputs
|
|
3607
|
-
if (ctx.hasInput(4) && ctx.hasInput(5)) { // has past_key
|
|
3608
|
-
// copy the type from query to present key and value
|
|
3609
|
-
propagateElemTypeFromInputToOutput(ctx, 4, 1);
|
|
3610
|
-
propagateElemTypeFromInputToOutput(ctx, 5, 2);
|
|
3611
|
-
|
|
3612
|
-
if (hasInputShape(ctx, 4) && hasInputShape(ctx, 5)) {
|
|
3613
|
-
auto& past_key_shape = getInputShape(ctx, 4);
|
|
3614
|
-
auto& past_key_dims = past_key_shape.dim();
|
|
3615
|
-
auto& past_value_shape = getInputShape(ctx, 5);
|
|
3616
|
-
auto& past_value_dims = past_value_shape.dim();
|
|
3617
|
-
|
|
3618
|
-
// past key has shape (batch_size, kv_num_heads, past_sequence_length, head_size)
|
|
3619
|
-
if (past_key_dims.size() != 4) {
|
|
3620
|
-
fail_shape_inference("The past_key input shall be 4 dimensions");
|
|
3621
|
-
}
|
|
3622
|
-
// past value has shape (batch_size, kv_num_heads, past_sequence_length, v_head_size)
|
|
3623
|
-
if (past_value_dims.size() != 4) {
|
|
3624
|
-
fail_shape_inference("The past_value input shall be 4 dimensions");
|
|
3625
|
-
}
|
|
3626
|
-
|
|
3627
|
-
if (kv_sequence_length > 0 && past_key_dims[2].has_dim_value()) {
|
|
3628
|
-
int64_t total_sequence_length = kv_sequence_length + past_key_dims[2].dim_value();
|
|
3629
|
-
|
|
3630
|
-
ONNX_NAMESPACE::TensorShapeProto present_key_shape;
|
|
3631
|
-
for (auto& dim : past_key_dims) {
|
|
3632
|
-
*present_key_shape.add_dim() = dim;
|
|
3633
|
-
}
|
|
3634
|
-
|
|
3635
|
-
ONNX_NAMESPACE::TensorShapeProto present_value_shape;
|
|
3636
|
-
for (auto& dim : past_value_dims) {
|
|
3637
|
-
*present_value_shape.add_dim() = dim;
|
|
3638
|
-
}
|
|
3639
|
-
|
|
3640
|
-
if (ctx.hasOutput(3)) { // has qk_matmul_output with bias
|
|
3641
|
-
qk_matmul_shape.mutable_dim(3)->set_dim_value(total_sequence_length);
|
|
3642
|
-
updateOutputShape(ctx, 3, qk_matmul_shape);
|
|
3643
|
-
}
|
|
3644
|
-
|
|
3645
|
-
// shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
|
|
3646
|
-
present_key_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
|
|
3647
|
-
present_value_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
|
|
3648
|
-
|
|
3649
|
-
updateOutputShape(ctx, 1, present_key_shape);
|
|
3650
|
-
updateOutputShape(ctx, 2, present_value_shape);
|
|
3651
|
-
}
|
|
3652
|
-
}
|
|
3653
|
-
}
|
|
3654
|
-
}
|
|
3655
|
-
})
|
|
3524
|
+
.TypeAndShapeInferenceFunction(defs::nn::utils::AttentionPropagateElemTypeFromInputToOutput)
|
|
3656
3525
|
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx,
|
|
3657
3526
|
const OpSchema& schema,
|
|
3658
3527
|
FunctionProto& functionProto) {
|
|
@@ -3676,11 +3545,6 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3676
3545
|
(softmax_precision != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))
|
|
3677
3546
|
return false; // Error
|
|
3678
3547
|
|
|
3679
|
-
auto mkbooltensor = [](bool val) -> ONNX_NAMESPACE::TensorProto {
|
|
3680
|
-
auto tp = ONNX_NAMESPACE::ToTensor(std::vector<bool>{val});
|
|
3681
|
-
tp.add_dims(1);
|
|
3682
|
-
return tp;
|
|
3683
|
-
};
|
|
3684
3548
|
// If shape is 3D, q_num_heads and kv_num_heads is provided,
|
|
3685
3549
|
// for 4D cases, set num_heads to zero for reshape purposes
|
|
3686
3550
|
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
@@ -3692,15 +3556,17 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3692
3556
|
bool is_3d_input = (q_num_heads > 0 && kv_num_heads > 0);
|
|
3693
3557
|
|
|
3694
3558
|
FunctionBuilder builder(functionProto);
|
|
3559
|
+
builder
|
|
3560
|
+
.Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
|
|
3561
|
+
.Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
|
|
3562
|
+
.Add("KVSeqLen = Shape <start = -2, end = -1> (K)"); // kv_sequence_length
|
|
3563
|
+
|
|
3695
3564
|
if (is_3d_input) {
|
|
3696
3565
|
// For 3D inputs: First reshape to [batch_size, seq_length, num_heads, head_size]
|
|
3697
3566
|
// then transpose to [batch_size, num_heads, seq_length, head_size]
|
|
3698
3567
|
builder
|
|
3699
|
-
.Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
|
|
3700
3568
|
.Const1D("QNumHeadsAttr", q_num_heads) // q_num_heads from attrs
|
|
3701
3569
|
.Const1D("KVNumHeadsAttr", kv_num_heads) // kv_num_heads from attrs
|
|
3702
|
-
.Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
|
|
3703
|
-
.Add("KVSeqLen = Shape <start = -2, end = -1> (K)") // kv_sequence_length
|
|
3704
3570
|
.Const1D("NegOne", static_cast<int64_t>(-1)); // head_size, inferred from other dimensions
|
|
3705
3571
|
|
|
3706
3572
|
builder.Add("QIntermediateShape = Concat <axis = 0> (BatchSize, QSeqLen, QNumHeadsAttr, NegOne)")
|
|
@@ -3715,7 +3581,6 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3715
3581
|
} else {
|
|
3716
3582
|
// For 4D inputs: Already in desired shape [batch_size, num_heads, seq_length, head_size]
|
|
3717
3583
|
builder.Add("QReshaped = Identity(Q)").Add("KReshaped = Identity(K)").Add("VReshaped = Identity(V)");
|
|
3718
|
-
builder.Add("QSeqLen = Shape <start = -2, end = -1> (Q)");
|
|
3719
3584
|
}
|
|
3720
3585
|
|
|
3721
3586
|
builder
|
|
@@ -3728,6 +3593,7 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3728
3593
|
builder
|
|
3729
3594
|
.Add("QKHeadSize = Shape <start = 3, end = 4> (QReshaped)") // head_size for Q and K
|
|
3730
3595
|
.Add("QKHeadSizeF = Cast (QKHeadSize)", "to", float_type)
|
|
3596
|
+
.Add("VHeadSize = Shape <start = 3, end = 4> (VReshaped)") // head_size for V
|
|
3731
3597
|
.Add("SqrtHeadSize = Sqrt(QKHeadSizeF)")
|
|
3732
3598
|
.Const1D("One1D", static_cast<int64_t>(1))
|
|
3733
3599
|
.Const1D("NegOne1D", static_cast<int64_t>(-1))
|
|
@@ -3743,8 +3609,10 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3743
3609
|
|
|
3744
3610
|
if (ctx.hasInput(4)) {
|
|
3745
3611
|
builder.Add("PresentKey = Concat <axis = 2> (past_key, KReshaped)");
|
|
3612
|
+
builder.Add("PastKVSeqLen = Shape <start = -2, end = -1> (past_key)");
|
|
3746
3613
|
} else {
|
|
3747
3614
|
builder.Add("PresentKey = Identity (KReshaped)");
|
|
3615
|
+
builder.Const1D("PastKVSeqLen", static_cast<int64_t>(0));
|
|
3748
3616
|
}
|
|
3749
3617
|
if (ctx.hasOutput(1)) {
|
|
3750
3618
|
builder.Add("present_key = Identity (PresentKey)");
|
|
@@ -3759,52 +3627,11 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3759
3627
|
builder.Add("present_value = Identity (PresentValue)");
|
|
3760
3628
|
}
|
|
3761
3629
|
|
|
3762
|
-
|
|
3763
|
-
|
|
3764
|
-
float neg_inf = -std::numeric_limits<float>::infinity();
|
|
3765
|
-
builder.Const1D("FloatNegInf", neg_inf);
|
|
3766
|
-
builder.Const1D("ScalarZero", 0.f);
|
|
3767
|
-
|
|
3768
|
-
// If attn_mask is provided
|
|
3769
|
-
if (ctx.hasInput(3)) {
|
|
3770
|
-
auto* up = ctx.getInputType(3);
|
|
3771
|
-
if ((up == nullptr) || (!up->has_tensor_type()))
|
|
3772
|
-
return false;
|
|
3773
|
-
int64_t U = up->tensor_type().elem_type();
|
|
3774
|
-
builder.Add(
|
|
3775
|
-
U == ONNX_NAMESPACE::TensorProto_DataType_BOOL
|
|
3776
|
-
? "AttnBiasShort = Where(attn_mask, ScalarZero, FloatNegInf)"
|
|
3777
|
-
: "AttnBiasShort = Identity(attn_mask)");
|
|
3778
|
-
// If attn_mask has a shorter kv sequence length, we pad it to NewKVSeqLen with FloatNegInf
|
|
3779
|
-
builder.Add("MaskKVSeqLen = Shape <start = -1> (attn_mask)")
|
|
3780
|
-
.Add("PaddingKVSeqLen = Sub(NewKVSeqLen, MaskKVSeqLen)")
|
|
3781
|
-
.Add("Pads = Concat <axis = 0> (Zero1D, PaddingKVSeqLen)")
|
|
3782
|
-
.Add("FloatNegInfCast = CastLike(FloatNegInf, AttnBiasShort)")
|
|
3783
|
-
.Add("AttnBias = Pad(AttnBiasShort, Pads, FloatNegInfCast, NegOne1D)");
|
|
3784
|
-
} else {
|
|
3785
|
-
builder.Add("AttnBias = ConstantOfShape(AttnBiasShape)");
|
|
3786
|
-
}
|
|
3787
|
-
|
|
3788
|
-
// If is_causal set to true, the attention masking is a lower triangular matrix when the mask
|
|
3789
|
-
// is a square matrix. The attention masking has the form of the upper left causal bias due to
|
|
3790
|
-
// the alignment when the mask is a non-square matrix.
|
|
3791
|
-
// An error is thrown if both attn_mask and is_causal are set.
|
|
3792
|
-
auto* is_causal_attr = ctx.getAttribute("is_causal");
|
|
3793
|
-
int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
|
|
3794
|
-
if (is_causal == 1) {
|
|
3795
|
-
builder.Add("BoolMask = ConstantOfShape(AttnBiasShape)", "value", mkbooltensor(1))
|
|
3796
|
-
.Add("BoolMaskTri = Trilu <upper = 0> (BoolMask, Zero1D)")
|
|
3797
|
-
.Add("MaskTri = Where(BoolMaskTri, ScalarZero, FloatNegInf)")
|
|
3798
|
-
.Add("AttnBiasCausal = Add(AttnBias, MaskTri)");
|
|
3799
|
-
} else {
|
|
3800
|
-
builder.Add("AttnBiasCausal = Identity(AttnBias)");
|
|
3801
|
-
}
|
|
3630
|
+
if (!defs::nn::utils::AttentionAppendFunctionCausalMask(ctx, builder, true))
|
|
3631
|
+
return false;
|
|
3802
3632
|
|
|
3803
3633
|
// Add padding mask if kv_nonpad_seqlen is provided
|
|
3804
3634
|
if (ctx.hasInput(6)) {
|
|
3805
|
-
if (!is_3d_input) {
|
|
3806
|
-
builder.Add("KVSeqLen = Shape <start = -2, end = -1> (K)");
|
|
3807
|
-
}
|
|
3808
3635
|
builder
|
|
3809
3636
|
.Add("KVSeqLenExpanded = Unsqueeze(nonpad_kv_seqlen, One1D)") // [batch_size, 1]
|
|
3810
3637
|
.Add("KVSeqLen0D = Squeeze(KVSeqLen)")
|
|
@@ -3815,9 +3642,9 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3815
3642
|
.Add("PaddingMaskFloat = Where(PaddingMaskBool, ScalarZero, FloatNegInf)") // [batch_size, KVSeqLen]
|
|
3816
3643
|
.Add("PaddingMask3D = Unsqueeze(PaddingMaskFloat, One1D)") // [batch_size, 1, KVSeqLen]
|
|
3817
3644
|
.Add("PaddingMask4D = Unsqueeze(PaddingMask3D, One1D)") // [batch_size, 1, 1, KVSeqLen]
|
|
3818
|
-
.Add("AttnBiasCausalPad = Add(
|
|
3645
|
+
.Add("AttnBiasCausalPad = Add(AttnBiasCausalOrNot, PaddingMask4D)");
|
|
3819
3646
|
} else {
|
|
3820
|
-
builder.Add("AttnBiasCausalPad = Identity(
|
|
3647
|
+
builder.Add("AttnBiasCausalPad = Identity(AttnBiasCausalOrNot)");
|
|
3821
3648
|
}
|
|
3822
3649
|
builder.Add("AttnBiasT = Cast (AttnBiasCausalPad)", "to", T1);
|
|
3823
3650
|
|
|
@@ -3832,10 +3659,25 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
3832
3659
|
.Add("RemainderNumHeads = Mod(QNumHeads, KVNumHeads)")
|
|
3833
3660
|
.Add("GQACond2 = Equal(RemainderNumHeads, Zero1D)")
|
|
3834
3661
|
.Add("GQACond = And(GQACond1, GQACond2)")
|
|
3835
|
-
.Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)")
|
|
3836
|
-
|
|
3837
|
-
|
|
3838
|
-
|
|
3662
|
+
.Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)");
|
|
3663
|
+
|
|
3664
|
+
// repeat kv (repeat_interleave)
|
|
3665
|
+
builder.Const1D("Two1D", static_cast<int64_t>(2))
|
|
3666
|
+
.Add("KUnsqueezed = Unsqueeze(PresentKey, Two1D)") // [B, Hk, 1, T, Dk]
|
|
3667
|
+
.Add("VUnsqueezed = Unsqueeze(PresentValue, Two1D)"); // [B, Hk, 1, T, Dv]
|
|
3668
|
+
|
|
3669
|
+
// Build expand shape: [B, Hk, repeats, T, Dk]
|
|
3670
|
+
builder
|
|
3671
|
+
.Add("KExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, QKHeadSize)")
|
|
3672
|
+
.Add("KExpanded = Expand(KUnsqueezed, KExpandShape)");
|
|
3673
|
+
builder.Add("VExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, VHeadSize)")
|
|
3674
|
+
.Add("VExpanded = Expand(VUnsqueezed, VExpandShape)");
|
|
3675
|
+
|
|
3676
|
+
// Reshape to [B, Hq, T, Dk] where Hq = Hk * repeats
|
|
3677
|
+
builder.Add("KAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, QKHeadSize)")
|
|
3678
|
+
.Add("VAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, VHeadSize)")
|
|
3679
|
+
.Add("KAttentionInput = Reshape(KExpanded, KAttentionShape)")
|
|
3680
|
+
.Add("VAttentionInput = Reshape(VExpanded, VAttentionShape)");
|
|
3839
3681
|
|
|
3840
3682
|
// The following pattern is applied
|
|
3841
3683
|
// Q K V
|