sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -0,0 +1,1359 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Adapted from https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/layers/fla/ops/kda.py
         
     | 
| 
      
 2 
     | 
    
         
            +
            # This file contains code copied from the flash-linear-attention project.
         
     | 
| 
      
 3 
     | 
    
         
            +
            # The original source code was licensed under the MIT license and included
         
     | 
| 
      
 4 
     | 
    
         
            +
            # the following copyright notice:
         
     | 
| 
      
 5 
     | 
    
         
            +
            # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 8 
     | 
    
         
            +
            import torch.nn as nn
         
     | 
| 
      
 9 
     | 
    
         
            +
            import triton
         
     | 
| 
      
 10 
     | 
    
         
            +
            import triton.language as tl
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
         
     | 
| 
      
 13 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
         
     | 
| 
      
 14 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.fused_recurrent import (
         
     | 
| 
      
 15 
     | 
    
         
            +
                fused_recurrent_gated_delta_rule_fwd_kernel,
         
     | 
| 
      
 16 
     | 
    
         
            +
            )
         
     | 
| 
      
 17 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
         
     | 
| 
      
 18 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
         
     | 
| 
      
 19 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.op import exp, log
         
     | 
| 
      
 20 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.solve_tril import solve_tril
         
     | 
| 
      
 21 
     | 
    
         
            +
            from sglang.srt.layers.attention.fla.utils import is_amd
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            BT_LIST_AUTOTUNE = [32, 64, 128]
         
     | 
| 
      
 24 
     | 
    
         
            +
            NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
            def cdiv(a: int, b: int) -> int:
         
     | 
| 
      
 28 
     | 
    
         
            +
                """Ceiling division."""
         
     | 
| 
      
 29 
     | 
    
         
            +
                return -(a // -b)
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            def next_power_of_2(n: int) -> int:
         
     | 
| 
      
 33 
     | 
    
         
            +
                """The next power of 2 (inclusive)"""
         
     | 
| 
      
 34 
     | 
    
         
            +
                if n < 1:
         
     | 
| 
      
 35 
     | 
    
         
            +
                    return 1
         
     | 
| 
      
 36 
     | 
    
         
            +
                return 1 << (n - 1).bit_length()
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
            def fused_recurrent_kda_fwd(
         
     | 
| 
      
 40 
     | 
    
         
            +
                q: torch.Tensor,
         
     | 
| 
      
 41 
     | 
    
         
            +
                k: torch.Tensor,
         
     | 
| 
      
 42 
     | 
    
         
            +
                v: torch.Tensor,
         
     | 
| 
      
 43 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 44 
     | 
    
         
            +
                beta: torch.Tensor,
         
     | 
| 
      
 45 
     | 
    
         
            +
                scale: float,
         
     | 
| 
      
 46 
     | 
    
         
            +
                initial_state: torch.Tensor,
         
     | 
| 
      
 47 
     | 
    
         
            +
                inplace_final_state: bool = True,
         
     | 
| 
      
 48 
     | 
    
         
            +
                cu_seqlens: torch.LongTensor | None = None,
         
     | 
| 
      
 49 
     | 
    
         
            +
                # ssm_state_indices: torch.Tensor | None = None,
         
     | 
| 
      
 50 
     | 
    
         
            +
                num_accepted_tokens: torch.Tensor | None = None,
         
     | 
| 
      
 51 
     | 
    
         
            +
                use_qk_l2norm_in_kernel: bool = False,
         
     | 
| 
      
 52 
     | 
    
         
            +
            ) -> tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 53 
     | 
    
         
            +
                B, T, H, K, V = *k.shape, v.shape[-1]
         
     | 
| 
      
 54 
     | 
    
         
            +
                HV = v.shape[2]
         
     | 
| 
      
 55 
     | 
    
         
            +
                N = B if cu_seqlens is None else len(cu_seqlens) - 1
         
     | 
| 
      
 56 
     | 
    
         
            +
                BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8)
         
     | 
| 
      
 57 
     | 
    
         
            +
                NK, NV = cdiv(K, BK), cdiv(V, BV)
         
     | 
| 
      
 58 
     | 
    
         
            +
                assert NK == 1, "NK > 1 is not supported yet"
         
     | 
| 
      
 59 
     | 
    
         
            +
                num_stages = 3
         
     | 
| 
      
 60 
     | 
    
         
            +
                num_warps = 1
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                o = torch.empty_like(k)
         
     | 
| 
      
 63 
     | 
    
         
            +
                if inplace_final_state:
         
     | 
| 
      
 64 
     | 
    
         
            +
                    final_state = initial_state
         
     | 
| 
      
 65 
     | 
    
         
            +
                else:
         
     | 
| 
      
 66 
     | 
    
         
            +
                    final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                stride_init_state_token = initial_state.stride(0)
         
     | 
| 
      
 69 
     | 
    
         
            +
                stride_final_state_token = final_state.stride(0)
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                # if ssm_state_indices is None:
         
     | 
| 
      
 72 
     | 
    
         
            +
                #     stride_indices_seq, stride_indices_tok = 1, 1
         
     | 
| 
      
 73 
     | 
    
         
            +
                # elif ssm_state_indices.ndim == 1:
         
     | 
| 
      
 74 
     | 
    
         
            +
                #     stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
         
     | 
| 
      
 75 
     | 
    
         
            +
                # else:
         
     | 
| 
      
 76 
     | 
    
         
            +
                #     stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                grid = (NK, NV, N * HV)
         
     | 
| 
      
 79 
     | 
    
         
            +
                fused_recurrent_gated_delta_rule_fwd_kernel[grid](
         
     | 
| 
      
 80 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 81 
     | 
    
         
            +
                    k=k,
         
     | 
| 
      
 82 
     | 
    
         
            +
                    v=v,
         
     | 
| 
      
 83 
     | 
    
         
            +
                    g=g,
         
     | 
| 
      
 84 
     | 
    
         
            +
                    beta=beta,
         
     | 
| 
      
 85 
     | 
    
         
            +
                    o=o,
         
     | 
| 
      
 86 
     | 
    
         
            +
                    h0=initial_state,
         
     | 
| 
      
 87 
     | 
    
         
            +
                    ht=final_state,
         
     | 
| 
      
 88 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 89 
     | 
    
         
            +
                    # ssm_state_indices=ssm_state_indices,
         
     | 
| 
      
 90 
     | 
    
         
            +
                    # num_accepted_tokens=num_accepted_tokens,
         
     | 
| 
      
 91 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 92 
     | 
    
         
            +
                    # N=N,
         
     | 
| 
      
 93 
     | 
    
         
            +
                    T=T,
         
     | 
| 
      
 94 
     | 
    
         
            +
                    B=B,
         
     | 
| 
      
 95 
     | 
    
         
            +
                    H=H,
         
     | 
| 
      
 96 
     | 
    
         
            +
                    HV=HV,
         
     | 
| 
      
 97 
     | 
    
         
            +
                    K=K,
         
     | 
| 
      
 98 
     | 
    
         
            +
                    V=V,
         
     | 
| 
      
 99 
     | 
    
         
            +
                    BK=BK,
         
     | 
| 
      
 100 
     | 
    
         
            +
                    BV=BV,
         
     | 
| 
      
 101 
     | 
    
         
            +
                    # stride_init_state_token=stride_init_state_token,
         
     | 
| 
      
 102 
     | 
    
         
            +
                    # stride_final_state_token=stride_final_state_token,
         
     | 
| 
      
 103 
     | 
    
         
            +
                    # stride_indices_seq=stride_indices_seq,
         
     | 
| 
      
 104 
     | 
    
         
            +
                    # stride_indices_tok=stride_indices_tok,
         
     | 
| 
      
 105 
     | 
    
         
            +
                    IS_BETA_HEADWISE=beta.ndim == v.ndim,
         
     | 
| 
      
 106 
     | 
    
         
            +
                    USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
         
     | 
| 
      
 107 
     | 
    
         
            +
                    # INPLACE_FINAL_STATE=inplace_final_state,
         
     | 
| 
      
 108 
     | 
    
         
            +
                    IS_KDA=True,
         
     | 
| 
      
 109 
     | 
    
         
            +
                    num_warps=num_warps,
         
     | 
| 
      
 110 
     | 
    
         
            +
                    num_stages=num_stages,
         
     | 
| 
      
 111 
     | 
    
         
            +
                )
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                return o, final_state
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
            def fused_recurrent_kda(
         
     | 
| 
      
 117 
     | 
    
         
            +
                q: torch.Tensor,
         
     | 
| 
      
 118 
     | 
    
         
            +
                k: torch.Tensor,
         
     | 
| 
      
 119 
     | 
    
         
            +
                v: torch.Tensor,
         
     | 
| 
      
 120 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 121 
     | 
    
         
            +
                beta: torch.Tensor = None,
         
     | 
| 
      
 122 
     | 
    
         
            +
                scale: float = None,
         
     | 
| 
      
 123 
     | 
    
         
            +
                initial_state: torch.Tensor = None,
         
     | 
| 
      
 124 
     | 
    
         
            +
                inplace_final_state: bool = True,
         
     | 
| 
      
 125 
     | 
    
         
            +
                use_qk_l2norm_in_kernel: bool = True,
         
     | 
| 
      
 126 
     | 
    
         
            +
                cu_seqlens: torch.LongTensor | None = None,
         
     | 
| 
      
 127 
     | 
    
         
            +
                # ssm_state_indices: torch.LongTensor | None = None,
         
     | 
| 
      
 128 
     | 
    
         
            +
                **kwargs,
         
     | 
| 
      
 129 
     | 
    
         
            +
            ) -> tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 130 
     | 
    
         
            +
                if cu_seqlens is not None and q.shape[0] != 1:
         
     | 
| 
      
 131 
     | 
    
         
            +
                    raise ValueError(
         
     | 
| 
      
 132 
     | 
    
         
            +
                        f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
         
     | 
| 
      
 133 
     | 
    
         
            +
                        f"Please flatten variable-length inputs before processing."
         
     | 
| 
      
 134 
     | 
    
         
            +
                    )
         
     | 
| 
      
 135 
     | 
    
         
            +
                if scale is None:
         
     | 
| 
      
 136 
     | 
    
         
            +
                    scale = k.shape[-1] ** -0.5
         
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
                o, final_state = fused_recurrent_kda_fwd(
         
     | 
| 
      
 139 
     | 
    
         
            +
                    q=q.contiguous(),
         
     | 
| 
      
 140 
     | 
    
         
            +
                    k=k.contiguous(),
         
     | 
| 
      
 141 
     | 
    
         
            +
                    v=v.contiguous(),
         
     | 
| 
      
 142 
     | 
    
         
            +
                    g=g.contiguous(),
         
     | 
| 
      
 143 
     | 
    
         
            +
                    beta=beta.contiguous(),
         
     | 
| 
      
 144 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 145 
     | 
    
         
            +
                    initial_state=initial_state,
         
     | 
| 
      
 146 
     | 
    
         
            +
                    inplace_final_state=inplace_final_state,
         
     | 
| 
      
 147 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 148 
     | 
    
         
            +
                    # ssm_state_indices=ssm_state_indices,
         
     | 
| 
      
 149 
     | 
    
         
            +
                    num_accepted_tokens=None,
         
     | 
| 
      
 150 
     | 
    
         
            +
                    use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
         
     | 
| 
      
 151 
     | 
    
         
            +
                )
         
     | 
| 
      
 152 
     | 
    
         
            +
                return o, final_state
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
            @triton.heuristics(
         
     | 
| 
      
 156 
     | 
    
         
            +
                {
         
     | 
| 
      
 157 
     | 
    
         
            +
                    "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
         
     | 
| 
      
 158 
     | 
    
         
            +
                    "HAS_RESIDUAL": lambda args: args["residual"] is not None,
         
     | 
| 
      
 159 
     | 
    
         
            +
                    "HAS_WEIGHT": lambda args: args["w"] is not None,
         
     | 
| 
      
 160 
     | 
    
         
            +
                    "HAS_BIAS": lambda args: args["b"] is not None,
         
     | 
| 
      
 161 
     | 
    
         
            +
                }
         
     | 
| 
      
 162 
     | 
    
         
            +
            )
         
     | 
| 
      
 163 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 164 
     | 
    
         
            +
            def layer_norm_gated_fwd_kernel(
         
     | 
| 
      
 165 
     | 
    
         
            +
                x,  # pointer to the input
         
     | 
| 
      
 166 
     | 
    
         
            +
                g,  # pointer to the gate
         
     | 
| 
      
 167 
     | 
    
         
            +
                y,  # pointer to the output
         
     | 
| 
      
 168 
     | 
    
         
            +
                w,  # pointer to the weights
         
     | 
| 
      
 169 
     | 
    
         
            +
                b,  # pointer to the biases
         
     | 
| 
      
 170 
     | 
    
         
            +
                residual,  # pointer to the residual
         
     | 
| 
      
 171 
     | 
    
         
            +
                residual_out,  # pointer to the residual
         
     | 
| 
      
 172 
     | 
    
         
            +
                mean,  # pointer to the mean
         
     | 
| 
      
 173 
     | 
    
         
            +
                rstd,  # pointer to the 1/std
         
     | 
| 
      
 174 
     | 
    
         
            +
                eps,  # epsilon to avoid division by zero
         
     | 
| 
      
 175 
     | 
    
         
            +
                T,  # number of rows in x
         
     | 
| 
      
 176 
     | 
    
         
            +
                D: tl.constexpr,  # number of columns in x
         
     | 
| 
      
 177 
     | 
    
         
            +
                BT: tl.constexpr,
         
     | 
| 
      
 178 
     | 
    
         
            +
                BD: tl.constexpr,
         
     | 
| 
      
 179 
     | 
    
         
            +
                ACTIVATION: tl.constexpr,
         
     | 
| 
      
 180 
     | 
    
         
            +
                IS_RMS_NORM: tl.constexpr,
         
     | 
| 
      
 181 
     | 
    
         
            +
                STORE_RESIDUAL_OUT: tl.constexpr,
         
     | 
| 
      
 182 
     | 
    
         
            +
                HAS_RESIDUAL: tl.constexpr,
         
     | 
| 
      
 183 
     | 
    
         
            +
                HAS_WEIGHT: tl.constexpr,
         
     | 
| 
      
 184 
     | 
    
         
            +
                HAS_BIAS: tl.constexpr,
         
     | 
| 
      
 185 
     | 
    
         
            +
            ):
         
     | 
| 
      
 186 
     | 
    
         
            +
                i_t = tl.program_id(0)
         
     | 
| 
      
 187 
     | 
    
         
            +
             
     | 
| 
      
 188 
     | 
    
         
            +
                o_d = tl.arange(0, BD)
         
     | 
| 
      
 189 
     | 
    
         
            +
                m_d = o_d < D
         
     | 
| 
      
 190 
     | 
    
         
            +
             
     | 
| 
      
 191 
     | 
    
         
            +
                p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
         
     | 
| 
      
 192 
     | 
    
         
            +
                b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
         
     | 
| 
      
 193 
     | 
    
         
            +
                if HAS_RESIDUAL:
         
     | 
| 
      
 194 
     | 
    
         
            +
                    p_res = tl.make_block_ptr(
         
     | 
| 
      
 195 
     | 
    
         
            +
                        residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
         
     | 
| 
      
 196 
     | 
    
         
            +
                    )
         
     | 
| 
      
 197 
     | 
    
         
            +
                    b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)
         
     | 
| 
      
 198 
     | 
    
         
            +
                if STORE_RESIDUAL_OUT:
         
     | 
| 
      
 199 
     | 
    
         
            +
                    p_res_out = tl.make_block_ptr(
         
     | 
| 
      
 200 
     | 
    
         
            +
                        residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
         
     | 
| 
      
 201 
     | 
    
         
            +
                    )
         
     | 
| 
      
 202 
     | 
    
         
            +
                    tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 203 
     | 
    
         
            +
                if not IS_RMS_NORM:
         
     | 
| 
      
 204 
     | 
    
         
            +
                    b_mean = tl.sum(b_x, axis=1) / D
         
     | 
| 
      
 205 
     | 
    
         
            +
                    p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))
         
     | 
| 
      
 206 
     | 
    
         
            +
                    tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))
         
     | 
| 
      
 207 
     | 
    
         
            +
                    b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)
         
     | 
| 
      
 208 
     | 
    
         
            +
                    b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
         
     | 
| 
      
 209 
     | 
    
         
            +
                else:
         
     | 
| 
      
 210 
     | 
    
         
            +
                    b_xbar = tl.where(m_d[None, :], b_x, 0.0)
         
     | 
| 
      
 211 
     | 
    
         
            +
                    b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
         
     | 
| 
      
 212 
     | 
    
         
            +
                b_rstd = 1 / tl.sqrt(b_var + eps)
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
         
     | 
| 
      
 215 
     | 
    
         
            +
                tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))
         
     | 
| 
      
 216 
     | 
    
         
            +
             
     | 
| 
      
 217 
     | 
    
         
            +
                if HAS_WEIGHT:
         
     | 
| 
      
 218 
     | 
    
         
            +
                    b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
         
     | 
| 
      
 219 
     | 
    
         
            +
                if HAS_BIAS:
         
     | 
| 
      
 220 
     | 
    
         
            +
                    b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
         
     | 
| 
      
 221 
     | 
    
         
            +
                b_x_hat = (
         
     | 
| 
      
 222 
     | 
    
         
            +
                    (b_x - b_mean[:, None]) * b_rstd[:, None]
         
     | 
| 
      
 223 
     | 
    
         
            +
                    if not IS_RMS_NORM
         
     | 
| 
      
 224 
     | 
    
         
            +
                    else b_x * b_rstd[:, None]
         
     | 
| 
      
 225 
     | 
    
         
            +
                )
         
     | 
| 
      
 226 
     | 
    
         
            +
                b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat
         
     | 
| 
      
 227 
     | 
    
         
            +
                if HAS_BIAS:
         
     | 
| 
      
 228 
     | 
    
         
            +
                    b_y = b_y + b_b[None, :]
         
     | 
| 
      
 229 
     | 
    
         
            +
             
     | 
| 
      
 230 
     | 
    
         
            +
                # swish/sigmoid output gate
         
     | 
| 
      
 231 
     | 
    
         
            +
                p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
         
     | 
| 
      
 232 
     | 
    
         
            +
                b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
         
     | 
| 
      
 233 
     | 
    
         
            +
                if ACTIVATION == "swish" or ACTIVATION == "silu":
         
     | 
| 
      
 234 
     | 
    
         
            +
                    b_y = b_y * b_g * tl.sigmoid(b_g)
         
     | 
| 
      
 235 
     | 
    
         
            +
                elif ACTIVATION == "sigmoid":
         
     | 
| 
      
 236 
     | 
    
         
            +
                    b_y = b_y * tl.sigmoid(b_g)
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
                # Write output
         
     | 
| 
      
 239 
     | 
    
         
            +
                p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
         
     | 
| 
      
 240 
     | 
    
         
            +
                tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
             
     | 
| 
      
 243 
     | 
    
         
            +
            @triton.heuristics(
         
     | 
| 
      
 244 
     | 
    
         
            +
                {
         
     | 
| 
      
 245 
     | 
    
         
            +
                    "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
         
     | 
| 
      
 246 
     | 
    
         
            +
                    "HAS_RESIDUAL": lambda args: args["residual"] is not None,
         
     | 
| 
      
 247 
     | 
    
         
            +
                    "HAS_WEIGHT": lambda args: args["w"] is not None,
         
     | 
| 
      
 248 
     | 
    
         
            +
                    "HAS_BIAS": lambda args: args["b"] is not None,
         
     | 
| 
      
 249 
     | 
    
         
            +
                }
         
     | 
| 
      
 250 
     | 
    
         
            +
            )
         
     | 
| 
      
 251 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 252 
     | 
    
         
            +
            def layer_norm_gated_fwd_kernel1(
         
     | 
| 
      
 253 
     | 
    
         
            +
                x,  # pointer to the input
         
     | 
| 
      
 254 
     | 
    
         
            +
                g,  # pointer to the gate
         
     | 
| 
      
 255 
     | 
    
         
            +
                y,  # pointer to the output
         
     | 
| 
      
 256 
     | 
    
         
            +
                w,  # pointer to the weights
         
     | 
| 
      
 257 
     | 
    
         
            +
                b,  # pointer to the biases
         
     | 
| 
      
 258 
     | 
    
         
            +
                residual,  # pointer to the residual
         
     | 
| 
      
 259 
     | 
    
         
            +
                residual_out,  # pointer to the residual
         
     | 
| 
      
 260 
     | 
    
         
            +
                mean,  # pointer to the mean
         
     | 
| 
      
 261 
     | 
    
         
            +
                rstd,  # pointer to the 1/std
         
     | 
| 
      
 262 
     | 
    
         
            +
                eps,  # epsilon to avoid division by zero
         
     | 
| 
      
 263 
     | 
    
         
            +
                D: tl.constexpr,  # number of columns in x
         
     | 
| 
      
 264 
     | 
    
         
            +
                BD: tl.constexpr,
         
     | 
| 
      
 265 
     | 
    
         
            +
                ACTIVATION: tl.constexpr,
         
     | 
| 
      
 266 
     | 
    
         
            +
                IS_RMS_NORM: tl.constexpr,
         
     | 
| 
      
 267 
     | 
    
         
            +
                STORE_RESIDUAL_OUT: tl.constexpr,
         
     | 
| 
      
 268 
     | 
    
         
            +
                HAS_RESIDUAL: tl.constexpr,
         
     | 
| 
      
 269 
     | 
    
         
            +
                HAS_WEIGHT: tl.constexpr,
         
     | 
| 
      
 270 
     | 
    
         
            +
                HAS_BIAS: tl.constexpr,
         
     | 
| 
      
 271 
     | 
    
         
            +
            ):
         
     | 
| 
      
 272 
     | 
    
         
            +
                i_t = tl.program_id(0)
         
     | 
| 
      
 273 
     | 
    
         
            +
                x += i_t * D
         
     | 
| 
      
 274 
     | 
    
         
            +
                y += i_t * D
         
     | 
| 
      
 275 
     | 
    
         
            +
                g += i_t * D
         
     | 
| 
      
 276 
     | 
    
         
            +
                if HAS_RESIDUAL:
         
     | 
| 
      
 277 
     | 
    
         
            +
                    residual += i_t * D
         
     | 
| 
      
 278 
     | 
    
         
            +
                if STORE_RESIDUAL_OUT:
         
     | 
| 
      
 279 
     | 
    
         
            +
                    residual_out += i_t * D
         
     | 
| 
      
 280 
     | 
    
         
            +
             
     | 
| 
      
 281 
     | 
    
         
            +
                o_d = tl.arange(0, BD)
         
     | 
| 
      
 282 
     | 
    
         
            +
                m_d = o_d < D
         
     | 
| 
      
 283 
     | 
    
         
            +
                b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32)
         
     | 
| 
      
 284 
     | 
    
         
            +
                if HAS_RESIDUAL:
         
     | 
| 
      
 285 
     | 
    
         
            +
                    b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32)
         
     | 
| 
      
 286 
     | 
    
         
            +
                if STORE_RESIDUAL_OUT:
         
     | 
| 
      
 287 
     | 
    
         
            +
                    tl.store(residual_out + o_d, b_x, mask=m_d)
         
     | 
| 
      
 288 
     | 
    
         
            +
                if not IS_RMS_NORM:
         
     | 
| 
      
 289 
     | 
    
         
            +
                    b_mean = tl.sum(b_x, axis=0) / D
         
     | 
| 
      
 290 
     | 
    
         
            +
                    tl.store(mean + i_t, b_mean)
         
     | 
| 
      
 291 
     | 
    
         
            +
                    b_xbar = tl.where(m_d, b_x - b_mean, 0.0)
         
     | 
| 
      
 292 
     | 
    
         
            +
                    b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
         
     | 
| 
      
 293 
     | 
    
         
            +
                else:
         
     | 
| 
      
 294 
     | 
    
         
            +
                    b_xbar = tl.where(m_d, b_x, 0.0)
         
     | 
| 
      
 295 
     | 
    
         
            +
                    b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
         
     | 
| 
      
 296 
     | 
    
         
            +
                b_rstd = 1 / tl.sqrt(b_var + eps)
         
     | 
| 
      
 297 
     | 
    
         
            +
                tl.store(rstd + i_t, b_rstd)
         
     | 
| 
      
 298 
     | 
    
         
            +
             
     | 
| 
      
 299 
     | 
    
         
            +
                if HAS_WEIGHT:
         
     | 
| 
      
 300 
     | 
    
         
            +
                    b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
         
     | 
| 
      
 301 
     | 
    
         
            +
                if HAS_BIAS:
         
     | 
| 
      
 302 
     | 
    
         
            +
                    b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
         
     | 
| 
      
 303 
     | 
    
         
            +
                b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd
         
     | 
| 
      
 304 
     | 
    
         
            +
                b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat
         
     | 
| 
      
 305 
     | 
    
         
            +
                if HAS_BIAS:
         
     | 
| 
      
 306 
     | 
    
         
            +
                    b_y = b_y + b_b
         
     | 
| 
      
 307 
     | 
    
         
            +
             
     | 
| 
      
 308 
     | 
    
         
            +
                # swish/sigmoid output gate
         
     | 
| 
      
 309 
     | 
    
         
            +
                b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32)
         
     | 
| 
      
 310 
     | 
    
         
            +
                if ACTIVATION == "swish" or ACTIVATION == "silu":
         
     | 
| 
      
 311 
     | 
    
         
            +
                    b_y = b_y * b_g * tl.sigmoid(b_g)
         
     | 
| 
      
 312 
     | 
    
         
            +
                elif ACTIVATION == "sigmoid":
         
     | 
| 
      
 313 
     | 
    
         
            +
                    b_y = b_y * tl.sigmoid(b_g)
         
     | 
| 
      
 314 
     | 
    
         
            +
             
     | 
| 
      
 315 
     | 
    
         
            +
                # Write output
         
     | 
| 
      
 316 
     | 
    
         
            +
                tl.store(y + o_d, b_y, mask=m_d)
         
     | 
| 
      
 317 
     | 
    
         
            +
             
     | 
| 
      
 318 
     | 
    
         
            +
             
     | 
| 
      
 319 
     | 
    
         
            +
            def layer_norm_gated_fwd(
         
     | 
| 
      
 320 
     | 
    
         
            +
                x: torch.Tensor,
         
     | 
| 
      
 321 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 322 
     | 
    
         
            +
                weight: torch.Tensor,
         
     | 
| 
      
 323 
     | 
    
         
            +
                bias: torch.Tensor,
         
     | 
| 
      
 324 
     | 
    
         
            +
                activation: str = "swish",
         
     | 
| 
      
 325 
     | 
    
         
            +
                eps: float = 1e-5,
         
     | 
| 
      
 326 
     | 
    
         
            +
                residual: torch.Tensor = None,
         
     | 
| 
      
 327 
     | 
    
         
            +
                out_dtype: torch.dtype = None,
         
     | 
| 
      
 328 
     | 
    
         
            +
                residual_dtype: torch.dtype = None,
         
     | 
| 
      
 329 
     | 
    
         
            +
                is_rms_norm: bool = False,
         
     | 
| 
      
 330 
     | 
    
         
            +
            ):
         
     | 
| 
      
 331 
     | 
    
         
            +
                if residual is not None:
         
     | 
| 
      
 332 
     | 
    
         
            +
                    residual_dtype = residual.dtype
         
     | 
| 
      
 333 
     | 
    
         
            +
                T, D = x.shape
         
     | 
| 
      
 334 
     | 
    
         
            +
                if residual is not None:
         
     | 
| 
      
 335 
     | 
    
         
            +
                    assert residual.shape == (T, D)
         
     | 
| 
      
 336 
     | 
    
         
            +
                if weight is not None:
         
     | 
| 
      
 337 
     | 
    
         
            +
                    assert weight.shape == (D,)
         
     | 
| 
      
 338 
     | 
    
         
            +
                if bias is not None:
         
     | 
| 
      
 339 
     | 
    
         
            +
                    assert bias.shape == (D,)
         
     | 
| 
      
 340 
     | 
    
         
            +
                # allocate output
         
     | 
| 
      
 341 
     | 
    
         
            +
                y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype)
         
     | 
| 
      
 342 
     | 
    
         
            +
                if residual is not None or (
         
     | 
| 
      
 343 
     | 
    
         
            +
                    residual_dtype is not None and residual_dtype != x.dtype
         
     | 
| 
      
 344 
     | 
    
         
            +
                ):
         
     | 
| 
      
 345 
     | 
    
         
            +
                    residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)
         
     | 
| 
      
 346 
     | 
    
         
            +
                else:
         
     | 
| 
      
 347 
     | 
    
         
            +
                    residual_out = None
         
     | 
| 
      
 348 
     | 
    
         
            +
                mean = (
         
     | 
| 
      
 349 
     | 
    
         
            +
                    torch.empty((T,), dtype=torch.float, device=x.device)
         
     | 
| 
      
 350 
     | 
    
         
            +
                    if not is_rms_norm
         
     | 
| 
      
 351 
     | 
    
         
            +
                    else None
         
     | 
| 
      
 352 
     | 
    
         
            +
                )
         
     | 
| 
      
 353 
     | 
    
         
            +
                rstd = torch.empty((T,), dtype=torch.float, device=x.device)
         
     | 
| 
      
 354 
     | 
    
         
            +
                # Less than 64KB per feature: enqueue fused kernel
         
     | 
| 
      
 355 
     | 
    
         
            +
                MAX_FUSED_SIZE = 65536 // x.element_size()
         
     | 
| 
      
 356 
     | 
    
         
            +
                BD = min(MAX_FUSED_SIZE, next_power_of_2(D))
         
     | 
| 
      
 357 
     | 
    
         
            +
                if D > BD:
         
     | 
| 
      
 358 
     | 
    
         
            +
                    raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
         
     | 
| 
      
 359 
     | 
    
         
            +
                # heuristics for number of warps
         
     | 
| 
      
 360 
     | 
    
         
            +
             
     | 
| 
      
 361 
     | 
    
         
            +
                if D <= 512:
         
     | 
| 
      
 362 
     | 
    
         
            +
                    BT = 32
         
     | 
| 
      
 363 
     | 
    
         
            +
                    layer_norm_gated_fwd_kernel[(cdiv(T, BT),)](
         
     | 
| 
      
 364 
     | 
    
         
            +
                        x=x,
         
     | 
| 
      
 365 
     | 
    
         
            +
                        g=g,
         
     | 
| 
      
 366 
     | 
    
         
            +
                        y=y,
         
     | 
| 
      
 367 
     | 
    
         
            +
                        w=weight,
         
     | 
| 
      
 368 
     | 
    
         
            +
                        b=bias,
         
     | 
| 
      
 369 
     | 
    
         
            +
                        residual=residual,
         
     | 
| 
      
 370 
     | 
    
         
            +
                        residual_out=residual_out,
         
     | 
| 
      
 371 
     | 
    
         
            +
                        mean=mean,
         
     | 
| 
      
 372 
     | 
    
         
            +
                        rstd=rstd,
         
     | 
| 
      
 373 
     | 
    
         
            +
                        eps=eps,
         
     | 
| 
      
 374 
     | 
    
         
            +
                        T=T,
         
     | 
| 
      
 375 
     | 
    
         
            +
                        D=D,
         
     | 
| 
      
 376 
     | 
    
         
            +
                        BD=BD,
         
     | 
| 
      
 377 
     | 
    
         
            +
                        BT=BT,
         
     | 
| 
      
 378 
     | 
    
         
            +
                        ACTIVATION=activation,
         
     | 
| 
      
 379 
     | 
    
         
            +
                        IS_RMS_NORM=is_rms_norm,
         
     | 
| 
      
 380 
     | 
    
         
            +
                        num_warps=4,
         
     | 
| 
      
 381 
     | 
    
         
            +
                    )
         
     | 
| 
      
 382 
     | 
    
         
            +
                else:
         
     | 
| 
      
 383 
     | 
    
         
            +
                    layer_norm_gated_fwd_kernel1[(T,)](
         
     | 
| 
      
 384 
     | 
    
         
            +
                        x=x,
         
     | 
| 
      
 385 
     | 
    
         
            +
                        g=g,
         
     | 
| 
      
 386 
     | 
    
         
            +
                        y=y,
         
     | 
| 
      
 387 
     | 
    
         
            +
                        w=weight,
         
     | 
| 
      
 388 
     | 
    
         
            +
                        b=bias,
         
     | 
| 
      
 389 
     | 
    
         
            +
                        residual=residual,
         
     | 
| 
      
 390 
     | 
    
         
            +
                        residual_out=residual_out,
         
     | 
| 
      
 391 
     | 
    
         
            +
                        mean=mean,
         
     | 
| 
      
 392 
     | 
    
         
            +
                        rstd=rstd,
         
     | 
| 
      
 393 
     | 
    
         
            +
                        eps=eps,
         
     | 
| 
      
 394 
     | 
    
         
            +
                        D=D,
         
     | 
| 
      
 395 
     | 
    
         
            +
                        BD=BD,
         
     | 
| 
      
 396 
     | 
    
         
            +
                        ACTIVATION=activation,
         
     | 
| 
      
 397 
     | 
    
         
            +
                        IS_RMS_NORM=is_rms_norm,
         
     | 
| 
      
 398 
     | 
    
         
            +
                        num_warps=4,
         
     | 
| 
      
 399 
     | 
    
         
            +
                    )
         
     | 
| 
      
 400 
     | 
    
         
            +
                # residual_out is None if residual is None and residual_dtype == input_dtype
         
     | 
| 
      
 401 
     | 
    
         
            +
                return y, mean, rstd, residual_out if residual_out is not None else x
         
     | 
| 
      
 402 
     | 
    
         
            +
             
     | 
| 
      
 403 
     | 
    
         
            +
             
     | 
| 
      
 404 
     | 
    
         
            +
            def rms_norm_gated(
         
     | 
| 
      
 405 
     | 
    
         
            +
                x: torch.Tensor,
         
     | 
| 
      
 406 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 407 
     | 
    
         
            +
                weight: torch.Tensor,
         
     | 
| 
      
 408 
     | 
    
         
            +
                bias: torch.Tensor,
         
     | 
| 
      
 409 
     | 
    
         
            +
                activation: str = "swish",
         
     | 
| 
      
 410 
     | 
    
         
            +
                residual: torch.Tensor | None = None,
         
     | 
| 
      
 411 
     | 
    
         
            +
                prenorm: bool = False,
         
     | 
| 
      
 412 
     | 
    
         
            +
                residual_in_fp32: bool = False,
         
     | 
| 
      
 413 
     | 
    
         
            +
                eps: float = 1e-6,
         
     | 
| 
      
 414 
     | 
    
         
            +
            ):
         
     | 
| 
      
 415 
     | 
    
         
            +
                x_shape_og = x.shape
         
     | 
| 
      
 416 
     | 
    
         
            +
                # reshape input data into 2D tensor
         
     | 
| 
      
 417 
     | 
    
         
            +
                x = x.contiguous().reshape(-1, x.shape[-1])
         
     | 
| 
      
 418 
     | 
    
         
            +
                g = g.contiguous().reshape(-1, g.shape[-1])
         
     | 
| 
      
 419 
     | 
    
         
            +
                if residual is not None:
         
     | 
| 
      
 420 
     | 
    
         
            +
                    assert residual.shape == x_shape_og
         
     | 
| 
      
 421 
     | 
    
         
            +
                    residual = residual.contiguous().reshape(-1, residual.shape[-1])
         
     | 
| 
      
 422 
     | 
    
         
            +
                residual_dtype = (
         
     | 
| 
      
 423 
     | 
    
         
            +
                    residual.dtype
         
     | 
| 
      
 424 
     | 
    
         
            +
                    if residual is not None
         
     | 
| 
      
 425 
     | 
    
         
            +
                    else (torch.float if residual_in_fp32 else None)
         
     | 
| 
      
 426 
     | 
    
         
            +
                )
         
     | 
| 
      
 427 
     | 
    
         
            +
                y, _, _, residual_out = layer_norm_gated_fwd(
         
     | 
| 
      
 428 
     | 
    
         
            +
                    x=x,
         
     | 
| 
      
 429 
     | 
    
         
            +
                    g=g,
         
     | 
| 
      
 430 
     | 
    
         
            +
                    weight=weight,
         
     | 
| 
      
 431 
     | 
    
         
            +
                    bias=bias,
         
     | 
| 
      
 432 
     | 
    
         
            +
                    activation=activation,
         
     | 
| 
      
 433 
     | 
    
         
            +
                    eps=eps,
         
     | 
| 
      
 434 
     | 
    
         
            +
                    residual=residual,
         
     | 
| 
      
 435 
     | 
    
         
            +
                    residual_dtype=residual_dtype,
         
     | 
| 
      
 436 
     | 
    
         
            +
                    is_rms_norm=True,
         
     | 
| 
      
 437 
     | 
    
         
            +
                )
         
     | 
| 
      
 438 
     | 
    
         
            +
                y = y.reshape(x_shape_og)
         
     | 
| 
      
 439 
     | 
    
         
            +
                return y if not prenorm else (y, residual_out.reshape(x_shape_og))
         
     | 
| 
      
 440 
     | 
    
         
            +
             
     | 
| 
      
 441 
     | 
    
         
            +
             
     | 
| 
      
 442 
     | 
    
         
            +
            class FusedRMSNormGated(nn.Module):
         
     | 
| 
      
 443 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 444 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 445 
     | 
    
         
            +
                    hidden_size: int,
         
     | 
| 
      
 446 
     | 
    
         
            +
                    elementwise_affine: bool = True,
         
     | 
| 
      
 447 
     | 
    
         
            +
                    eps: float = 1e-5,
         
     | 
| 
      
 448 
     | 
    
         
            +
                    activation: str = "swish",
         
     | 
| 
      
 449 
     | 
    
         
            +
                    device: torch.device | None = None,
         
     | 
| 
      
 450 
     | 
    
         
            +
                    dtype: torch.dtype | None = None,
         
     | 
| 
      
 451 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 452 
     | 
    
         
            +
                    factory_kwargs = {"device": device, "dtype": dtype}
         
     | 
| 
      
 453 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 454 
     | 
    
         
            +
             
     | 
| 
      
 455 
     | 
    
         
            +
                    self.hidden_size = hidden_size
         
     | 
| 
      
 456 
     | 
    
         
            +
                    self.elementwise_affine = elementwise_affine
         
     | 
| 
      
 457 
     | 
    
         
            +
                    self.eps = eps
         
     | 
| 
      
 458 
     | 
    
         
            +
                    self.activation = activation
         
     | 
| 
      
 459 
     | 
    
         
            +
             
     | 
| 
      
 460 
     | 
    
         
            +
                    if self.activation not in ["swish", "silu", "sigmoid"]:
         
     | 
| 
      
 461 
     | 
    
         
            +
                        raise ValueError(f"Unsupported activation: {self.activation}")
         
     | 
| 
      
 462 
     | 
    
         
            +
             
     | 
| 
      
 463 
     | 
    
         
            +
                    if elementwise_affine:
         
     | 
| 
      
 464 
     | 
    
         
            +
                        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
         
     | 
| 
      
 465 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 466 
     | 
    
         
            +
                        self.register_parameter("weight", None)
         
     | 
| 
      
 467 
     | 
    
         
            +
                    self.register_parameter("bias", None)
         
     | 
| 
      
 468 
     | 
    
         
            +
             
     | 
| 
      
 469 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 470 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 471 
     | 
    
         
            +
                    x: torch.Tensor,
         
     | 
| 
      
 472 
     | 
    
         
            +
                    g: torch.Tensor,
         
     | 
| 
      
 473 
     | 
    
         
            +
                    residual: torch.Tensor | None = None,
         
     | 
| 
      
 474 
     | 
    
         
            +
                    prenorm: bool = False,
         
     | 
| 
      
 475 
     | 
    
         
            +
                    residual_in_fp32: bool = False,
         
     | 
| 
      
 476 
     | 
    
         
            +
                ) -> torch.Tensor:
         
     | 
| 
      
 477 
     | 
    
         
            +
                    return rms_norm_gated(
         
     | 
| 
      
 478 
     | 
    
         
            +
                        x,
         
     | 
| 
      
 479 
     | 
    
         
            +
                        g,
         
     | 
| 
      
 480 
     | 
    
         
            +
                        self.weight,
         
     | 
| 
      
 481 
     | 
    
         
            +
                        self.bias,
         
     | 
| 
      
 482 
     | 
    
         
            +
                        self.activation,
         
     | 
| 
      
 483 
     | 
    
         
            +
                        residual=residual,
         
     | 
| 
      
 484 
     | 
    
         
            +
                        eps=self.eps,
         
     | 
| 
      
 485 
     | 
    
         
            +
                        prenorm=prenorm,
         
     | 
| 
      
 486 
     | 
    
         
            +
                        residual_in_fp32=residual_in_fp32,
         
     | 
| 
      
 487 
     | 
    
         
            +
                    )
         
     | 
| 
      
 488 
     | 
    
         
            +
             
     | 
| 
      
 489 
     | 
    
         
            +
             
     | 
| 
      
 490 
     | 
    
         
            +
            @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
         
     | 
| 
      
 491 
     | 
    
         
            +
            @triton.autotune(
         
     | 
| 
      
 492 
     | 
    
         
            +
                configs=[
         
     | 
| 
      
 493 
     | 
    
         
            +
                    triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
         
     | 
| 
      
 494 
     | 
    
         
            +
                    for BK in [32, 64]
         
     | 
| 
      
 495 
     | 
    
         
            +
                    for num_warps in [1, 2, 4, 8]
         
     | 
| 
      
 496 
     | 
    
         
            +
                    for num_stages in [2, 3, 4]
         
     | 
| 
      
 497 
     | 
    
         
            +
                ],
         
     | 
| 
      
 498 
     | 
    
         
            +
                key=["BC"],
         
     | 
| 
      
 499 
     | 
    
         
            +
            )
         
     | 
| 
      
 500 
     | 
    
         
            +
            @triton.jit(do_not_specialize=["T"])
         
     | 
| 
      
 501 
     | 
    
         
            +
            def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
         
     | 
| 
      
 502 
     | 
    
         
            +
                q,
         
     | 
| 
      
 503 
     | 
    
         
            +
                k,
         
     | 
| 
      
 504 
     | 
    
         
            +
                g,
         
     | 
| 
      
 505 
     | 
    
         
            +
                beta,
         
     | 
| 
      
 506 
     | 
    
         
            +
                A,
         
     | 
| 
      
 507 
     | 
    
         
            +
                Aqk,
         
     | 
| 
      
 508 
     | 
    
         
            +
                scale,
         
     | 
| 
      
 509 
     | 
    
         
            +
                cu_seqlens,
         
     | 
| 
      
 510 
     | 
    
         
            +
                chunk_indices,
         
     | 
| 
      
 511 
     | 
    
         
            +
                T,
         
     | 
| 
      
 512 
     | 
    
         
            +
                H: tl.constexpr,
         
     | 
| 
      
 513 
     | 
    
         
            +
                K: tl.constexpr,
         
     | 
| 
      
 514 
     | 
    
         
            +
                BT: tl.constexpr,
         
     | 
| 
      
 515 
     | 
    
         
            +
                BC: tl.constexpr,
         
     | 
| 
      
 516 
     | 
    
         
            +
                BK: tl.constexpr,
         
     | 
| 
      
 517 
     | 
    
         
            +
                NC: tl.constexpr,
         
     | 
| 
      
 518 
     | 
    
         
            +
                IS_VARLEN: tl.constexpr,
         
     | 
| 
      
 519 
     | 
    
         
            +
            ):
         
     | 
| 
      
 520 
     | 
    
         
            +
                i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
         
     | 
| 
      
 521 
     | 
    
         
            +
                i_b, i_h = i_bh // H, i_bh % H
         
     | 
| 
      
 522 
     | 
    
         
            +
                i_i, i_j = i_c // NC, i_c % NC
         
     | 
| 
      
 523 
     | 
    
         
            +
                if IS_VARLEN:
         
     | 
| 
      
 524 
     | 
    
         
            +
                    i_n, i_t = (
         
     | 
| 
      
 525 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2).to(tl.int32),
         
     | 
| 
      
 526 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
         
     | 
| 
      
 527 
     | 
    
         
            +
                    )
         
     | 
| 
      
 528 
     | 
    
         
            +
                    bos, eos = (
         
     | 
| 
      
 529 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n).to(tl.int32),
         
     | 
| 
      
 530 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n + 1).to(tl.int32),
         
     | 
| 
      
 531 
     | 
    
         
            +
                    )
         
     | 
| 
      
 532 
     | 
    
         
            +
                    T = eos - bos
         
     | 
| 
      
 533 
     | 
    
         
            +
                else:
         
     | 
| 
      
 534 
     | 
    
         
            +
                    bos, eos = i_b * T, i_b * T + T
         
     | 
| 
      
 535 
     | 
    
         
            +
             
     | 
| 
      
 536 
     | 
    
         
            +
                if i_t * BT + i_i * BC >= T:
         
     | 
| 
      
 537 
     | 
    
         
            +
                    return
         
     | 
| 
      
 538 
     | 
    
         
            +
                if i_i <= i_j:
         
     | 
| 
      
 539 
     | 
    
         
            +
                    return
         
     | 
| 
      
 540 
     | 
    
         
            +
             
     | 
| 
      
 541 
     | 
    
         
            +
                q += (bos * H + i_h) * K
         
     | 
| 
      
 542 
     | 
    
         
            +
                k += (bos * H + i_h) * K
         
     | 
| 
      
 543 
     | 
    
         
            +
                g += (bos * H + i_h) * K
         
     | 
| 
      
 544 
     | 
    
         
            +
                A += (bos * H + i_h) * BT
         
     | 
| 
      
 545 
     | 
    
         
            +
                Aqk += (bos * H + i_h) * BT
         
     | 
| 
      
 546 
     | 
    
         
            +
             
     | 
| 
      
 547 
     | 
    
         
            +
                p_b = tl.make_block_ptr(
         
     | 
| 
      
 548 
     | 
    
         
            +
                    beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)
         
     | 
| 
      
 549 
     | 
    
         
            +
                )
         
     | 
| 
      
 550 
     | 
    
         
            +
                b_b = tl.load(p_b, boundary_check=(0,))
         
     | 
| 
      
 551 
     | 
    
         
            +
             
     | 
| 
      
 552 
     | 
    
         
            +
                b_A = tl.zeros([BC, BC], dtype=tl.float32)
         
     | 
| 
      
 553 
     | 
    
         
            +
                b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
         
     | 
| 
      
 554 
     | 
    
         
            +
                for i_k in range(tl.cdiv(K, BK)):
         
     | 
| 
      
 555 
     | 
    
         
            +
                    p_q = tl.make_block_ptr(
         
     | 
| 
      
 556 
     | 
    
         
            +
                        q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
         
     | 
| 
      
 557 
     | 
    
         
            +
                    )
         
     | 
| 
      
 558 
     | 
    
         
            +
                    p_k = tl.make_block_ptr(
         
     | 
| 
      
 559 
     | 
    
         
            +
                        k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
         
     | 
| 
      
 560 
     | 
    
         
            +
                    )
         
     | 
| 
      
 561 
     | 
    
         
            +
                    p_g = tl.make_block_ptr(
         
     | 
| 
      
 562 
     | 
    
         
            +
                        g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
         
     | 
| 
      
 563 
     | 
    
         
            +
                    )
         
     | 
| 
      
 564 
     | 
    
         
            +
                    b_kt = tl.make_block_ptr(
         
     | 
| 
      
 565 
     | 
    
         
            +
                        k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
         
     | 
| 
      
 566 
     | 
    
         
            +
                    )
         
     | 
| 
      
 567 
     | 
    
         
            +
                    p_gk = tl.make_block_ptr(
         
     | 
| 
      
 568 
     | 
    
         
            +
                        g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
         
     | 
| 
      
 569 
     | 
    
         
            +
                    )
         
     | 
| 
      
 570 
     | 
    
         
            +
             
     | 
| 
      
 571 
     | 
    
         
            +
                    o_k = i_k * BK + tl.arange(0, BK)
         
     | 
| 
      
 572 
     | 
    
         
            +
                    m_k = o_k < K
         
     | 
| 
      
 573 
     | 
    
         
            +
                    # [BK,]
         
     | 
| 
      
 574 
     | 
    
         
            +
                    b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
         
     | 
| 
      
 575 
     | 
    
         
            +
                    # [BC, BK]
         
     | 
| 
      
 576 
     | 
    
         
            +
                    b_g = tl.load(p_g, boundary_check=(0, 1))
         
     | 
| 
      
 577 
     | 
    
         
            +
                    b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
         
     | 
| 
      
 578 
     | 
    
         
            +
                    # [BK, BC]
         
     | 
| 
      
 579 
     | 
    
         
            +
                    b_gk = tl.load(p_gk, boundary_check=(0, 1))
         
     | 
| 
      
 580 
     | 
    
         
            +
                    b_kt = tl.load(b_kt, boundary_check=(0, 1))
         
     | 
| 
      
 581 
     | 
    
         
            +
                    # [BC, BC]
         
     | 
| 
      
 582 
     | 
    
         
            +
                    b_ktg = b_kt * exp(b_gn[:, None] - b_gk)
         
     | 
| 
      
 583 
     | 
    
         
            +
                    b_A += tl.dot(b_k, b_ktg)
         
     | 
| 
      
 584 
     | 
    
         
            +
             
     | 
| 
      
 585 
     | 
    
         
            +
                    b_q = tl.load(p_q, boundary_check=(0, 1))
         
     | 
| 
      
 586 
     | 
    
         
            +
                    b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
         
     | 
| 
      
 587 
     | 
    
         
            +
                    b_Aqk += tl.dot(b_qg, b_ktg)
         
     | 
| 
      
 588 
     | 
    
         
            +
             
     | 
| 
      
 589 
     | 
    
         
            +
                b_A *= b_b[:, None]
         
     | 
| 
      
 590 
     | 
    
         
            +
             
     | 
| 
      
 591 
     | 
    
         
            +
                p_A = tl.make_block_ptr(
         
     | 
| 
      
 592 
     | 
    
         
            +
                    A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
         
     | 
| 
      
 593 
     | 
    
         
            +
                )
         
     | 
| 
      
 594 
     | 
    
         
            +
                tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 595 
     | 
    
         
            +
                p_Aqk = tl.make_block_ptr(
         
     | 
| 
      
 596 
     | 
    
         
            +
                    Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
         
     | 
| 
      
 597 
     | 
    
         
            +
                )
         
     | 
| 
      
 598 
     | 
    
         
            +
                tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 599 
     | 
    
         
            +
             
     | 
| 
      
 600 
     | 
    
         
            +
             
     | 
| 
      
 601 
     | 
    
         
            +
            @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
         
     | 
| 
      
 602 
     | 
    
         
            +
            @triton.autotune(
         
     | 
| 
      
 603 
     | 
    
         
            +
                configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
         
     | 
| 
      
 604 
     | 
    
         
            +
                key=["BK", "BT"],
         
     | 
| 
      
 605 
     | 
    
         
            +
            )
         
     | 
| 
      
 606 
     | 
    
         
            +
            @triton.jit(do_not_specialize=["T"])
         
     | 
| 
      
 607 
     | 
    
         
            +
            def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
         
     | 
| 
      
 608 
     | 
    
         
            +
                q,
         
     | 
| 
      
 609 
     | 
    
         
            +
                k,
         
     | 
| 
      
 610 
     | 
    
         
            +
                g,
         
     | 
| 
      
 611 
     | 
    
         
            +
                beta,
         
     | 
| 
      
 612 
     | 
    
         
            +
                A,
         
     | 
| 
      
 613 
     | 
    
         
            +
                Aqk,
         
     | 
| 
      
 614 
     | 
    
         
            +
                scale,
         
     | 
| 
      
 615 
     | 
    
         
            +
                cu_seqlens,
         
     | 
| 
      
 616 
     | 
    
         
            +
                chunk_indices,
         
     | 
| 
      
 617 
     | 
    
         
            +
                T,
         
     | 
| 
      
 618 
     | 
    
         
            +
                H: tl.constexpr,
         
     | 
| 
      
 619 
     | 
    
         
            +
                K: tl.constexpr,
         
     | 
| 
      
 620 
     | 
    
         
            +
                BT: tl.constexpr,
         
     | 
| 
      
 621 
     | 
    
         
            +
                BC: tl.constexpr,
         
     | 
| 
      
 622 
     | 
    
         
            +
                BK: tl.constexpr,
         
     | 
| 
      
 623 
     | 
    
         
            +
                IS_VARLEN: tl.constexpr,
         
     | 
| 
      
 624 
     | 
    
         
            +
            ):
         
     | 
| 
      
 625 
     | 
    
         
            +
                i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
         
     | 
| 
      
 626 
     | 
    
         
            +
                i_b, i_h = i_bh // H, i_bh % H
         
     | 
| 
      
 627 
     | 
    
         
            +
                if IS_VARLEN:
         
     | 
| 
      
 628 
     | 
    
         
            +
                    i_n, i_t = (
         
     | 
| 
      
 629 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2).to(tl.int32),
         
     | 
| 
      
 630 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
         
     | 
| 
      
 631 
     | 
    
         
            +
                    )
         
     | 
| 
      
 632 
     | 
    
         
            +
                    bos, eos = (
         
     | 
| 
      
 633 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n).to(tl.int32),
         
     | 
| 
      
 634 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n + 1).to(tl.int32),
         
     | 
| 
      
 635 
     | 
    
         
            +
                    )
         
     | 
| 
      
 636 
     | 
    
         
            +
                    T = eos - bos
         
     | 
| 
      
 637 
     | 
    
         
            +
                else:
         
     | 
| 
      
 638 
     | 
    
         
            +
                    bos, eos = i_b * T, i_b * T + T
         
     | 
| 
      
 639 
     | 
    
         
            +
             
     | 
| 
      
 640 
     | 
    
         
            +
                if i_t * BT + i_i * BC >= T:
         
     | 
| 
      
 641 
     | 
    
         
            +
                    return
         
     | 
| 
      
 642 
     | 
    
         
            +
             
     | 
| 
      
 643 
     | 
    
         
            +
                o_i = tl.arange(0, BC)
         
     | 
| 
      
 644 
     | 
    
         
            +
                o_k = tl.arange(0, BK)
         
     | 
| 
      
 645 
     | 
    
         
            +
                m_k = o_k < K
         
     | 
| 
      
 646 
     | 
    
         
            +
                m_A = (i_t * BT + i_i * BC + o_i) < T
         
     | 
| 
      
 647 
     | 
    
         
            +
                o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
         
     | 
| 
      
 648 
     | 
    
         
            +
             
     | 
| 
      
 649 
     | 
    
         
            +
                p_q = tl.make_block_ptr(
         
     | 
| 
      
 650 
     | 
    
         
            +
                    q + (bos * H + i_h) * K,
         
     | 
| 
      
 651 
     | 
    
         
            +
                    (T, K),
         
     | 
| 
      
 652 
     | 
    
         
            +
                    (H * K, 1),
         
     | 
| 
      
 653 
     | 
    
         
            +
                    (i_t * BT + i_i * BC, 0),
         
     | 
| 
      
 654 
     | 
    
         
            +
                    (BC, BK),
         
     | 
| 
      
 655 
     | 
    
         
            +
                    (1, 0),
         
     | 
| 
      
 656 
     | 
    
         
            +
                )
         
     | 
| 
      
 657 
     | 
    
         
            +
                p_k = tl.make_block_ptr(
         
     | 
| 
      
 658 
     | 
    
         
            +
                    k + (bos * H + i_h) * K,
         
     | 
| 
      
 659 
     | 
    
         
            +
                    (T, K),
         
     | 
| 
      
 660 
     | 
    
         
            +
                    (H * K, 1),
         
     | 
| 
      
 661 
     | 
    
         
            +
                    (i_t * BT + i_i * BC, 0),
         
     | 
| 
      
 662 
     | 
    
         
            +
                    (BC, BK),
         
     | 
| 
      
 663 
     | 
    
         
            +
                    (1, 0),
         
     | 
| 
      
 664 
     | 
    
         
            +
                )
         
     | 
| 
      
 665 
     | 
    
         
            +
                p_g = tl.make_block_ptr(
         
     | 
| 
      
 666 
     | 
    
         
            +
                    g + (bos * H + i_h) * K,
         
     | 
| 
      
 667 
     | 
    
         
            +
                    (T, K),
         
     | 
| 
      
 668 
     | 
    
         
            +
                    (H * K, 1),
         
     | 
| 
      
 669 
     | 
    
         
            +
                    (i_t * BT + i_i * BC, 0),
         
     | 
| 
      
 670 
     | 
    
         
            +
                    (BC, BK),
         
     | 
| 
      
 671 
     | 
    
         
            +
                    (1, 0),
         
     | 
| 
      
 672 
     | 
    
         
            +
                )
         
     | 
| 
      
 673 
     | 
    
         
            +
                b_q = tl.load(p_q, boundary_check=(0, 1))
         
     | 
| 
      
 674 
     | 
    
         
            +
                b_k = tl.load(p_k, boundary_check=(0, 1))
         
     | 
| 
      
 675 
     | 
    
         
            +
                b_g = tl.load(p_g, boundary_check=(0, 1))
         
     | 
| 
      
 676 
     | 
    
         
            +
             
     | 
| 
      
 677 
     | 
    
         
            +
                p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
         
     | 
| 
      
 678 
     | 
    
         
            +
                b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None]
         
     | 
| 
      
 679 
     | 
    
         
            +
             
     | 
| 
      
 680 
     | 
    
         
            +
                p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
         
     | 
| 
      
 681 
     | 
    
         
            +
                p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
         
     | 
| 
      
 682 
     | 
    
         
            +
             
     | 
| 
      
 683 
     | 
    
         
            +
                for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
         
     | 
| 
      
 684 
     | 
    
         
            +
                    b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
         
     | 
| 
      
 685 
     | 
    
         
            +
                    b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
         
     | 
| 
      
 686 
     | 
    
         
            +
                    b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :])
         
     | 
| 
      
 687 
     | 
    
         
            +
                    b_A = tl.sum(b_k * b_ktg, 1)
         
     | 
| 
      
 688 
     | 
    
         
            +
                    b_A = tl.where(o_i > j, b_A, 0.0)
         
     | 
| 
      
 689 
     | 
    
         
            +
                    b_Aqk = tl.sum(b_q * b_ktg, 1)
         
     | 
| 
      
 690 
     | 
    
         
            +
                    b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0)
         
     | 
| 
      
 691 
     | 
    
         
            +
                    tl.store(A + o_A + j, b_A, mask=m_A)
         
     | 
| 
      
 692 
     | 
    
         
            +
                    tl.store(Aqk + o_A + j, b_Aqk, mask=m_A)
         
     | 
| 
      
 693 
     | 
    
         
            +
                    p_kt += H * K
         
     | 
| 
      
 694 
     | 
    
         
            +
                    p_gk += H * K
         
     | 
| 
      
 695 
     | 
    
         
            +
             
     | 
| 
      
 696 
     | 
    
         
            +
             
     | 
| 
      
 697 
     | 
    
         
            +
            def chunk_kda_scaled_dot_kkt_fwd(
         
     | 
| 
      
 698 
     | 
    
         
            +
                q: torch.Tensor,
         
     | 
| 
      
 699 
     | 
    
         
            +
                k: torch.Tensor,
         
     | 
| 
      
 700 
     | 
    
         
            +
                gk: torch.Tensor | None = None,
         
     | 
| 
      
 701 
     | 
    
         
            +
                beta: torch.Tensor | None = None,
         
     | 
| 
      
 702 
     | 
    
         
            +
                scale: float | None = None,
         
     | 
| 
      
 703 
     | 
    
         
            +
                cu_seqlens: torch.LongTensor | None = None,
         
     | 
| 
      
 704 
     | 
    
         
            +
                chunk_size: int = 64,
         
     | 
| 
      
 705 
     | 
    
         
            +
                output_dtype: torch.dtype = torch.float32,
         
     | 
| 
      
 706 
     | 
    
         
            +
            ) -> tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 707 
     | 
    
         
            +
                r"""
         
     | 
| 
      
 708 
     | 
    
         
            +
                Compute beta * K * K^T.
         
     | 
| 
      
 709 
     | 
    
         
            +
             
     | 
| 
      
 710 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 711 
     | 
    
         
            +
                    k (torch.Tensor):
         
     | 
| 
      
 712 
     | 
    
         
            +
                        The key tensor of shape `[B, T, H, K]`.
         
     | 
| 
      
 713 
     | 
    
         
            +
                    beta (torch.Tensor):
         
     | 
| 
      
 714 
     | 
    
         
            +
                        The beta tensor of shape `[B, T, H]`.
         
     | 
| 
      
 715 
     | 
    
         
            +
                    gk (torch.Tensor):
         
     | 
| 
      
 716 
     | 
    
         
            +
                        The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
         
     | 
| 
      
 717 
     | 
    
         
            +
                    cu_seqlens (torch.LongTensor):
         
     | 
| 
      
 718 
     | 
    
         
            +
                        The cumulative sequence lengths of the input tensor.
         
     | 
| 
      
 719 
     | 
    
         
            +
                        Default: None
         
     | 
| 
      
 720 
     | 
    
         
            +
                    chunk_size (int):
         
     | 
| 
      
 721 
     | 
    
         
            +
                        The chunk size. Default: 64.
         
     | 
| 
      
 722 
     | 
    
         
            +
                    output_dtype (torch.dtype):
         
     | 
| 
      
 723 
     | 
    
         
            +
                        The dtype of the output tensor. Default: `torch.float32`
         
     | 
| 
      
 724 
     | 
    
         
            +
             
     | 
| 
      
 725 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 726 
     | 
    
         
            +
                    beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
         
     | 
| 
      
 727 
     | 
    
         
            +
                """
         
     | 
| 
      
 728 
     | 
    
         
            +
                B, T, H, K = k.shape
         
     | 
| 
      
 729 
     | 
    
         
            +
                assert K <= 256
         
     | 
| 
      
 730 
     | 
    
         
            +
                BT = chunk_size
         
     | 
| 
      
 731 
     | 
    
         
            +
                chunk_indices = (
         
     | 
| 
      
 732 
     | 
    
         
            +
                    prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
         
     | 
| 
      
 733 
     | 
    
         
            +
                )
         
     | 
| 
      
 734 
     | 
    
         
            +
                NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
         
     | 
| 
      
 735 
     | 
    
         
            +
             
     | 
| 
      
 736 
     | 
    
         
            +
                BC = min(16, BT)
         
     | 
| 
      
 737 
     | 
    
         
            +
                NC = cdiv(BT, BC)
         
     | 
| 
      
 738 
     | 
    
         
            +
                BK = max(next_power_of_2(K), 16)
         
     | 
| 
      
 739 
     | 
    
         
            +
                A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
         
     | 
| 
      
 740 
     | 
    
         
            +
                Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
         
     | 
| 
      
 741 
     | 
    
         
            +
                grid = (NT, NC * NC, B * H)
         
     | 
| 
      
 742 
     | 
    
         
            +
                chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
         
     | 
| 
      
 743 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 744 
     | 
    
         
            +
                    k=k,
         
     | 
| 
      
 745 
     | 
    
         
            +
                    g=gk,
         
     | 
| 
      
 746 
     | 
    
         
            +
                    beta=beta,
         
     | 
| 
      
 747 
     | 
    
         
            +
                    A=A,
         
     | 
| 
      
 748 
     | 
    
         
            +
                    Aqk=Aqk,
         
     | 
| 
      
 749 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 750 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 751 
     | 
    
         
            +
                    chunk_indices=chunk_indices,
         
     | 
| 
      
 752 
     | 
    
         
            +
                    T=T,
         
     | 
| 
      
 753 
     | 
    
         
            +
                    H=H,
         
     | 
| 
      
 754 
     | 
    
         
            +
                    K=K,
         
     | 
| 
      
 755 
     | 
    
         
            +
                    BT=BT,
         
     | 
| 
      
 756 
     | 
    
         
            +
                    BC=BC,
         
     | 
| 
      
 757 
     | 
    
         
            +
                    NC=NC,
         
     | 
| 
      
 758 
     | 
    
         
            +
                )
         
     | 
| 
      
 759 
     | 
    
         
            +
             
     | 
| 
      
 760 
     | 
    
         
            +
                grid = (NT, NC, B * H)
         
     | 
| 
      
 761 
     | 
    
         
            +
                chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
         
     | 
| 
      
 762 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 763 
     | 
    
         
            +
                    k=k,
         
     | 
| 
      
 764 
     | 
    
         
            +
                    g=gk,
         
     | 
| 
      
 765 
     | 
    
         
            +
                    beta=beta,
         
     | 
| 
      
 766 
     | 
    
         
            +
                    A=A,
         
     | 
| 
      
 767 
     | 
    
         
            +
                    Aqk=Aqk,
         
     | 
| 
      
 768 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 769 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 770 
     | 
    
         
            +
                    chunk_indices=chunk_indices,
         
     | 
| 
      
 771 
     | 
    
         
            +
                    T=T,
         
     | 
| 
      
 772 
     | 
    
         
            +
                    H=H,
         
     | 
| 
      
 773 
     | 
    
         
            +
                    K=K,
         
     | 
| 
      
 774 
     | 
    
         
            +
                    BT=BT,
         
     | 
| 
      
 775 
     | 
    
         
            +
                    BC=BC,
         
     | 
| 
      
 776 
     | 
    
         
            +
                    BK=BK,
         
     | 
| 
      
 777 
     | 
    
         
            +
                )
         
     | 
| 
      
 778 
     | 
    
         
            +
                return A, Aqk
         
     | 
| 
      
 779 
     | 
    
         
            +
             
     | 
| 
      
 780 
     | 
    
         
            +
             
     | 
| 
      
 781 
     | 
    
         
            +
            @triton.heuristics(
         
     | 
| 
      
 782 
     | 
    
         
            +
                {
         
     | 
| 
      
 783 
     | 
    
         
            +
                    "STORE_QG": lambda args: args["qg"] is not None,
         
     | 
| 
      
 784 
     | 
    
         
            +
                    "STORE_KG": lambda args: args["kg"] is not None,
         
     | 
| 
      
 785 
     | 
    
         
            +
                    "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
         
     | 
| 
      
 786 
     | 
    
         
            +
                }
         
     | 
| 
      
 787 
     | 
    
         
            +
            )
         
     | 
| 
      
 788 
     | 
    
         
            +
            @triton.autotune(
         
     | 
| 
      
 789 
     | 
    
         
            +
                configs=[
         
     | 
| 
      
 790 
     | 
    
         
            +
                    triton.Config({}, num_warps=num_warps, num_stages=num_stages)
         
     | 
| 
      
 791 
     | 
    
         
            +
                    for num_warps in [2, 4, 8]
         
     | 
| 
      
 792 
     | 
    
         
            +
                    for num_stages in [2, 3, 4]
         
     | 
| 
      
 793 
     | 
    
         
            +
                ],
         
     | 
| 
      
 794 
     | 
    
         
            +
                key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
         
     | 
| 
      
 795 
     | 
    
         
            +
            )
         
     | 
| 
      
 796 
     | 
    
         
            +
            @triton.jit(do_not_specialize=["T"])
         
     | 
| 
      
 797 
     | 
    
         
            +
            def recompute_w_u_fwd_kernel(
         
     | 
| 
      
 798 
     | 
    
         
            +
                q,
         
     | 
| 
      
 799 
     | 
    
         
            +
                k,
         
     | 
| 
      
 800 
     | 
    
         
            +
                qg,
         
     | 
| 
      
 801 
     | 
    
         
            +
                kg,
         
     | 
| 
      
 802 
     | 
    
         
            +
                v,
         
     | 
| 
      
 803 
     | 
    
         
            +
                beta,
         
     | 
| 
      
 804 
     | 
    
         
            +
                w,
         
     | 
| 
      
 805 
     | 
    
         
            +
                u,
         
     | 
| 
      
 806 
     | 
    
         
            +
                A,
         
     | 
| 
      
 807 
     | 
    
         
            +
                gk,
         
     | 
| 
      
 808 
     | 
    
         
            +
                cu_seqlens,
         
     | 
| 
      
 809 
     | 
    
         
            +
                chunk_indices,
         
     | 
| 
      
 810 
     | 
    
         
            +
                T,
         
     | 
| 
      
 811 
     | 
    
         
            +
                H: tl.constexpr,
         
     | 
| 
      
 812 
     | 
    
         
            +
                K: tl.constexpr,
         
     | 
| 
      
 813 
     | 
    
         
            +
                V: tl.constexpr,
         
     | 
| 
      
 814 
     | 
    
         
            +
                BT: tl.constexpr,
         
     | 
| 
      
 815 
     | 
    
         
            +
                BK: tl.constexpr,
         
     | 
| 
      
 816 
     | 
    
         
            +
                BV: tl.constexpr,
         
     | 
| 
      
 817 
     | 
    
         
            +
                STORE_QG: tl.constexpr,
         
     | 
| 
      
 818 
     | 
    
         
            +
                STORE_KG: tl.constexpr,
         
     | 
| 
      
 819 
     | 
    
         
            +
                IS_VARLEN: tl.constexpr,
         
     | 
| 
      
 820 
     | 
    
         
            +
                DOT_PRECISION: tl.constexpr,
         
     | 
| 
      
 821 
     | 
    
         
            +
            ):
         
     | 
| 
      
 822 
     | 
    
         
            +
                i_t, i_bh = tl.program_id(0), tl.program_id(1)
         
     | 
| 
      
 823 
     | 
    
         
            +
                i_b, i_h = i_bh // H, i_bh % H
         
     | 
| 
      
 824 
     | 
    
         
            +
                if IS_VARLEN:
         
     | 
| 
      
 825 
     | 
    
         
            +
                    i_n, i_t = (
         
     | 
| 
      
 826 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2).to(tl.int32),
         
     | 
| 
      
 827 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
         
     | 
| 
      
 828 
     | 
    
         
            +
                    )
         
     | 
| 
      
 829 
     | 
    
         
            +
                    bos, eos = (
         
     | 
| 
      
 830 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n).to(tl.int32),
         
     | 
| 
      
 831 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n + 1).to(tl.int32),
         
     | 
| 
      
 832 
     | 
    
         
            +
                    )
         
     | 
| 
      
 833 
     | 
    
         
            +
                    T = eos - bos
         
     | 
| 
      
 834 
     | 
    
         
            +
                else:
         
     | 
| 
      
 835 
     | 
    
         
            +
                    bos, eos = i_b * T, i_b * T + T
         
     | 
| 
      
 836 
     | 
    
         
            +
                p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
         
     | 
| 
      
 837 
     | 
    
         
            +
                b_b = tl.load(p_b, boundary_check=(0,))
         
     | 
| 
      
 838 
     | 
    
         
            +
             
     | 
| 
      
 839 
     | 
    
         
            +
                p_A = tl.make_block_ptr(
         
     | 
| 
      
 840 
     | 
    
         
            +
                    A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
         
     | 
| 
      
 841 
     | 
    
         
            +
                )
         
     | 
| 
      
 842 
     | 
    
         
            +
                b_A = tl.load(p_A, boundary_check=(0, 1))
         
     | 
| 
      
 843 
     | 
    
         
            +
             
     | 
| 
      
 844 
     | 
    
         
            +
                for i_v in range(tl.cdiv(V, BV)):
         
     | 
| 
      
 845 
     | 
    
         
            +
                    p_v = tl.make_block_ptr(
         
     | 
| 
      
 846 
     | 
    
         
            +
                        v + (bos * H + i_h) * V,
         
     | 
| 
      
 847 
     | 
    
         
            +
                        (T, V),
         
     | 
| 
      
 848 
     | 
    
         
            +
                        (H * V, 1),
         
     | 
| 
      
 849 
     | 
    
         
            +
                        (i_t * BT, i_v * BV),
         
     | 
| 
      
 850 
     | 
    
         
            +
                        (BT, BV),
         
     | 
| 
      
 851 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 852 
     | 
    
         
            +
                    )
         
     | 
| 
      
 853 
     | 
    
         
            +
                    p_u = tl.make_block_ptr(
         
     | 
| 
      
 854 
     | 
    
         
            +
                        u + (bos * H + i_h) * V,
         
     | 
| 
      
 855 
     | 
    
         
            +
                        (T, V),
         
     | 
| 
      
 856 
     | 
    
         
            +
                        (H * V, 1),
         
     | 
| 
      
 857 
     | 
    
         
            +
                        (i_t * BT, i_v * BV),
         
     | 
| 
      
 858 
     | 
    
         
            +
                        (BT, BV),
         
     | 
| 
      
 859 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 860 
     | 
    
         
            +
                    )
         
     | 
| 
      
 861 
     | 
    
         
            +
                    b_v = tl.load(p_v, boundary_check=(0, 1))
         
     | 
| 
      
 862 
     | 
    
         
            +
                    b_vb = (b_v * b_b[:, None]).to(b_v.dtype)
         
     | 
| 
      
 863 
     | 
    
         
            +
                    b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION)
         
     | 
| 
      
 864 
     | 
    
         
            +
                    tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 865 
     | 
    
         
            +
             
     | 
| 
      
 866 
     | 
    
         
            +
                for i_k in range(tl.cdiv(K, BK)):
         
     | 
| 
      
 867 
     | 
    
         
            +
                    p_w = tl.make_block_ptr(
         
     | 
| 
      
 868 
     | 
    
         
            +
                        w + (bos * H + i_h) * K,
         
     | 
| 
      
 869 
     | 
    
         
            +
                        (T, K),
         
     | 
| 
      
 870 
     | 
    
         
            +
                        (H * K, 1),
         
     | 
| 
      
 871 
     | 
    
         
            +
                        (i_t * BT, i_k * BK),
         
     | 
| 
      
 872 
     | 
    
         
            +
                        (BT, BK),
         
     | 
| 
      
 873 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 874 
     | 
    
         
            +
                    )
         
     | 
| 
      
 875 
     | 
    
         
            +
                    p_k = tl.make_block_ptr(
         
     | 
| 
      
 876 
     | 
    
         
            +
                        k + (bos * H + i_h) * K,
         
     | 
| 
      
 877 
     | 
    
         
            +
                        (T, K),
         
     | 
| 
      
 878 
     | 
    
         
            +
                        (H * K, 1),
         
     | 
| 
      
 879 
     | 
    
         
            +
                        (i_t * BT, i_k * BK),
         
     | 
| 
      
 880 
     | 
    
         
            +
                        (BT, BK),
         
     | 
| 
      
 881 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 882 
     | 
    
         
            +
                    )
         
     | 
| 
      
 883 
     | 
    
         
            +
                    b_k = tl.load(p_k, boundary_check=(0, 1))
         
     | 
| 
      
 884 
     | 
    
         
            +
                    b_kb = b_k * b_b[:, None]
         
     | 
| 
      
 885 
     | 
    
         
            +
             
     | 
| 
      
 886 
     | 
    
         
            +
                    p_gk = tl.make_block_ptr(
         
     | 
| 
      
 887 
     | 
    
         
            +
                        gk + (bos * H + i_h) * K,
         
     | 
| 
      
 888 
     | 
    
         
            +
                        (T, K),
         
     | 
| 
      
 889 
     | 
    
         
            +
                        (H * K, 1),
         
     | 
| 
      
 890 
     | 
    
         
            +
                        (i_t * BT, i_k * BK),
         
     | 
| 
      
 891 
     | 
    
         
            +
                        (BT, BK),
         
     | 
| 
      
 892 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 893 
     | 
    
         
            +
                    )
         
     | 
| 
      
 894 
     | 
    
         
            +
                    b_gk = tl.load(p_gk, boundary_check=(0, 1))
         
     | 
| 
      
 895 
     | 
    
         
            +
                    b_kb *= exp(b_gk)
         
     | 
| 
      
 896 
     | 
    
         
            +
                    if STORE_QG:
         
     | 
| 
      
 897 
     | 
    
         
            +
                        p_q = tl.make_block_ptr(
         
     | 
| 
      
 898 
     | 
    
         
            +
                            q + (bos * H + i_h) * K,
         
     | 
| 
      
 899 
     | 
    
         
            +
                            (T, K),
         
     | 
| 
      
 900 
     | 
    
         
            +
                            (H * K, 1),
         
     | 
| 
      
 901 
     | 
    
         
            +
                            (i_t * BT, i_k * BK),
         
     | 
| 
      
 902 
     | 
    
         
            +
                            (BT, BK),
         
     | 
| 
      
 903 
     | 
    
         
            +
                            (1, 0),
         
     | 
| 
      
 904 
     | 
    
         
            +
                        )
         
     | 
| 
      
 905 
     | 
    
         
            +
                        p_qg = tl.make_block_ptr(
         
     | 
| 
      
 906 
     | 
    
         
            +
                            qg + (bos * H + i_h) * K,
         
     | 
| 
      
 907 
     | 
    
         
            +
                            (T, K),
         
     | 
| 
      
 908 
     | 
    
         
            +
                            (H * K, 1),
         
     | 
| 
      
 909 
     | 
    
         
            +
                            (i_t * BT, i_k * BK),
         
     | 
| 
      
 910 
     | 
    
         
            +
                            (BT, BK),
         
     | 
| 
      
 911 
     | 
    
         
            +
                            (1, 0),
         
     | 
| 
      
 912 
     | 
    
         
            +
                        )
         
     | 
| 
      
 913 
     | 
    
         
            +
                        b_q = tl.load(p_q, boundary_check=(0, 1))
         
     | 
| 
      
 914 
     | 
    
         
            +
                        b_qg = b_q * exp(b_gk)
         
     | 
| 
      
 915 
     | 
    
         
            +
                        tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 916 
     | 
    
         
            +
                    if STORE_KG:
         
     | 
| 
      
 917 
     | 
    
         
            +
                        last_idx = min(i_t * BT + BT, T) - 1
         
     | 
| 
      
 918 
     | 
    
         
            +
             
     | 
| 
      
 919 
     | 
    
         
            +
                        o_k = i_k * BK + tl.arange(0, BK)
         
     | 
| 
      
 920 
     | 
    
         
            +
                        m_k = o_k < K
         
     | 
| 
      
 921 
     | 
    
         
            +
                        b_gn = tl.load(
         
     | 
| 
      
 922 
     | 
    
         
            +
                            gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0
         
     | 
| 
      
 923 
     | 
    
         
            +
                        )
         
     | 
| 
      
 924 
     | 
    
         
            +
                        b_kg = b_k * exp(b_gn - b_gk)
         
     | 
| 
      
 925 
     | 
    
         
            +
             
     | 
| 
      
 926 
     | 
    
         
            +
                        p_kg = tl.make_block_ptr(
         
     | 
| 
      
 927 
     | 
    
         
            +
                            kg + (bos * H + i_h) * K,
         
     | 
| 
      
 928 
     | 
    
         
            +
                            (T, K),
         
     | 
| 
      
 929 
     | 
    
         
            +
                            (H * K, 1),
         
     | 
| 
      
 930 
     | 
    
         
            +
                            (i_t * BT, i_k * BK),
         
     | 
| 
      
 931 
     | 
    
         
            +
                            (BT, BK),
         
     | 
| 
      
 932 
     | 
    
         
            +
                            (1, 0),
         
     | 
| 
      
 933 
     | 
    
         
            +
                        )
         
     | 
| 
      
 934 
     | 
    
         
            +
                        tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 935 
     | 
    
         
            +
             
     | 
| 
      
 936 
     | 
    
         
            +
                    b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
         
     | 
| 
      
 937 
     | 
    
         
            +
                    tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 938 
     | 
    
         
            +
             
     | 
| 
      
 939 
     | 
    
         
            +
             
     | 
| 
      
 940 
     | 
    
         
            +
            def recompute_w_u_fwd(
         
     | 
| 
      
 941 
     | 
    
         
            +
                k: torch.Tensor,
         
     | 
| 
      
 942 
     | 
    
         
            +
                v: torch.Tensor,
         
     | 
| 
      
 943 
     | 
    
         
            +
                beta: torch.Tensor,
         
     | 
| 
      
 944 
     | 
    
         
            +
                A: torch.Tensor,
         
     | 
| 
      
 945 
     | 
    
         
            +
                q: torch.Tensor | None = None,
         
     | 
| 
      
 946 
     | 
    
         
            +
                gk: torch.Tensor | None = None,
         
     | 
| 
      
 947 
     | 
    
         
            +
                cu_seqlens: torch.LongTensor | None = None,
         
     | 
| 
      
 948 
     | 
    
         
            +
            ) -> tuple[torch.Tensor, torch.Tensor]:
         
     | 
| 
      
 949 
     | 
    
         
            +
                B, T, H, K, V = *k.shape, v.shape[-1]
         
     | 
| 
      
 950 
     | 
    
         
            +
                BT = A.shape[-1]
         
     | 
| 
      
 951 
     | 
    
         
            +
                BK = 64
         
     | 
| 
      
 952 
     | 
    
         
            +
                BV = 64
         
     | 
| 
      
 953 
     | 
    
         
            +
             
     | 
| 
      
 954 
     | 
    
         
            +
                chunk_indices = (
         
     | 
| 
      
 955 
     | 
    
         
            +
                    prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
         
     | 
| 
      
 956 
     | 
    
         
            +
                )
         
     | 
| 
      
 957 
     | 
    
         
            +
                NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
         
     | 
| 
      
 958 
     | 
    
         
            +
             
     | 
| 
      
 959 
     | 
    
         
            +
                w = torch.empty_like(k)
         
     | 
| 
      
 960 
     | 
    
         
            +
                u = torch.empty_like(v)
         
     | 
| 
      
 961 
     | 
    
         
            +
                kg = torch.empty_like(k) if gk is not None else None
         
     | 
| 
      
 962 
     | 
    
         
            +
                recompute_w_u_fwd_kernel[(NT, B * H)](
         
     | 
| 
      
 963 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 964 
     | 
    
         
            +
                    k=k,
         
     | 
| 
      
 965 
     | 
    
         
            +
                    qg=None,
         
     | 
| 
      
 966 
     | 
    
         
            +
                    kg=kg,
         
     | 
| 
      
 967 
     | 
    
         
            +
                    v=v,
         
     | 
| 
      
 968 
     | 
    
         
            +
                    beta=beta,
         
     | 
| 
      
 969 
     | 
    
         
            +
                    w=w,
         
     | 
| 
      
 970 
     | 
    
         
            +
                    u=u,
         
     | 
| 
      
 971 
     | 
    
         
            +
                    A=A,
         
     | 
| 
      
 972 
     | 
    
         
            +
                    gk=gk,
         
     | 
| 
      
 973 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 974 
     | 
    
         
            +
                    chunk_indices=chunk_indices,
         
     | 
| 
      
 975 
     | 
    
         
            +
                    T=T,
         
     | 
| 
      
 976 
     | 
    
         
            +
                    H=H,
         
     | 
| 
      
 977 
     | 
    
         
            +
                    K=K,
         
     | 
| 
      
 978 
     | 
    
         
            +
                    V=V,
         
     | 
| 
      
 979 
     | 
    
         
            +
                    BT=BT,
         
     | 
| 
      
 980 
     | 
    
         
            +
                    BK=BK,
         
     | 
| 
      
 981 
     | 
    
         
            +
                    BV=BV,
         
     | 
| 
      
 982 
     | 
    
         
            +
                    DOT_PRECISION="ieee",
         
     | 
| 
      
 983 
     | 
    
         
            +
                )
         
     | 
| 
      
 984 
     | 
    
         
            +
                return w, u, None, kg
         
     | 
| 
      
 985 
     | 
    
         
            +
             
     | 
| 
      
 986 
     | 
    
         
            +
             
     | 
| 
      
 987 
     | 
    
         
            +
            @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
         
     | 
| 
      
 988 
     | 
    
         
            +
            @triton.autotune(
         
     | 
| 
      
 989 
     | 
    
         
            +
                configs=[
         
     | 
| 
      
 990 
     | 
    
         
            +
                    triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
         
     | 
| 
      
 991 
     | 
    
         
            +
                    for BK in [32, 64]
         
     | 
| 
      
 992 
     | 
    
         
            +
                    for BV in [64, 128]
         
     | 
| 
      
 993 
     | 
    
         
            +
                    for num_warps in [2, 4, 8]
         
     | 
| 
      
 994 
     | 
    
         
            +
                    for num_stages in [2, 3, 4]
         
     | 
| 
      
 995 
     | 
    
         
            +
                ],
         
     | 
| 
      
 996 
     | 
    
         
            +
                key=["BT"],
         
     | 
| 
      
 997 
     | 
    
         
            +
            )
         
     | 
| 
      
 998 
     | 
    
         
            +
            @triton.jit(do_not_specialize=["T"])
         
     | 
| 
      
 999 
     | 
    
         
            +
            def chunk_gla_fwd_kernel_o(
         
     | 
| 
      
 1000 
     | 
    
         
            +
                q,
         
     | 
| 
      
 1001 
     | 
    
         
            +
                v,
         
     | 
| 
      
 1002 
     | 
    
         
            +
                g,
         
     | 
| 
      
 1003 
     | 
    
         
            +
                h,
         
     | 
| 
      
 1004 
     | 
    
         
            +
                o,
         
     | 
| 
      
 1005 
     | 
    
         
            +
                A,
         
     | 
| 
      
 1006 
     | 
    
         
            +
                cu_seqlens,
         
     | 
| 
      
 1007 
     | 
    
         
            +
                chunk_indices,
         
     | 
| 
      
 1008 
     | 
    
         
            +
                scale,
         
     | 
| 
      
 1009 
     | 
    
         
            +
                T,
         
     | 
| 
      
 1010 
     | 
    
         
            +
                H: tl.constexpr,
         
     | 
| 
      
 1011 
     | 
    
         
            +
                K: tl.constexpr,
         
     | 
| 
      
 1012 
     | 
    
         
            +
                V: tl.constexpr,
         
     | 
| 
      
 1013 
     | 
    
         
            +
                BT: tl.constexpr,
         
     | 
| 
      
 1014 
     | 
    
         
            +
                BK: tl.constexpr,
         
     | 
| 
      
 1015 
     | 
    
         
            +
                BV: tl.constexpr,
         
     | 
| 
      
 1016 
     | 
    
         
            +
                IS_VARLEN: tl.constexpr,
         
     | 
| 
      
 1017 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1018 
     | 
    
         
            +
                i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
         
     | 
| 
      
 1019 
     | 
    
         
            +
                i_b, i_h = i_bh // H, i_bh % H
         
     | 
| 
      
 1020 
     | 
    
         
            +
                if IS_VARLEN:
         
     | 
| 
      
 1021 
     | 
    
         
            +
                    i_tg = i_t
         
     | 
| 
      
 1022 
     | 
    
         
            +
                    i_n, i_t = (
         
     | 
| 
      
 1023 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2).to(tl.int32),
         
     | 
| 
      
 1024 
     | 
    
         
            +
                        tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
         
     | 
| 
      
 1025 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1026 
     | 
    
         
            +
                    bos, eos = (
         
     | 
| 
      
 1027 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n).to(tl.int32),
         
     | 
| 
      
 1028 
     | 
    
         
            +
                        tl.load(cu_seqlens + i_n + 1).to(tl.int32),
         
     | 
| 
      
 1029 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1030 
     | 
    
         
            +
                    T = eos - bos
         
     | 
| 
      
 1031 
     | 
    
         
            +
                    NT = tl.cdiv(T, BT)
         
     | 
| 
      
 1032 
     | 
    
         
            +
                else:
         
     | 
| 
      
 1033 
     | 
    
         
            +
                    NT = tl.cdiv(T, BT)
         
     | 
| 
      
 1034 
     | 
    
         
            +
                    i_tg = i_b * NT + i_t
         
     | 
| 
      
 1035 
     | 
    
         
            +
                    bos, eos = i_b * T, i_b * T + T
         
     | 
| 
      
 1036 
     | 
    
         
            +
             
     | 
| 
      
 1037 
     | 
    
         
            +
                m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
         
     | 
| 
      
 1038 
     | 
    
         
            +
             
     | 
| 
      
 1039 
     | 
    
         
            +
                b_o = tl.zeros([BT, BV], dtype=tl.float32)
         
     | 
| 
      
 1040 
     | 
    
         
            +
                for i_k in range(tl.cdiv(K, BK)):
         
     | 
| 
      
 1041 
     | 
    
         
            +
                    p_q = tl.make_block_ptr(
         
     | 
| 
      
 1042 
     | 
    
         
            +
                        q + (bos * H + i_h) * K,
         
     | 
| 
      
 1043 
     | 
    
         
            +
                        (T, K),
         
     | 
| 
      
 1044 
     | 
    
         
            +
                        (H * K, 1),
         
     | 
| 
      
 1045 
     | 
    
         
            +
                        (i_t * BT, i_k * BK),
         
     | 
| 
      
 1046 
     | 
    
         
            +
                        (BT, BK),
         
     | 
| 
      
 1047 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 1048 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1049 
     | 
    
         
            +
                    p_g = tl.make_block_ptr(
         
     | 
| 
      
 1050 
     | 
    
         
            +
                        g + (bos * H + i_h) * K,
         
     | 
| 
      
 1051 
     | 
    
         
            +
                        (T, K),
         
     | 
| 
      
 1052 
     | 
    
         
            +
                        (H * K, 1),
         
     | 
| 
      
 1053 
     | 
    
         
            +
                        (i_t * BT, i_k * BK),
         
     | 
| 
      
 1054 
     | 
    
         
            +
                        (BT, BK),
         
     | 
| 
      
 1055 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 1056 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1057 
     | 
    
         
            +
                    p_h = tl.make_block_ptr(
         
     | 
| 
      
 1058 
     | 
    
         
            +
                        h + (i_tg * H + i_h) * K * V,
         
     | 
| 
      
 1059 
     | 
    
         
            +
                        (K, V),
         
     | 
| 
      
 1060 
     | 
    
         
            +
                        (V, 1),
         
     | 
| 
      
 1061 
     | 
    
         
            +
                        (i_k * BK, i_v * BV),
         
     | 
| 
      
 1062 
     | 
    
         
            +
                        (BK, BV),
         
     | 
| 
      
 1063 
     | 
    
         
            +
                        (1, 0),
         
     | 
| 
      
 1064 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1065 
     | 
    
         
            +
             
     | 
| 
      
 1066 
     | 
    
         
            +
                    # [BT, BK]
         
     | 
| 
      
 1067 
     | 
    
         
            +
                    b_q = tl.load(p_q, boundary_check=(0, 1))
         
     | 
| 
      
 1068 
     | 
    
         
            +
                    b_q = (b_q * scale).to(b_q.dtype)
         
     | 
| 
      
 1069 
     | 
    
         
            +
                    # [BT, BK]
         
     | 
| 
      
 1070 
     | 
    
         
            +
                    b_g = tl.load(p_g, boundary_check=(0, 1))
         
     | 
| 
      
 1071 
     | 
    
         
            +
                    # [BT, BK]
         
     | 
| 
      
 1072 
     | 
    
         
            +
                    b_qg = (b_q * exp(b_g)).to(b_q.dtype)
         
     | 
| 
      
 1073 
     | 
    
         
            +
                    # [BK, BV]
         
     | 
| 
      
 1074 
     | 
    
         
            +
                    b_h = tl.load(p_h, boundary_check=(0, 1))
         
     | 
| 
      
 1075 
     | 
    
         
            +
                    # works but dkw, owing to divine benevolence
         
     | 
| 
      
 1076 
     | 
    
         
            +
                    # [BT, BV]
         
     | 
| 
      
 1077 
     | 
    
         
            +
                    if i_k >= 0:
         
     | 
| 
      
 1078 
     | 
    
         
            +
                        b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
         
     | 
| 
      
 1079 
     | 
    
         
            +
                p_v = tl.make_block_ptr(
         
     | 
| 
      
 1080 
     | 
    
         
            +
                    v + (bos * H + i_h) * V,
         
     | 
| 
      
 1081 
     | 
    
         
            +
                    (T, V),
         
     | 
| 
      
 1082 
     | 
    
         
            +
                    (H * V, 1),
         
     | 
| 
      
 1083 
     | 
    
         
            +
                    (i_t * BT, i_v * BV),
         
     | 
| 
      
 1084 
     | 
    
         
            +
                    (BT, BV),
         
     | 
| 
      
 1085 
     | 
    
         
            +
                    (1, 0),
         
     | 
| 
      
 1086 
     | 
    
         
            +
                )
         
     | 
| 
      
 1087 
     | 
    
         
            +
                p_o = tl.make_block_ptr(
         
     | 
| 
      
 1088 
     | 
    
         
            +
                    o + (bos * H + i_h) * V,
         
     | 
| 
      
 1089 
     | 
    
         
            +
                    (T, V),
         
     | 
| 
      
 1090 
     | 
    
         
            +
                    (H * V, 1),
         
     | 
| 
      
 1091 
     | 
    
         
            +
                    (i_t * BT, i_v * BV),
         
     | 
| 
      
 1092 
     | 
    
         
            +
                    (BT, BV),
         
     | 
| 
      
 1093 
     | 
    
         
            +
                    (1, 0),
         
     | 
| 
      
 1094 
     | 
    
         
            +
                )
         
     | 
| 
      
 1095 
     | 
    
         
            +
                p_A = tl.make_block_ptr(
         
     | 
| 
      
 1096 
     | 
    
         
            +
                    A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
         
     | 
| 
      
 1097 
     | 
    
         
            +
                )
         
     | 
| 
      
 1098 
     | 
    
         
            +
                # [BT, BV]
         
     | 
| 
      
 1099 
     | 
    
         
            +
                b_v = tl.load(p_v, boundary_check=(0, 1))
         
     | 
| 
      
 1100 
     | 
    
         
            +
                # [BT, BT]
         
     | 
| 
      
 1101 
     | 
    
         
            +
                b_A = tl.load(p_A, boundary_check=(0, 1))
         
     | 
| 
      
 1102 
     | 
    
         
            +
                b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype)
         
     | 
| 
      
 1103 
     | 
    
         
            +
                b_o += tl.dot(b_A, b_v, allow_tf32=False)
         
     | 
| 
      
 1104 
     | 
    
         
            +
                tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 1105 
     | 
    
         
            +
             
     | 
| 
      
 1106 
     | 
    
         
            +
             
     | 
| 
      
 1107 
     | 
    
         
            +
            def chunk_gla_fwd_o_gk(
         
     | 
| 
      
 1108 
     | 
    
         
            +
                q: torch.Tensor,
         
     | 
| 
      
 1109 
     | 
    
         
            +
                v: torch.Tensor,
         
     | 
| 
      
 1110 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 1111 
     | 
    
         
            +
                A: torch.Tensor,
         
     | 
| 
      
 1112 
     | 
    
         
            +
                h: torch.Tensor,
         
     | 
| 
      
 1113 
     | 
    
         
            +
                o: torch.Tensor,
         
     | 
| 
      
 1114 
     | 
    
         
            +
                scale: float,
         
     | 
| 
      
 1115 
     | 
    
         
            +
                cu_seqlens: torch.LongTensor | None = None,
         
     | 
| 
      
 1116 
     | 
    
         
            +
                chunk_size: int = 64,
         
     | 
| 
      
 1117 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1118 
     | 
    
         
            +
                B, T, H, K, V = *q.shape, v.shape[-1]
         
     | 
| 
      
 1119 
     | 
    
         
            +
                BT = chunk_size
         
     | 
| 
      
 1120 
     | 
    
         
            +
             
     | 
| 
      
 1121 
     | 
    
         
            +
                chunk_indices = (
         
     | 
| 
      
 1122 
     | 
    
         
            +
                    prepare_chunk_indices(cu_seqlens, chunk_size)
         
     | 
| 
      
 1123 
     | 
    
         
            +
                    if cu_seqlens is not None
         
     | 
| 
      
 1124 
     | 
    
         
            +
                    else None
         
     | 
| 
      
 1125 
     | 
    
         
            +
                )
         
     | 
| 
      
 1126 
     | 
    
         
            +
                NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
         
     | 
| 
      
 1127 
     | 
    
         
            +
             
     | 
| 
      
 1128 
     | 
    
         
            +
                def grid(meta):
         
     | 
| 
      
 1129 
     | 
    
         
            +
                    return (cdiv(V, meta["BV"]), NT, B * H)
         
     | 
| 
      
 1130 
     | 
    
         
            +
             
     | 
| 
      
 1131 
     | 
    
         
            +
                chunk_gla_fwd_kernel_o[grid](
         
     | 
| 
      
 1132 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 1133 
     | 
    
         
            +
                    v=v,
         
     | 
| 
      
 1134 
     | 
    
         
            +
                    g=g,
         
     | 
| 
      
 1135 
     | 
    
         
            +
                    h=h,
         
     | 
| 
      
 1136 
     | 
    
         
            +
                    o=o,
         
     | 
| 
      
 1137 
     | 
    
         
            +
                    A=A,
         
     | 
| 
      
 1138 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 1139 
     | 
    
         
            +
                    chunk_indices=chunk_indices,
         
     | 
| 
      
 1140 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 1141 
     | 
    
         
            +
                    T=T,
         
     | 
| 
      
 1142 
     | 
    
         
            +
                    H=H,
         
     | 
| 
      
 1143 
     | 
    
         
            +
                    K=K,
         
     | 
| 
      
 1144 
     | 
    
         
            +
                    V=V,
         
     | 
| 
      
 1145 
     | 
    
         
            +
                    BT=BT,
         
     | 
| 
      
 1146 
     | 
    
         
            +
                )
         
     | 
| 
      
 1147 
     | 
    
         
            +
                return o
         
     | 
| 
      
 1148 
     | 
    
         
            +
             
     | 
| 
      
 1149 
     | 
    
         
            +
             
     | 
| 
      
 1150 
     | 
    
         
            +
            def chunk_kda_fwd(
         
     | 
| 
      
 1151 
     | 
    
         
            +
                q: torch.Tensor,
         
     | 
| 
      
 1152 
     | 
    
         
            +
                k: torch.Tensor,
         
     | 
| 
      
 1153 
     | 
    
         
            +
                v: torch.Tensor,
         
     | 
| 
      
 1154 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 1155 
     | 
    
         
            +
                beta: torch.Tensor,
         
     | 
| 
      
 1156 
     | 
    
         
            +
                scale: float,
         
     | 
| 
      
 1157 
     | 
    
         
            +
                initial_state: torch.Tensor,
         
     | 
| 
      
 1158 
     | 
    
         
            +
                output_final_state: bool,
         
     | 
| 
      
 1159 
     | 
    
         
            +
                cu_seqlens: torch.LongTensor | None = None,
         
     | 
| 
      
 1160 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1161 
     | 
    
         
            +
                chunk_size = 64
         
     | 
| 
      
 1162 
     | 
    
         
            +
                g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
         
     | 
| 
      
 1163 
     | 
    
         
            +
                # the intra Aqk is kept in fp32
         
     | 
| 
      
 1164 
     | 
    
         
            +
                # the computation has very marginal effect on the entire throughput
         
     | 
| 
      
 1165 
     | 
    
         
            +
                A, Aqk = chunk_kda_scaled_dot_kkt_fwd(
         
     | 
| 
      
 1166 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 1167 
     | 
    
         
            +
                    k=k,
         
     | 
| 
      
 1168 
     | 
    
         
            +
                    gk=g,
         
     | 
| 
      
 1169 
     | 
    
         
            +
                    beta=beta,
         
     | 
| 
      
 1170 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 1171 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 1172 
     | 
    
         
            +
                    output_dtype=torch.float32,
         
     | 
| 
      
 1173 
     | 
    
         
            +
                )
         
     | 
| 
      
 1174 
     | 
    
         
            +
                A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
         
     | 
| 
      
 1175 
     | 
    
         
            +
                w, u, _, kg = recompute_w_u_fwd(
         
     | 
| 
      
 1176 
     | 
    
         
            +
                    k=k,
         
     | 
| 
      
 1177 
     | 
    
         
            +
                    v=v,
         
     | 
| 
      
 1178 
     | 
    
         
            +
                    beta=beta,
         
     | 
| 
      
 1179 
     | 
    
         
            +
                    A=A,
         
     | 
| 
      
 1180 
     | 
    
         
            +
                    gk=g,
         
     | 
| 
      
 1181 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 1182 
     | 
    
         
            +
                )
         
     | 
| 
      
 1183 
     | 
    
         
            +
                del A
         
     | 
| 
      
 1184 
     | 
    
         
            +
                h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
         
     | 
| 
      
 1185 
     | 
    
         
            +
                    k=kg,
         
     | 
| 
      
 1186 
     | 
    
         
            +
                    w=w,
         
     | 
| 
      
 1187 
     | 
    
         
            +
                    u=u,
         
     | 
| 
      
 1188 
     | 
    
         
            +
                    gk=g,
         
     | 
| 
      
 1189 
     | 
    
         
            +
                    initial_state=initial_state,
         
     | 
| 
      
 1190 
     | 
    
         
            +
                    output_final_state=output_final_state,
         
     | 
| 
      
 1191 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 1192 
     | 
    
         
            +
                )
         
     | 
| 
      
 1193 
     | 
    
         
            +
                del w, u, kg
         
     | 
| 
      
 1194 
     | 
    
         
            +
                o = chunk_gla_fwd_o_gk(
         
     | 
| 
      
 1195 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 1196 
     | 
    
         
            +
                    v=v_new,
         
     | 
| 
      
 1197 
     | 
    
         
            +
                    g=g,
         
     | 
| 
      
 1198 
     | 
    
         
            +
                    A=Aqk,
         
     | 
| 
      
 1199 
     | 
    
         
            +
                    h=h,
         
     | 
| 
      
 1200 
     | 
    
         
            +
                    o=v,
         
     | 
| 
      
 1201 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 1202 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 1203 
     | 
    
         
            +
                    chunk_size=chunk_size,
         
     | 
| 
      
 1204 
     | 
    
         
            +
                )
         
     | 
| 
      
 1205 
     | 
    
         
            +
                del Aqk, v_new, h
         
     | 
| 
      
 1206 
     | 
    
         
            +
                return o, final_state
         
     | 
| 
      
 1207 
     | 
    
         
            +
             
     | 
| 
      
 1208 
     | 
    
         
            +
             
     | 
| 
      
 1209 
     | 
    
         
            +
            def chunk_kda(
         
     | 
| 
      
 1210 
     | 
    
         
            +
                q: torch.Tensor,
         
     | 
| 
      
 1211 
     | 
    
         
            +
                k: torch.Tensor,
         
     | 
| 
      
 1212 
     | 
    
         
            +
                v: torch.Tensor,
         
     | 
| 
      
 1213 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 1214 
     | 
    
         
            +
                beta: torch.Tensor,
         
     | 
| 
      
 1215 
     | 
    
         
            +
                scale: float = None,
         
     | 
| 
      
 1216 
     | 
    
         
            +
                initial_state: torch.Tensor = None,
         
     | 
| 
      
 1217 
     | 
    
         
            +
                output_final_state: bool = False,
         
     | 
| 
      
 1218 
     | 
    
         
            +
                use_qk_l2norm_in_kernel: bool = False,
         
     | 
| 
      
 1219 
     | 
    
         
            +
                cu_seqlens: torch.LongTensor | None = None,
         
     | 
| 
      
 1220 
     | 
    
         
            +
                **kwargs,
         
     | 
| 
      
 1221 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1222 
     | 
    
         
            +
                if scale is None:
         
     | 
| 
      
 1223 
     | 
    
         
            +
                    scale = k.shape[-1] ** -0.5
         
     | 
| 
      
 1224 
     | 
    
         
            +
             
     | 
| 
      
 1225 
     | 
    
         
            +
                if use_qk_l2norm_in_kernel:
         
     | 
| 
      
 1226 
     | 
    
         
            +
                    q = l2norm_fwd(q.contiguous())
         
     | 
| 
      
 1227 
     | 
    
         
            +
                    k = l2norm_fwd(k.contiguous())
         
     | 
| 
      
 1228 
     | 
    
         
            +
             
     | 
| 
      
 1229 
     | 
    
         
            +
                o, final_state = chunk_kda_fwd(
         
     | 
| 
      
 1230 
     | 
    
         
            +
                    q=q,
         
     | 
| 
      
 1231 
     | 
    
         
            +
                    k=k,
         
     | 
| 
      
 1232 
     | 
    
         
            +
                    v=v.contiguous(),
         
     | 
| 
      
 1233 
     | 
    
         
            +
                    g=g.contiguous(),
         
     | 
| 
      
 1234 
     | 
    
         
            +
                    beta=beta.contiguous(),
         
     | 
| 
      
 1235 
     | 
    
         
            +
                    scale=scale,
         
     | 
| 
      
 1236 
     | 
    
         
            +
                    initial_state=initial_state.contiguous(),
         
     | 
| 
      
 1237 
     | 
    
         
            +
                    output_final_state=output_final_state,
         
     | 
| 
      
 1238 
     | 
    
         
            +
                    cu_seqlens=cu_seqlens,
         
     | 
| 
      
 1239 
     | 
    
         
            +
                )
         
     | 
| 
      
 1240 
     | 
    
         
            +
                return o, final_state
         
     | 
| 
      
 1241 
     | 
    
         
            +
             
     | 
| 
      
 1242 
     | 
    
         
            +
             
     | 
| 
      
 1243 
     | 
    
         
            +
            @triton.autotune(
         
     | 
| 
      
 1244 
     | 
    
         
            +
                configs=[
         
     | 
| 
      
 1245 
     | 
    
         
            +
                    triton.Config({"BT": bt}, num_warps=nw, num_stages=ns)
         
     | 
| 
      
 1246 
     | 
    
         
            +
                    for bt in BT_LIST_AUTOTUNE
         
     | 
| 
      
 1247 
     | 
    
         
            +
                    for nw in NUM_WARPS_AUTOTUNE
         
     | 
| 
      
 1248 
     | 
    
         
            +
                    for ns in [2, 3]
         
     | 
| 
      
 1249 
     | 
    
         
            +
                ],
         
     | 
| 
      
 1250 
     | 
    
         
            +
                key=["H", "D"],
         
     | 
| 
      
 1251 
     | 
    
         
            +
            )
         
     | 
| 
      
 1252 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 1253 
     | 
    
         
            +
            def kda_gate_fwd_kernel(
         
     | 
| 
      
 1254 
     | 
    
         
            +
                g,
         
     | 
| 
      
 1255 
     | 
    
         
            +
                A,
         
     | 
| 
      
 1256 
     | 
    
         
            +
                y,
         
     | 
| 
      
 1257 
     | 
    
         
            +
                g_bias,
         
     | 
| 
      
 1258 
     | 
    
         
            +
                beta: tl.constexpr,
         
     | 
| 
      
 1259 
     | 
    
         
            +
                threshold: tl.constexpr,
         
     | 
| 
      
 1260 
     | 
    
         
            +
                T,
         
     | 
| 
      
 1261 
     | 
    
         
            +
                H,
         
     | 
| 
      
 1262 
     | 
    
         
            +
                D: tl.constexpr,
         
     | 
| 
      
 1263 
     | 
    
         
            +
                BT: tl.constexpr,
         
     | 
| 
      
 1264 
     | 
    
         
            +
                BD: tl.constexpr,
         
     | 
| 
      
 1265 
     | 
    
         
            +
                HAS_BIAS: tl.constexpr,
         
     | 
| 
      
 1266 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1267 
     | 
    
         
            +
                i_t, i_h = tl.program_id(0), tl.program_id(1)
         
     | 
| 
      
 1268 
     | 
    
         
            +
                n_t = i_t * BT
         
     | 
| 
      
 1269 
     | 
    
         
            +
             
     | 
| 
      
 1270 
     | 
    
         
            +
                b_a = tl.load(A + i_h).to(tl.float32)
         
     | 
| 
      
 1271 
     | 
    
         
            +
                b_a = -tl.exp(b_a)
         
     | 
| 
      
 1272 
     | 
    
         
            +
             
     | 
| 
      
 1273 
     | 
    
         
            +
                stride_row = H * D
         
     | 
| 
      
 1274 
     | 
    
         
            +
                stride_col = 1
         
     | 
| 
      
 1275 
     | 
    
         
            +
             
     | 
| 
      
 1276 
     | 
    
         
            +
                g_ptr = tl.make_block_ptr(
         
     | 
| 
      
 1277 
     | 
    
         
            +
                    base=g + i_h * D,
         
     | 
| 
      
 1278 
     | 
    
         
            +
                    shape=(T, D),
         
     | 
| 
      
 1279 
     | 
    
         
            +
                    strides=(stride_row, stride_col),
         
     | 
| 
      
 1280 
     | 
    
         
            +
                    offsets=(n_t, 0),
         
     | 
| 
      
 1281 
     | 
    
         
            +
                    block_shape=(BT, BD),
         
     | 
| 
      
 1282 
     | 
    
         
            +
                    order=(1, 0),
         
     | 
| 
      
 1283 
     | 
    
         
            +
                )
         
     | 
| 
      
 1284 
     | 
    
         
            +
             
     | 
| 
      
 1285 
     | 
    
         
            +
                y_ptr = tl.make_block_ptr(
         
     | 
| 
      
 1286 
     | 
    
         
            +
                    base=y + i_h * D,
         
     | 
| 
      
 1287 
     | 
    
         
            +
                    shape=(T, D),
         
     | 
| 
      
 1288 
     | 
    
         
            +
                    strides=(stride_row, stride_col),
         
     | 
| 
      
 1289 
     | 
    
         
            +
                    offsets=(n_t, 0),
         
     | 
| 
      
 1290 
     | 
    
         
            +
                    block_shape=(BT, BD),
         
     | 
| 
      
 1291 
     | 
    
         
            +
                    order=(1, 0),
         
     | 
| 
      
 1292 
     | 
    
         
            +
                )
         
     | 
| 
      
 1293 
     | 
    
         
            +
             
     | 
| 
      
 1294 
     | 
    
         
            +
                b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32)
         
     | 
| 
      
 1295 
     | 
    
         
            +
             
     | 
| 
      
 1296 
     | 
    
         
            +
                if HAS_BIAS:
         
     | 
| 
      
 1297 
     | 
    
         
            +
                    n_d = tl.arange(0, BD)
         
     | 
| 
      
 1298 
     | 
    
         
            +
                    bias_mask = n_d < D
         
     | 
| 
      
 1299 
     | 
    
         
            +
                    b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to(
         
     | 
| 
      
 1300 
     | 
    
         
            +
                        tl.float32
         
     | 
| 
      
 1301 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1302 
     | 
    
         
            +
                    b_g = b_g + b_bias[None, :]
         
     | 
| 
      
 1303 
     | 
    
         
            +
             
     | 
| 
      
 1304 
     | 
    
         
            +
                # softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
         
     | 
| 
      
 1305 
     | 
    
         
            +
                # When beta * x > threshold, use linear approximation x
         
     | 
| 
      
 1306 
     | 
    
         
            +
                # Use threshold to switch to linear when beta*x > threshold
         
     | 
| 
      
 1307 
     | 
    
         
            +
                g_scaled = b_g * beta
         
     | 
| 
      
 1308 
     | 
    
         
            +
                use_linear = g_scaled > threshold
         
     | 
| 
      
 1309 
     | 
    
         
            +
                sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled)))
         
     | 
| 
      
 1310 
     | 
    
         
            +
                b_y = b_a * sp
         
     | 
| 
      
 1311 
     | 
    
         
            +
             
     | 
| 
      
 1312 
     | 
    
         
            +
                tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
         
     | 
| 
      
 1313 
     | 
    
         
            +
             
     | 
| 
      
 1314 
     | 
    
         
            +
             
     | 
| 
      
 1315 
     | 
    
         
            +
            def fused_kda_gate(
         
     | 
| 
      
 1316 
     | 
    
         
            +
                g: torch.Tensor,
         
     | 
| 
      
 1317 
     | 
    
         
            +
                A: torch.Tensor,
         
     | 
| 
      
 1318 
     | 
    
         
            +
                head_k_dim: int,
         
     | 
| 
      
 1319 
     | 
    
         
            +
                g_bias: torch.Tensor | None = None,
         
     | 
| 
      
 1320 
     | 
    
         
            +
                beta: float = 1.0,
         
     | 
| 
      
 1321 
     | 
    
         
            +
                threshold: float = 20.0,
         
     | 
| 
      
 1322 
     | 
    
         
            +
            ) -> torch.Tensor:
         
     | 
| 
      
 1323 
     | 
    
         
            +
                """
         
     | 
| 
      
 1324 
     | 
    
         
            +
                Forward pass for KDA gate:
         
     | 
| 
      
 1325 
     | 
    
         
            +
                  input g: [..., H*D]
         
     | 
| 
      
 1326 
     | 
    
         
            +
                  param A: [H] or [1, 1, H, 1]
         
     | 
| 
      
 1327 
     | 
    
         
            +
                  beta: softplus beta parameter
         
     | 
| 
      
 1328 
     | 
    
         
            +
                  threshold: softplus threshold parameter
         
     | 
| 
      
 1329 
     | 
    
         
            +
                  return  : [..., H, D]
         
     | 
| 
      
 1330 
     | 
    
         
            +
                """
         
     | 
| 
      
 1331 
     | 
    
         
            +
                orig_shape = g.shape[:-1]
         
     | 
| 
      
 1332 
     | 
    
         
            +
             
     | 
| 
      
 1333 
     | 
    
         
            +
                g = g.view(-1, g.shape[-1])
         
     | 
| 
      
 1334 
     | 
    
         
            +
                T = g.shape[0]
         
     | 
| 
      
 1335 
     | 
    
         
            +
                HD = g.shape[1]
         
     | 
| 
      
 1336 
     | 
    
         
            +
                H = A.numel()
         
     | 
| 
      
 1337 
     | 
    
         
            +
                assert H * head_k_dim == HD
         
     | 
| 
      
 1338 
     | 
    
         
            +
             
     | 
| 
      
 1339 
     | 
    
         
            +
                y = torch.empty_like(g, dtype=torch.float32)
         
     | 
| 
      
 1340 
     | 
    
         
            +
             
     | 
| 
      
 1341 
     | 
    
         
            +
                def grid(meta):
         
     | 
| 
      
 1342 
     | 
    
         
            +
                    return (cdiv(T, meta["BT"]), H)
         
     | 
| 
      
 1343 
     | 
    
         
            +
             
     | 
| 
      
 1344 
     | 
    
         
            +
                kda_gate_fwd_kernel[grid](
         
     | 
| 
      
 1345 
     | 
    
         
            +
                    g,
         
     | 
| 
      
 1346 
     | 
    
         
            +
                    A,
         
     | 
| 
      
 1347 
     | 
    
         
            +
                    y,
         
     | 
| 
      
 1348 
     | 
    
         
            +
                    g_bias,
         
     | 
| 
      
 1349 
     | 
    
         
            +
                    beta,
         
     | 
| 
      
 1350 
     | 
    
         
            +
                    threshold,
         
     | 
| 
      
 1351 
     | 
    
         
            +
                    T,
         
     | 
| 
      
 1352 
     | 
    
         
            +
                    H,
         
     | 
| 
      
 1353 
     | 
    
         
            +
                    head_k_dim,
         
     | 
| 
      
 1354 
     | 
    
         
            +
                    BD=next_power_of_2(head_k_dim),
         
     | 
| 
      
 1355 
     | 
    
         
            +
                    HAS_BIAS=g_bias is not None,
         
     | 
| 
      
 1356 
     | 
    
         
            +
                )
         
     | 
| 
      
 1357 
     | 
    
         
            +
             
     | 
| 
      
 1358 
     | 
    
         
            +
                y = y.view(*orig_shape, H, head_k_dim)
         
     | 
| 
      
 1359 
     | 
    
         
            +
                return y
         
     |