onnx 1.19.0__cp310-cp310-win_amd64.whl → 1.19.1rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx might be problematic. Click here for more details.
- onnx/__init__.py +98 -0
- onnx/backend/test/case/node/__init__.py +20 -3
- onnx/backend/test/case/node/attention.py +62 -0
- onnx/backend/test/case/node/rotaryembedding.py +6 -6
- onnx/backend/test/data/node/test_attention_3d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_3d_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_transpose_verification/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_transpose_verification_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_3d_with_past_and_present_qk_matmul_softmax_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_3d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_bool_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_mask4d_padded_kv_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_sizes_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask3d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_diff_heads_with_past_and_present_mask4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_fp16/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_fp16_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_attn_mask_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_scaled_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_softcap_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_gqa_with_past_and_present_fp16_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_scaled/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_scaled_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_3d_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal_expanded/test_data_set_0/output_3.pb +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_bias_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_past_and_present_qk_matmul_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_bias_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softcap_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax/model.onnx +0 -0
- onnx/backend/test/data/node/test_attention_4d_with_qk_matmul_softmax_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_1.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/input_2.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_1.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/input_2.pb +1 -1
- onnx/backend/test/data/node/test_rotary_embedding_no_position_ids_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_interleaved_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim/test_data_set_0/output_0.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/model.onnx +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_1.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/input_2.pb +0 -0
- onnx/backend/test/data/node/test_rotary_embedding_with_rotary_dim_expanded/test_data_set_0/output_0.pb +0 -0
- onnx/defs/nn/defs.cc +70 -228
- onnx/defs/nn/old.cc +31 -201
- onnx/defs/nn/utils.cc +222 -0
- onnx/defs/nn/utils.h +25 -0
- onnx/onnx_cpp2py_export.cp310-win_amd64.pyd +0 -0
- onnx/reference/ops/op_attention.py +33 -14
- onnx/reference/ops/op_rotary_embedding.py +21 -19
- onnx/test/basic_test.py +84 -0
- onnx/test/reference_evaluator_test.py +23 -0
- onnx/test/test_backend_reference.py +2 -1
- onnx/version.py +2 -2
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/METADATA +2 -2
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/RECORD +202 -200
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/WHEEL +1 -1
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/entry_points.txt +0 -0
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/licenses/LICENSE +0 -0
- {onnx-1.19.0.dist-info → onnx-1.19.1rc1.dist-info}/top_level.txt +0 -0
onnx/defs/nn/old.cc
CHANGED
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
#include <cmath>
|
|
6
6
|
|
|
7
7
|
#include "onnx/defs/function.h"
|
|
8
|
+
#include "onnx/defs/nn/utils.h"
|
|
8
9
|
#include "onnx/defs/schema.h"
|
|
9
10
|
|
|
10
11
|
namespace ONNX_NAMESPACE {
|
|
@@ -4373,154 +4374,7 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
4373
4374
|
"U",
|
|
4374
4375
|
OpSchema::all_non_complex_numeric_types_plus_bool_ir4(),
|
|
4375
4376
|
"Constrain output 'mask' types to boolean tensors and input types.")
|
|
4376
|
-
.TypeAndShapeInferenceFunction(
|
|
4377
|
-
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
|
4378
|
-
|
|
4379
|
-
int64_t kv_sequence_length = -1;
|
|
4380
|
-
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
|
4381
|
-
ONNX_NAMESPACE::TensorShapeProto qk_matmul_shape;
|
|
4382
|
-
if (hasInputShape(ctx, 0)) {
|
|
4383
|
-
auto& query_shape = getInputShape(ctx, 0);
|
|
4384
|
-
auto& query_dims = query_shape.dim();
|
|
4385
|
-
if ((query_dims.size() != 3) && (query_dims.size() != 4)) {
|
|
4386
|
-
fail_shape_inference("Inputs 0 (query) shall be 3 or 4 dimensions");
|
|
4387
|
-
}
|
|
4388
|
-
|
|
4389
|
-
if (query_dims.size() == 3) {
|
|
4390
|
-
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
4391
|
-
if (q_num_heads_attr == nullptr) {
|
|
4392
|
-
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
4393
|
-
}
|
|
4394
|
-
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
|
|
4395
|
-
if (kv_num_heads_attr == nullptr) {
|
|
4396
|
-
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
4397
|
-
}
|
|
4398
|
-
}
|
|
4399
|
-
|
|
4400
|
-
*output_shape.add_dim() = query_dims[0]; // batch_size
|
|
4401
|
-
*output_shape.add_dim() = query_dims[1]; // num_heads for 4D, sequence_length for 3D
|
|
4402
|
-
|
|
4403
|
-
*qk_matmul_shape.add_dim() = query_dims[0]; // batch_size
|
|
4404
|
-
|
|
4405
|
-
if (hasInputShape(ctx, 1)) {
|
|
4406
|
-
auto& key_shape = getInputShape(ctx, 1);
|
|
4407
|
-
auto& key_dims = key_shape.dim();
|
|
4408
|
-
if ((key_dims.size() != 3) && (key_dims.size() != 4)) {
|
|
4409
|
-
fail_shape_inference("Inputs 1 (key) shall be 3 or 4 dimensions");
|
|
4410
|
-
}
|
|
4411
|
-
}
|
|
4412
|
-
|
|
4413
|
-
if (hasInputShape(ctx, 2)) {
|
|
4414
|
-
auto& value_shape = getInputShape(ctx, 2);
|
|
4415
|
-
auto& value_dims = value_shape.dim();
|
|
4416
|
-
if ((value_dims.size() != 3) && (value_dims.size() != 4)) {
|
|
4417
|
-
fail_shape_inference("Inputs 2 (value) shall be 3 or 4 dimensions");
|
|
4418
|
-
}
|
|
4419
|
-
|
|
4420
|
-
// Update Output Shape for 4D inputs
|
|
4421
|
-
// Input 0 (query) has shape (batch_size, q_num_heads, q_sequence_length, head_size)
|
|
4422
|
-
// Input 1 (key) has shape (batch_size, kv_num_heads, kv_sequence_length, head_size)
|
|
4423
|
-
// Input 2 (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
|
|
4424
|
-
// Output 0 has shape (batch_size, q_num_heads, q_sequence_length, v_head_size)
|
|
4425
|
-
if (value_dims.size() == 4 && query_dims.size() == 4) {
|
|
4426
|
-
kv_sequence_length = value_dims[2].dim_value();
|
|
4427
|
-
*output_shape.add_dim() = query_dims[2]; // sequence_length
|
|
4428
|
-
*output_shape.add_dim() = value_dims[3]; // head_size
|
|
4429
|
-
updateOutputShape(ctx, 0, output_shape);
|
|
4430
|
-
// Update qk_matmul_shape
|
|
4431
|
-
*qk_matmul_shape.add_dim() = query_dims[1]; // q_num_heads
|
|
4432
|
-
*qk_matmul_shape.add_dim() = query_dims[2]; // q_sequence_length
|
|
4433
|
-
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
|
|
4434
|
-
}
|
|
4435
|
-
|
|
4436
|
-
// Update Output Shape for 3D inputs
|
|
4437
|
-
// Input 0 (query) has shape (batch_size, q_sequence_length, q_hidden_size),
|
|
4438
|
-
// q_hidden_size = q_num_heads * head_size
|
|
4439
|
-
// Input 1 (key) has shape (batch_size, kv_sequence_length, k_hidden_size),
|
|
4440
|
-
// k_hidden_size = kv_num_heads * head_size
|
|
4441
|
-
// Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size),
|
|
4442
|
-
// v_hidden_size = kv_num_heads * v_head_size
|
|
4443
|
-
// Output 0 has shape (batch_size, q_sequence_length, hidden_size),
|
|
4444
|
-
// hidden_size = q_num_heads * v_head_size
|
|
4445
|
-
if (value_dims.size() == 3 && query_dims.size() == 3) {
|
|
4446
|
-
kv_sequence_length = value_dims[1].dim_value();
|
|
4447
|
-
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
4448
|
-
if (q_num_heads_attr == nullptr) {
|
|
4449
|
-
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
4450
|
-
}
|
|
4451
|
-
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
|
|
4452
|
-
if (kv_num_heads_attr == nullptr) {
|
|
4453
|
-
fail_type_inference("3D inputs expected to have kv_num_heads attribute.");
|
|
4454
|
-
}
|
|
4455
|
-
int64_t q_num_heads = q_num_heads_attr->i();
|
|
4456
|
-
int64_t kv_num_heads = kv_num_heads_attr->i();
|
|
4457
|
-
// Calculate v_head_size
|
|
4458
|
-
int64_t v_head_size = value_dims[2].dim_value() / kv_num_heads;
|
|
4459
|
-
output_shape.add_dim()->set_dim_value(v_head_size * q_num_heads);
|
|
4460
|
-
updateOutputShape(ctx, 0, output_shape);
|
|
4461
|
-
// Update qk_matmul_shape
|
|
4462
|
-
qk_matmul_shape.add_dim()->set_dim_value(q_num_heads);
|
|
4463
|
-
*qk_matmul_shape.add_dim() = query_dims[1];
|
|
4464
|
-
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
|
|
4465
|
-
}
|
|
4466
|
-
}
|
|
4467
|
-
}
|
|
4468
|
-
|
|
4469
|
-
if (ctx.hasOutput(3)) { // has qk_matmul_output
|
|
4470
|
-
propagateElemTypeFromInputToOutput(ctx, 0, 3);
|
|
4471
|
-
updateOutputShape(ctx, 3, qk_matmul_shape);
|
|
4472
|
-
}
|
|
4473
|
-
|
|
4474
|
-
if (ctx.hasOutput(1) && ctx.hasOutput(2)) { // has present outputs
|
|
4475
|
-
if (ctx.hasInput(4) && ctx.hasInput(5)) { // has past_key
|
|
4476
|
-
// copy the type from query to present key and value
|
|
4477
|
-
propagateElemTypeFromInputToOutput(ctx, 4, 1);
|
|
4478
|
-
propagateElemTypeFromInputToOutput(ctx, 5, 2);
|
|
4479
|
-
|
|
4480
|
-
if (hasInputShape(ctx, 4) && hasInputShape(ctx, 5)) {
|
|
4481
|
-
auto& past_key_shape = getInputShape(ctx, 4);
|
|
4482
|
-
auto& past_key_dims = past_key_shape.dim();
|
|
4483
|
-
auto& past_value_shape = getInputShape(ctx, 5);
|
|
4484
|
-
auto& past_value_dims = past_value_shape.dim();
|
|
4485
|
-
|
|
4486
|
-
// past key has shape (batch_size, kv_num_heads, past_sequence_length, head_size)
|
|
4487
|
-
if (past_key_dims.size() != 4) {
|
|
4488
|
-
fail_shape_inference("The past_key input shall be 4 dimensions");
|
|
4489
|
-
}
|
|
4490
|
-
// past value has shape (batch_size, kv_num_heads, past_sequence_length, v_head_size)
|
|
4491
|
-
if (past_value_dims.size() != 4) {
|
|
4492
|
-
fail_shape_inference("The past_value input shall be 4 dimensions");
|
|
4493
|
-
}
|
|
4494
|
-
|
|
4495
|
-
if (kv_sequence_length > 0 && past_key_dims[2].has_dim_value()) {
|
|
4496
|
-
int64_t total_sequence_length = kv_sequence_length + past_key_dims[2].dim_value();
|
|
4497
|
-
|
|
4498
|
-
ONNX_NAMESPACE::TensorShapeProto present_key_shape;
|
|
4499
|
-
for (auto& dim : past_key_dims) {
|
|
4500
|
-
*present_key_shape.add_dim() = dim;
|
|
4501
|
-
}
|
|
4502
|
-
|
|
4503
|
-
ONNX_NAMESPACE::TensorShapeProto present_value_shape;
|
|
4504
|
-
for (auto& dim : past_value_dims) {
|
|
4505
|
-
*present_value_shape.add_dim() = dim;
|
|
4506
|
-
}
|
|
4507
|
-
|
|
4508
|
-
if (ctx.hasOutput(3)) { // has qk_matmul_output with bias
|
|
4509
|
-
qk_matmul_shape.mutable_dim(3)->set_dim_value(total_sequence_length);
|
|
4510
|
-
updateOutputShape(ctx, 3, qk_matmul_shape);
|
|
4511
|
-
}
|
|
4512
|
-
|
|
4513
|
-
// shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
|
|
4514
|
-
present_key_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
|
|
4515
|
-
present_value_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
|
|
4516
|
-
|
|
4517
|
-
updateOutputShape(ctx, 1, present_key_shape);
|
|
4518
|
-
updateOutputShape(ctx, 2, present_value_shape);
|
|
4519
|
-
}
|
|
4520
|
-
}
|
|
4521
|
-
}
|
|
4522
|
-
}
|
|
4523
|
-
})
|
|
4377
|
+
.TypeAndShapeInferenceFunction(defs::nn::utils::AttentionPropagateElemTypeFromInputToOutput)
|
|
4524
4378
|
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx,
|
|
4525
4379
|
const OpSchema& schema,
|
|
4526
4380
|
FunctionProto& functionProto) {
|
|
@@ -4544,12 +4398,6 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
4544
4398
|
(softmax_precision != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE))
|
|
4545
4399
|
return false; // Error
|
|
4546
4400
|
|
|
4547
|
-
auto mkbooltensor = [](bool val) -> ONNX_NAMESPACE::TensorProto {
|
|
4548
|
-
auto tp = ONNX_NAMESPACE::ToTensor(std::vector<bool>{val});
|
|
4549
|
-
tp.add_dims(1);
|
|
4550
|
-
return tp;
|
|
4551
|
-
};
|
|
4552
|
-
|
|
4553
4401
|
// If shape is 3D, q_num_heads and kv_num_heads is provided,
|
|
4554
4402
|
// for 4D cases, set num_heads to zero for reshape purposes
|
|
4555
4403
|
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
@@ -4561,15 +4409,16 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
4561
4409
|
bool is_3d_input = (q_num_heads > 0 && kv_num_heads > 0);
|
|
4562
4410
|
|
|
4563
4411
|
FunctionBuilder builder(functionProto);
|
|
4412
|
+
builder
|
|
4413
|
+
.Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
|
|
4414
|
+
.Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
|
|
4415
|
+
.Add("KVSeqLen = Shape <start = -2, end = -1> (K)"); // kv_sequence_length
|
|
4564
4416
|
if (is_3d_input) {
|
|
4565
4417
|
// For 3D inputs: First reshape to [batch_size, seq_length, num_heads, head_size]
|
|
4566
4418
|
// then transpose to [batch_size, num_heads, seq_length, head_size]
|
|
4567
4419
|
builder
|
|
4568
|
-
.Add("BatchSize = Shape <start = 0, end = 1> (Q)") // batch size
|
|
4569
4420
|
.Const1D("QNumHeadsAttr", q_num_heads) // q_num_heads from attrs
|
|
4570
4421
|
.Const1D("KVNumHeadsAttr", kv_num_heads) // kv_num_heads from attrs
|
|
4571
|
-
.Add("QSeqLen = Shape <start = -2, end = -1> (Q)") // q_sequence_length
|
|
4572
|
-
.Add("KVSeqLen = Shape <start = -2, end = -1> (K)") // kv_sequence_length
|
|
4573
4422
|
.Const1D("NegOne", static_cast<int64_t>(-1)); // head_size, inferred from other dimensions
|
|
4574
4423
|
|
|
4575
4424
|
builder.Add("QIntermediateShape = Concat <axis = 0> (BatchSize, QSeqLen, QNumHeadsAttr, NegOne)")
|
|
@@ -4596,6 +4445,7 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
4596
4445
|
builder
|
|
4597
4446
|
.Add("QKHeadSize = Shape <start = 3, end = 4> (QReshaped)") // head_size for Q and K
|
|
4598
4447
|
.Add("QKHeadSizeF = Cast (QKHeadSize)", "to", float_type)
|
|
4448
|
+
.Add("VHeadSize = Shape <start = 3, end = 4> (VReshaped)") // head_size for V
|
|
4599
4449
|
.Add("SqrtHeadSize = Sqrt(QKHeadSizeF)")
|
|
4600
4450
|
.Const1D("One1D", static_cast<int64_t>(1))
|
|
4601
4451
|
.Const1D("One1DF", static_cast<float>(1))
|
|
@@ -4610,8 +4460,10 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
4610
4460
|
|
|
4611
4461
|
if (ctx.hasInput(4)) {
|
|
4612
4462
|
builder.Add("PresentKey = Concat <axis = 2> (past_key, KReshaped)");
|
|
4463
|
+
builder.Add("PastKVSeqLen = Shape <start = -2, end = -1> (past_key)");
|
|
4613
4464
|
} else {
|
|
4614
4465
|
builder.Add("PresentKey = Identity (KReshaped)");
|
|
4466
|
+
builder.Const1D("PastKVSeqLen", static_cast<int64_t>(0));
|
|
4615
4467
|
}
|
|
4616
4468
|
if (ctx.hasOutput(1)) {
|
|
4617
4469
|
builder.Add("present_key = Identity (PresentKey)");
|
|
@@ -4626,46 +4478,9 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
4626
4478
|
builder.Add("present_value = Identity (PresentValue)");
|
|
4627
4479
|
}
|
|
4628
4480
|
|
|
4629
|
-
|
|
4630
|
-
|
|
4631
|
-
|
|
4632
|
-
auto* up = ctx.getInputType(3);
|
|
4633
|
-
if ((up == nullptr) || (!up->has_tensor_type()))
|
|
4634
|
-
return false;
|
|
4635
|
-
int64_t U = up->tensor_type().elem_type();
|
|
4636
|
-
builder.Const1D("FloatInf", neg_inf);
|
|
4637
|
-
builder.Const1D("ScalarZero", 0.f);
|
|
4638
|
-
builder.Add(
|
|
4639
|
-
U == ONNX_NAMESPACE::TensorProto_DataType_BOOL ? "AttnBias = Where(attn_mask, ScalarZero, FloatInf)"
|
|
4640
|
-
: "AttnBias = Identity(attn_mask)");
|
|
4641
|
-
} else {
|
|
4642
|
-
// If is_causal set to true, the attention masking is a lower triangular matrix when the mask
|
|
4643
|
-
// is a square matrix. The attention masking has the form of the upper left causal bias due to
|
|
4644
|
-
// the alignment when the mask is a non-square matrix.
|
|
4645
|
-
// An error is thrown if both attn_mask and is_causal are set.
|
|
4646
|
-
auto* is_causal_attr = ctx.getAttribute("is_causal");
|
|
4647
|
-
int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
|
|
4648
|
-
if (!is_3d_input) {
|
|
4649
|
-
builder.Add("QSeqLen = Shape <start = -2, end = -1> (Q)"); // q_sequence_length
|
|
4650
|
-
}
|
|
4651
|
-
|
|
4652
|
-
if (is_causal == 1) {
|
|
4653
|
-
builder.Const1D("FloatInf", neg_inf);
|
|
4654
|
-
builder.Const1D("ScalarZero", 0.f);
|
|
4655
|
-
builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)")
|
|
4656
|
-
.Add("AttnBiasShape = Concat <axis = -1> (QSeqLen, NewKVSeqLen)")
|
|
4657
|
-
.Add("AttnBiasZeros_ = ConstantOfShape(AttnBiasShape)")
|
|
4658
|
-
.Add("AttnBiasZeros = CastLike(AttnBiasZeros_, Q)")
|
|
4659
|
-
.Add("BoolMask = ConstantOfShape(AttnBiasShape)", "value", mkbooltensor(1))
|
|
4660
|
-
.Add("BoolMaskTri = Trilu <upper = 0> (BoolMask, Zero1D)")
|
|
4661
|
-
.Add("AttnBias = Where(BoolMaskTri, ScalarZero, FloatInf)");
|
|
4662
|
-
} else {
|
|
4663
|
-
builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)")
|
|
4664
|
-
.Add("AttnBiasShape = Concat <axis = -1> (QSeqLen, NewKVSeqLen)")
|
|
4665
|
-
.Add("AttnBias = ConstantOfShape(AttnBiasShape)");
|
|
4666
|
-
}
|
|
4667
|
-
}
|
|
4668
|
-
builder.Add("AttnBiasT = Cast (AttnBias)", "to", T1);
|
|
4481
|
+
if (!defs::nn::utils::AttentionAppendFunctionCausalMask(ctx, builder, false))
|
|
4482
|
+
return false;
|
|
4483
|
+
builder.Add("AttnBiasT = Cast (AttnBiasCausalOrNot)", "to", T1);
|
|
4669
4484
|
|
|
4670
4485
|
// Group Query Attention is applied if the following are satisfied
|
|
4671
4486
|
// 1) q_num_heads != kv_num_heads
|
|
@@ -4678,10 +4493,25 @@ ONNX_OPERATOR_SET_SCHEMA(
|
|
|
4678
4493
|
.Add("RemainderNumHeads = Mod(QNumHeads, KVNumHeads)")
|
|
4679
4494
|
.Add("GQACond2 = Equal(RemainderNumHeads, Zero1D)")
|
|
4680
4495
|
.Add("GQACond = And(GQACond1, GQACond2)")
|
|
4681
|
-
.Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)")
|
|
4682
|
-
|
|
4683
|
-
|
|
4684
|
-
|
|
4496
|
+
.Add("InterleaveDim = Where(GQACond, IDivNumHeads, One1D)");
|
|
4497
|
+
|
|
4498
|
+
// repeat kv (repeat_interleave)
|
|
4499
|
+
builder.Const1D("Two1D", static_cast<int64_t>(2))
|
|
4500
|
+
.Add("KUnsqueezed = Unsqueeze(PresentKey, Two1D)") // [B, Hk, 1, T, Dk]
|
|
4501
|
+
.Add("VUnsqueezed = Unsqueeze(PresentValue, Two1D)"); // [B, Hk, 1, T, Dv]
|
|
4502
|
+
|
|
4503
|
+
// Build expand shape: [B, Hk, repeats, T, Dk]
|
|
4504
|
+
builder
|
|
4505
|
+
.Add("KExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, QKHeadSize)")
|
|
4506
|
+
.Add("KExpanded = Expand(KUnsqueezed, KExpandShape)");
|
|
4507
|
+
builder.Add("VExpandShape = Concat <axis = 0> (BatchSize, KVNumHeads, InterleaveDim, NewKVSeqLen, VHeadSize)")
|
|
4508
|
+
.Add("VExpanded = Expand(VUnsqueezed, VExpandShape)");
|
|
4509
|
+
|
|
4510
|
+
// Reshape to [B, Hq, T, Dk] where Hq = Hk * repeats
|
|
4511
|
+
builder.Add("KAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, QKHeadSize)")
|
|
4512
|
+
.Add("VAttentionShape = Concat <axis = 0> (BatchSize, QNumHeads, NewKVSeqLen, VHeadSize)")
|
|
4513
|
+
.Add("KAttentionInput = Reshape(KExpanded, KAttentionShape)")
|
|
4514
|
+
.Add("VAttentionInput = Reshape(VExpanded, VAttentionShape)");
|
|
4685
4515
|
|
|
4686
4516
|
// The following pattern is applied
|
|
4687
4517
|
// Q K V
|
onnx/defs/nn/utils.cc
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
*/
|
|
4
|
+
#include "onnx/defs/nn/utils.h"
|
|
5
|
+
|
|
6
|
+
#include <algorithm>
|
|
7
|
+
|
|
8
|
+
namespace ONNX_NAMESPACE {
|
|
9
|
+
namespace defs {
|
|
10
|
+
namespace nn {
|
|
11
|
+
namespace utils {
|
|
12
|
+
|
|
13
|
+
void AttentionPropagateElemTypeFromInputToOutput(InferenceContext& ctx) {
|
|
14
|
+
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
|
15
|
+
|
|
16
|
+
int64_t kv_sequence_length = -1;
|
|
17
|
+
ONNX_NAMESPACE::TensorShapeProto output_shape;
|
|
18
|
+
ONNX_NAMESPACE::TensorShapeProto qk_matmul_shape;
|
|
19
|
+
if (hasInputShape(ctx, 0)) {
|
|
20
|
+
auto& query_shape = getInputShape(ctx, 0);
|
|
21
|
+
auto& query_dims = query_shape.dim();
|
|
22
|
+
if ((query_dims.size() != 3) && (query_dims.size() != 4)) {
|
|
23
|
+
fail_shape_inference("Inputs 0 (query) shall be 3 or 4 dimensions");
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
if (query_dims.size() == 3) {
|
|
27
|
+
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
28
|
+
if (q_num_heads_attr == nullptr) {
|
|
29
|
+
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
30
|
+
}
|
|
31
|
+
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
|
|
32
|
+
if (kv_num_heads_attr == nullptr) {
|
|
33
|
+
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
*output_shape.add_dim() = query_dims[0]; // batch_size
|
|
38
|
+
*output_shape.add_dim() = query_dims[1]; // num_heads for 4D, sequence_length for 3D
|
|
39
|
+
|
|
40
|
+
*qk_matmul_shape.add_dim() = query_dims[0]; // batch_size
|
|
41
|
+
|
|
42
|
+
if (hasInputShape(ctx, 1)) {
|
|
43
|
+
auto& key_shape = getInputShape(ctx, 1);
|
|
44
|
+
auto& key_dims = key_shape.dim();
|
|
45
|
+
if ((key_dims.size() != 3) && (key_dims.size() != 4)) {
|
|
46
|
+
fail_shape_inference("Inputs 1 (key) shall be 3 or 4 dimensions");
|
|
47
|
+
}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
if (hasInputShape(ctx, 2)) {
|
|
51
|
+
auto& value_shape = getInputShape(ctx, 2);
|
|
52
|
+
auto& value_dims = value_shape.dim();
|
|
53
|
+
if ((value_dims.size() != 3) && (value_dims.size() != 4)) {
|
|
54
|
+
fail_shape_inference("Inputs 2 (value) shall be 3 or 4 dimensions");
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
// Update Output Shape for 4D inputs
|
|
58
|
+
// Input 0 (query) has shape (batch_size, q_num_heads, q_sequence_length, head_size)
|
|
59
|
+
// Input 1 (key) has shape (batch_size, kv_num_heads, kv_sequence_length, head_size)
|
|
60
|
+
// Input 2 (value) has shape (batch_size, kv_num_heads, kv_sequence_length, v_head_size)
|
|
61
|
+
// Output 0 has shape (batch_size, q_num_heads, q_sequence_length, v_head_size)
|
|
62
|
+
if (value_dims.size() == 4 && query_dims.size() == 4) {
|
|
63
|
+
kv_sequence_length = value_dims[2].dim_value();
|
|
64
|
+
*output_shape.add_dim() = query_dims[2]; // sequence_length
|
|
65
|
+
*output_shape.add_dim() = value_dims[3]; // head_size
|
|
66
|
+
updateOutputShape(ctx, 0, output_shape);
|
|
67
|
+
// Update qk_matmul_shape
|
|
68
|
+
*qk_matmul_shape.add_dim() = query_dims[1]; // q_num_heads
|
|
69
|
+
*qk_matmul_shape.add_dim() = query_dims[2]; // q_sequence_length
|
|
70
|
+
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
// Update Output Shape for 3D inputs
|
|
74
|
+
// Input 0 (query) has shape (batch_size, q_sequence_length, q_hidden_size),
|
|
75
|
+
// q_hidden_size = q_num_heads * head_size
|
|
76
|
+
// Input 1 (key) has shape (batch_size, kv_sequence_length, k_hidden_size),
|
|
77
|
+
// k_hidden_size = kv_num_heads * head_size
|
|
78
|
+
// Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size),
|
|
79
|
+
// v_hidden_size = kv_num_heads * v_head_size
|
|
80
|
+
// Output 0 has shape (batch_size, q_sequence_length, hidden_size),
|
|
81
|
+
// hidden_size = q_num_heads * v_head_size
|
|
82
|
+
if (value_dims.size() == 3 && query_dims.size() == 3) {
|
|
83
|
+
kv_sequence_length = value_dims[1].dim_value();
|
|
84
|
+
auto* q_num_heads_attr = ctx.getAttribute("q_num_heads");
|
|
85
|
+
if (q_num_heads_attr == nullptr) {
|
|
86
|
+
fail_type_inference("3D inputs expected to have q_num_heads attribute.");
|
|
87
|
+
}
|
|
88
|
+
auto* kv_num_heads_attr = ctx.getAttribute("kv_num_heads");
|
|
89
|
+
if (kv_num_heads_attr == nullptr) {
|
|
90
|
+
fail_type_inference("3D inputs expected to have kv_num_heads attribute.");
|
|
91
|
+
}
|
|
92
|
+
int64_t q_num_heads = q_num_heads_attr->i();
|
|
93
|
+
int64_t kv_num_heads = kv_num_heads_attr->i();
|
|
94
|
+
// Calculate v_head_size
|
|
95
|
+
int64_t v_head_size = value_dims[2].dim_value() / kv_num_heads;
|
|
96
|
+
output_shape.add_dim()->set_dim_value(v_head_size * q_num_heads);
|
|
97
|
+
updateOutputShape(ctx, 0, output_shape);
|
|
98
|
+
// Update qk_matmul_shape
|
|
99
|
+
qk_matmul_shape.add_dim()->set_dim_value(q_num_heads);
|
|
100
|
+
*qk_matmul_shape.add_dim() = query_dims[1];
|
|
101
|
+
qk_matmul_shape.add_dim()->set_dim_value(kv_sequence_length);
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if (ctx.hasOutput(3)) { // has qk_matmul_output
|
|
107
|
+
propagateElemTypeFromInputToOutput(ctx, 0, 3);
|
|
108
|
+
updateOutputShape(ctx, 3, qk_matmul_shape);
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
if (ctx.hasOutput(1) && ctx.hasOutput(2)) { // has present outputs
|
|
112
|
+
if (ctx.hasInput(4) && ctx.hasInput(5)) { // has past_key
|
|
113
|
+
// copy the type from query to present key and value
|
|
114
|
+
propagateElemTypeFromInputToOutput(ctx, 4, 1);
|
|
115
|
+
propagateElemTypeFromInputToOutput(ctx, 5, 2);
|
|
116
|
+
|
|
117
|
+
if (hasInputShape(ctx, 4) && hasInputShape(ctx, 5)) {
|
|
118
|
+
auto& past_key_shape = getInputShape(ctx, 4);
|
|
119
|
+
auto& past_key_dims = past_key_shape.dim();
|
|
120
|
+
auto& past_value_shape = getInputShape(ctx, 5);
|
|
121
|
+
auto& past_value_dims = past_value_shape.dim();
|
|
122
|
+
|
|
123
|
+
// past key has shape (batch_size, kv_num_heads, past_sequence_length, head_size)
|
|
124
|
+
if (past_key_dims.size() != 4) {
|
|
125
|
+
fail_shape_inference("The past_key input shall be 4 dimensions");
|
|
126
|
+
}
|
|
127
|
+
// past value has shape (batch_size, kv_num_heads, past_sequence_length, v_head_size)
|
|
128
|
+
if (past_value_dims.size() != 4) {
|
|
129
|
+
fail_shape_inference("The past_value input shall be 4 dimensions");
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
if (kv_sequence_length > 0 && past_key_dims[2].has_dim_value()) {
|
|
133
|
+
int64_t total_sequence_length = kv_sequence_length + past_key_dims[2].dim_value();
|
|
134
|
+
|
|
135
|
+
ONNX_NAMESPACE::TensorShapeProto present_key_shape;
|
|
136
|
+
for (auto& dim : past_key_dims) {
|
|
137
|
+
*present_key_shape.add_dim() = dim;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
ONNX_NAMESPACE::TensorShapeProto present_value_shape;
|
|
141
|
+
for (auto& dim : past_value_dims) {
|
|
142
|
+
*present_value_shape.add_dim() = dim;
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
if (ctx.hasOutput(3)) { // has qk_matmul_output with bias
|
|
146
|
+
qk_matmul_shape.mutable_dim(3)->set_dim_value(total_sequence_length);
|
|
147
|
+
updateOutputShape(ctx, 3, qk_matmul_shape);
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
// shape of present key/value is (batch_size, kv_num_heads, total_sequence_length, head_size)
|
|
151
|
+
present_key_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
|
|
152
|
+
present_value_shape.mutable_dim(2)->set_dim_value(total_sequence_length);
|
|
153
|
+
|
|
154
|
+
updateOutputShape(ctx, 1, present_key_shape);
|
|
155
|
+
updateOutputShape(ctx, 2, present_value_shape);
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
bool AttentionAppendFunctionCausalMask(const FunctionBodyBuildContext& ctx, FunctionBuilder& builder, bool padding) {
|
|
163
|
+
builder.Add("NewKVSeqLen = Shape <start = -2, end = -1> (PresentKey)")
|
|
164
|
+
.Add("AttnBiasShape = Concat <axis = 0> (QSeqLen, NewKVSeqLen)");
|
|
165
|
+
float neg_inf = -std::numeric_limits<float>::infinity();
|
|
166
|
+
builder.Const1D("FloatNegInf", neg_inf);
|
|
167
|
+
builder.Const1D("ScalarZero", 0.f);
|
|
168
|
+
|
|
169
|
+
// If attn_mask is provided
|
|
170
|
+
if (ctx.hasInput(3)) {
|
|
171
|
+
auto* up = ctx.getInputType(3);
|
|
172
|
+
if ((up == nullptr) || (!up->has_tensor_type()))
|
|
173
|
+
return false;
|
|
174
|
+
int64_t U = up->tensor_type().elem_type();
|
|
175
|
+
builder.Add(
|
|
176
|
+
U == ONNX_NAMESPACE::TensorProto_DataType_BOOL ? "AttnBiasShort = Where(attn_mask, ScalarZero, FloatNegInf)"
|
|
177
|
+
: "AttnBiasShort = Identity(attn_mask)");
|
|
178
|
+
// If attn_mask has a shorter kv sequence length, we pad it to NewKVSeqLen with FloatNegInf
|
|
179
|
+
if (padding) {
|
|
180
|
+
builder.Add("MaskKVSeqLen = Shape <start = -1> (attn_mask)")
|
|
181
|
+
.Add("PaddingKVSeqLen = Sub(NewKVSeqLen, MaskKVSeqLen)")
|
|
182
|
+
.Add("Pads = Concat <axis = 0> (Zero1D, PaddingKVSeqLen)")
|
|
183
|
+
.Add("FloatNegInfCast = CastLike(FloatNegInf, AttnBiasShort)")
|
|
184
|
+
.Add("AttnBias = Pad(AttnBiasShort, Pads, FloatNegInfCast, NegOne1D)");
|
|
185
|
+
} else {
|
|
186
|
+
builder.Add("AttnBias = Identity(AttnBiasShort)");
|
|
187
|
+
}
|
|
188
|
+
} else {
|
|
189
|
+
builder.Add("AttnBias = ConstantOfShape(AttnBiasShape)");
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// If is_causal set to true, the attention masking is a lower triangular matrix when the mask
|
|
193
|
+
// is a square matrix. The attention masking has the form of the upper left causal bias due to
|
|
194
|
+
// the alignment when the mask is a non-square matrix.
|
|
195
|
+
// An error is thrown if both attn_mask and is_causal are set.
|
|
196
|
+
auto* is_causal_attr = ctx.getAttribute("is_causal");
|
|
197
|
+
int64_t is_causal = (is_causal_attr != nullptr) ? is_causal_attr->i() : 0;
|
|
198
|
+
if (is_causal == 1) {
|
|
199
|
+
builder.Const1D("Zero", static_cast<int64_t>(0))
|
|
200
|
+
.Const1D("One", static_cast<int64_t>(1))
|
|
201
|
+
.Add("ZeroNoDim = Squeeze(Zero, Zero)")
|
|
202
|
+
.Add("OneNoDim = Squeeze(One, Zero)")
|
|
203
|
+
.Add("SequenceLength = Gather(AttnBiasShape, ZeroNoDim)")
|
|
204
|
+
.Add("TotalSequenceLength = Gather(AttnBiasShape, OneNoDim)")
|
|
205
|
+
.Add("RangeRow = Range(ZeroNoDim, SequenceLength, OneNoDim)")
|
|
206
|
+
.Add("RangeRow2D = Unsqueeze(RangeRow, One)")
|
|
207
|
+
.Add("RangeCol = Range(ZeroNoDim, TotalSequenceLength, OneNoDim)")
|
|
208
|
+
.Add("RangeCol2D = Unsqueeze(RangeCol, Zero)")
|
|
209
|
+
.Add("RangeRow2DPast = Add(RangeRow2D, PastKVSeqLen)")
|
|
210
|
+
.Add("BoolMaskTri = Less(RangeRow2DPast, RangeCol2D)")
|
|
211
|
+
.Add("MaskTri = Where(BoolMaskTri, FloatNegInf, ScalarZero)")
|
|
212
|
+
.Add("AttnBiasCausalOrNot = Add(AttnBias, MaskTri)");
|
|
213
|
+
} else {
|
|
214
|
+
builder.Add("AttnBiasCausalOrNot = Identity(AttnBias)");
|
|
215
|
+
}
|
|
216
|
+
return true;
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
} // namespace utils
|
|
220
|
+
} // namespace nn
|
|
221
|
+
} // namespace defs
|
|
222
|
+
} // namespace ONNX_NAMESPACE
|
onnx/defs/nn/utils.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
#pragma once
|
|
6
|
+
|
|
7
|
+
#include "onnx/common/assertions.h"
|
|
8
|
+
#include "onnx/defs/function.h"
|
|
9
|
+
#include "onnx/defs/schema.h"
|
|
10
|
+
|
|
11
|
+
namespace ONNX_NAMESPACE {
|
|
12
|
+
namespace defs {
|
|
13
|
+
namespace nn {
|
|
14
|
+
namespace utils {
|
|
15
|
+
|
|
16
|
+
/** Implements shape and type propagation for Attention (23-). */
|
|
17
|
+
void AttentionPropagateElemTypeFromInputToOutput(InferenceContext& ctx);
|
|
18
|
+
|
|
19
|
+
/** Implements CausalMask for Attention. */
|
|
20
|
+
bool AttentionAppendFunctionCausalMask(const FunctionBodyBuildContext& ctx, FunctionBuilder& builder, bool padding);
|
|
21
|
+
|
|
22
|
+
} // namespace utils
|
|
23
|
+
} // namespace nn
|
|
24
|
+
} // namespace defs
|
|
25
|
+
} // namespace ONNX_NAMESPACE
|
|
Binary file
|
|
@@ -24,6 +24,25 @@ def _softcap(X, softcap):
|
|
|
24
24
|
return X
|
|
25
25
|
|
|
26
26
|
|
|
27
|
+
def _apply_causal(mask, past_sequence_length):
|
|
28
|
+
"""Applies a causal mask on the input `mask`:
|
|
29
|
+
``mask[i, j] = -inf if past_sequence_length + i > j else 0``.
|
|
30
|
+
Because a softmax is applied on the mask, -inf becomes 0 and 0 becomes 1.
|
|
31
|
+
The modification is done inplace.
|
|
32
|
+
"""
|
|
33
|
+
q_sequence_length, total_sequence_length = mask.shape[-2:]
|
|
34
|
+
triu = np.triu(
|
|
35
|
+
np.ones(
|
|
36
|
+
(q_sequence_length, total_sequence_length - past_sequence_length),
|
|
37
|
+
dtype=mask.dtype,
|
|
38
|
+
),
|
|
39
|
+
k=1,
|
|
40
|
+
)
|
|
41
|
+
triu[triu == 1] = -np.inf
|
|
42
|
+
mask[..., :, past_sequence_length:] += triu
|
|
43
|
+
return mask
|
|
44
|
+
|
|
45
|
+
|
|
27
46
|
def _compute_attention(
|
|
28
47
|
Q: np.ndarray,
|
|
29
48
|
K: np.ndarray,
|
|
@@ -114,18 +133,18 @@ def _compute_attention(
|
|
|
114
133
|
# bias due to the alignment when the mask is a non-square matrix.
|
|
115
134
|
if is_causal:
|
|
116
135
|
if attn_mask is None:
|
|
117
|
-
temp_mask = np.
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
136
|
+
temp_mask = np.zeros((q_sequence_length, kv_sequence_length), dtype=Q.dtype)
|
|
137
|
+
attn_bias = _apply_causal(
|
|
138
|
+
temp_mask,
|
|
139
|
+
past_sequence_length=past_key.shape[2] if past_key is not None else 0,
|
|
140
|
+
)
|
|
122
141
|
else:
|
|
123
142
|
if attn_mask.dtype == np.bool_:
|
|
124
143
|
attn_mask = (1 - attn_mask).astype(Q.dtype) * (-np.inf)
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
144
|
+
attn_bias = _apply_causal(
|
|
145
|
+
attn_mask.copy(),
|
|
146
|
+
past_sequence_length=past_key.shape[2] if past_key is not None else 0,
|
|
147
|
+
)
|
|
129
148
|
elif attn_mask is not None:
|
|
130
149
|
if attn_mask.dtype == np.bool_:
|
|
131
150
|
attn_mask = (1 - attn_mask).astype(Q.dtype)
|
|
@@ -158,10 +177,10 @@ def _compute_attention(
|
|
|
158
177
|
and (q_num_heads % k_num_heads == 0)
|
|
159
178
|
and (k_num_heads == v_num_heads)
|
|
160
179
|
):
|
|
161
|
-
seq_reps =
|
|
162
|
-
|
|
163
|
-
K = np.
|
|
164
|
-
V = np.
|
|
180
|
+
seq_reps = q_num_heads // k_num_heads
|
|
181
|
+
# Interleave-repeat each KV head: [h0, h0, h1, h1, ...]
|
|
182
|
+
K = np.repeat(K, repeats=seq_reps, axis=1)
|
|
183
|
+
V = np.repeat(V, repeats=seq_reps, axis=1)
|
|
165
184
|
|
|
166
185
|
# The following pattern is applied
|
|
167
186
|
# Q K V
|
|
@@ -185,7 +204,7 @@ def _compute_attention(
|
|
|
185
204
|
qk_matmul_output = np.matmul(Q * scale, k_transpose * scale)
|
|
186
205
|
qk_with_bias = qk_matmul_output + attn_bias
|
|
187
206
|
if qk_matmul_output_mode == 1:
|
|
188
|
-
qk_matmul_output =
|
|
207
|
+
qk_matmul_output = qk_with_bias.copy()
|
|
189
208
|
|
|
190
209
|
# Apply softcap
|
|
191
210
|
if softcap is not None:
|