sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,640 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
4
|
+
|
5
|
+
from typing import Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import triton
|
9
|
+
import triton.language as tl
|
10
|
+
|
11
|
+
from sglang.srt.layers.attention.fla.op import exp
|
12
|
+
from sglang.srt.layers.attention.fla.utils import input_guard
|
13
|
+
|
14
|
+
|
15
|
+
@triton.heuristics(
|
16
|
+
{
|
17
|
+
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
18
|
+
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
19
|
+
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
20
|
+
}
|
21
|
+
)
|
22
|
+
@triton.jit(do_not_specialize=["T"])
|
23
|
+
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
24
|
+
q,
|
25
|
+
k,
|
26
|
+
v,
|
27
|
+
g,
|
28
|
+
beta,
|
29
|
+
o,
|
30
|
+
h0,
|
31
|
+
ht,
|
32
|
+
cu_seqlens,
|
33
|
+
scale,
|
34
|
+
T,
|
35
|
+
B: tl.constexpr,
|
36
|
+
H: tl.constexpr,
|
37
|
+
HV: tl.constexpr,
|
38
|
+
K: tl.constexpr,
|
39
|
+
V: tl.constexpr,
|
40
|
+
BK: tl.constexpr,
|
41
|
+
BV: tl.constexpr,
|
42
|
+
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
43
|
+
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
44
|
+
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
|
45
|
+
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
46
|
+
IS_VARLEN: tl.constexpr,
|
47
|
+
):
|
48
|
+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
49
|
+
i_n, i_hv = i_nh // HV, i_nh % HV
|
50
|
+
i_h = i_hv // (HV // H)
|
51
|
+
if IS_VARLEN:
|
52
|
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(
|
53
|
+
cu_seqlens + i_n + 1
|
54
|
+
).to(tl.int64)
|
55
|
+
all = T
|
56
|
+
T = eos - bos
|
57
|
+
else:
|
58
|
+
bos, eos = i_n * T, i_n * T + T
|
59
|
+
all = B * T
|
60
|
+
o_k = i_k * BK + tl.arange(0, BK)
|
61
|
+
o_v = i_v * BV + tl.arange(0, BV)
|
62
|
+
|
63
|
+
p_q = q + (bos * H + i_h) * K + o_k
|
64
|
+
p_k = k + (bos * H + i_h) * K + o_k
|
65
|
+
p_v = v + (bos * HV + i_hv) * V + o_v
|
66
|
+
if IS_BETA_HEADWISE:
|
67
|
+
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
68
|
+
else:
|
69
|
+
p_beta = beta + bos * HV + i_hv
|
70
|
+
p_g = g + bos * HV + i_hv
|
71
|
+
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
72
|
+
|
73
|
+
mask_k = o_k < K
|
74
|
+
mask_v = o_v < V
|
75
|
+
mask_h = mask_k[:, None] & mask_v[None, :]
|
76
|
+
|
77
|
+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
78
|
+
if USE_INITIAL_STATE:
|
79
|
+
p_h0 = h0 + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
|
80
|
+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
81
|
+
|
82
|
+
for _ in range(0, T):
|
83
|
+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
84
|
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
85
|
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
86
|
+
b_g = tl.load(p_g).to(tl.float32)
|
87
|
+
|
88
|
+
if USE_QK_L2NORM_IN_KERNEL:
|
89
|
+
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
90
|
+
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
91
|
+
b_q = b_q * scale
|
92
|
+
# [BK, BV]
|
93
|
+
b_h *= exp(b_g)
|
94
|
+
# [BV]
|
95
|
+
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
96
|
+
if IS_BETA_HEADWISE:
|
97
|
+
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
98
|
+
else:
|
99
|
+
b_beta = tl.load(p_beta).to(tl.float32)
|
100
|
+
b_v *= b_beta
|
101
|
+
# [BK, BV]
|
102
|
+
b_h += b_k[:, None] * b_v[None, :]
|
103
|
+
# [BV]
|
104
|
+
b_o = tl.sum(b_h * b_q[:, None], 0)
|
105
|
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
106
|
+
|
107
|
+
p_q += H * K
|
108
|
+
p_k += H * K
|
109
|
+
p_o += HV * V
|
110
|
+
p_v += HV * V
|
111
|
+
p_g += HV
|
112
|
+
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
113
|
+
|
114
|
+
if STORE_FINAL_STATE:
|
115
|
+
p_ht = ht + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
|
116
|
+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
117
|
+
|
118
|
+
|
119
|
+
def fused_recurrent_gated_delta_rule_fwd(
|
120
|
+
q: torch.Tensor,
|
121
|
+
k: torch.Tensor,
|
122
|
+
v: torch.Tensor,
|
123
|
+
g: torch.Tensor,
|
124
|
+
beta: torch.Tensor,
|
125
|
+
scale: float,
|
126
|
+
initial_state: torch.Tensor,
|
127
|
+
output_final_state: bool,
|
128
|
+
use_qk_l2norm_in_kernel: bool = False,
|
129
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
130
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
131
|
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
132
|
+
HV = v.shape[2]
|
133
|
+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
134
|
+
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
135
|
+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
136
|
+
assert NK == 1, "NK > 1 is not supported yet"
|
137
|
+
num_stages = 3
|
138
|
+
num_warps = 1
|
139
|
+
|
140
|
+
o = q.new_empty(NK, *v.shape)
|
141
|
+
if output_final_state:
|
142
|
+
final_state = q.new_empty(N, HV, K, V, dtype=torch.float32)
|
143
|
+
else:
|
144
|
+
final_state = None
|
145
|
+
|
146
|
+
grid = (NK, NV, N * HV)
|
147
|
+
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
|
148
|
+
q=q,
|
149
|
+
k=k,
|
150
|
+
v=v,
|
151
|
+
g=g,
|
152
|
+
beta=beta,
|
153
|
+
o=o,
|
154
|
+
h0=initial_state,
|
155
|
+
ht=final_state,
|
156
|
+
cu_seqlens=cu_seqlens,
|
157
|
+
scale=scale,
|
158
|
+
T=T,
|
159
|
+
B=B,
|
160
|
+
H=H,
|
161
|
+
HV=HV,
|
162
|
+
K=K,
|
163
|
+
V=V,
|
164
|
+
BK=BK,
|
165
|
+
BV=BV,
|
166
|
+
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
167
|
+
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
168
|
+
num_warps=num_warps,
|
169
|
+
num_stages=num_stages,
|
170
|
+
)
|
171
|
+
o = o.squeeze(0)
|
172
|
+
return o, final_state
|
173
|
+
|
174
|
+
|
175
|
+
class FusedRecurrentFunction(torch.autograd.Function):
|
176
|
+
|
177
|
+
@staticmethod
|
178
|
+
@input_guard
|
179
|
+
def forward(
|
180
|
+
ctx,
|
181
|
+
q: torch.Tensor,
|
182
|
+
k: torch.Tensor,
|
183
|
+
v: torch.Tensor,
|
184
|
+
g: torch.Tensor,
|
185
|
+
beta: torch.Tensor,
|
186
|
+
scale: float,
|
187
|
+
initial_state: torch.Tensor,
|
188
|
+
output_final_state: bool,
|
189
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
190
|
+
use_qk_l2norm_in_kernel: bool = False,
|
191
|
+
):
|
192
|
+
o, final_state = fused_recurrent_gated_delta_rule_fwd(
|
193
|
+
q=q,
|
194
|
+
k=k,
|
195
|
+
v=v,
|
196
|
+
g=g,
|
197
|
+
beta=beta,
|
198
|
+
scale=scale,
|
199
|
+
initial_state=initial_state,
|
200
|
+
output_final_state=output_final_state,
|
201
|
+
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
202
|
+
cu_seqlens=cu_seqlens,
|
203
|
+
)
|
204
|
+
|
205
|
+
return o, final_state
|
206
|
+
|
207
|
+
@staticmethod
|
208
|
+
@input_guard
|
209
|
+
def backward(ctx, do, dht):
|
210
|
+
raise NotImplementedError(
|
211
|
+
"Backward pass is not implemented yet and we do not have plans to implement it "
|
212
|
+
"because we haven't figured out how to compute dg without materializing the full "
|
213
|
+
"hidden states for all time steps."
|
214
|
+
)
|
215
|
+
|
216
|
+
|
217
|
+
def fused_recurrent_gated_delta_rule(
|
218
|
+
q: torch.Tensor,
|
219
|
+
k: torch.Tensor,
|
220
|
+
v: torch.Tensor,
|
221
|
+
g: torch.Tensor,
|
222
|
+
beta: torch.Tensor = None,
|
223
|
+
scale: float = None,
|
224
|
+
initial_state: torch.Tensor = None,
|
225
|
+
output_final_state: bool = False,
|
226
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
227
|
+
use_qk_l2norm_in_kernel: bool = False,
|
228
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
229
|
+
r"""
|
230
|
+
Args:
|
231
|
+
q (torch.Tensor):
|
232
|
+
queries of shape `[B, T, H, K]`.
|
233
|
+
k (torch.Tensor):
|
234
|
+
keys of shape `[B, T, H, K]`.
|
235
|
+
v (torch.Tensor):
|
236
|
+
values of shape `[B, T, HV, V]`.
|
237
|
+
GVA is applied if `HV > H`.
|
238
|
+
g (torch.Tensor):
|
239
|
+
g (decays) of shape `[B, T, HV]`.
|
240
|
+
beta (torch.Tensor):
|
241
|
+
betas of shape `[B, T, HV]`.
|
242
|
+
scale (Optional[int]):
|
243
|
+
Scale factor for the RetNet attention scores.
|
244
|
+
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
245
|
+
initial_state (Optional[torch.Tensor]):
|
246
|
+
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
247
|
+
For equal-length input sequences, `N` equals the batch size `B`.
|
248
|
+
Default: `None`.
|
249
|
+
output_final_state (Optional[bool]):
|
250
|
+
Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`.
|
251
|
+
cu_seqlens (torch.LongTensor):
|
252
|
+
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
253
|
+
consistent with the FlashAttention API.
|
254
|
+
Returns:
|
255
|
+
o (torch.Tensor):
|
256
|
+
Outputs of shape `[B, T, HV, V]`.
|
257
|
+
final_state (torch.Tensor):
|
258
|
+
Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`.
|
259
|
+
Examples::
|
260
|
+
>>> import torch
|
261
|
+
>>> import torch.nn.functional as F
|
262
|
+
>>> from einops import rearrange
|
263
|
+
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
264
|
+
# inputs with equal lengths
|
265
|
+
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
266
|
+
>>> q = torch.randn(B, T, H, K, device='cuda')
|
267
|
+
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
268
|
+
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
269
|
+
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
270
|
+
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
271
|
+
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
272
|
+
>>> o, ht = fused_gated_recurrent_delta_rule(
|
273
|
+
q, k, v, g, beta,
|
274
|
+
initial_state=h0,
|
275
|
+
output_final_state=True
|
276
|
+
)
|
277
|
+
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
278
|
+
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
279
|
+
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
280
|
+
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
281
|
+
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
282
|
+
q, k, v, g, beta,
|
283
|
+
initial_state=h0,
|
284
|
+
output_final_state=True,
|
285
|
+
cu_seqlens=cu_seqlens
|
286
|
+
)
|
287
|
+
"""
|
288
|
+
if cu_seqlens is not None:
|
289
|
+
if q.shape[0] != 1:
|
290
|
+
raise ValueError(
|
291
|
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
292
|
+
f"Please flatten variable-length inputs before processing."
|
293
|
+
)
|
294
|
+
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
295
|
+
raise ValueError(
|
296
|
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
297
|
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
298
|
+
)
|
299
|
+
if scale is None:
|
300
|
+
scale = k.shape[-1] ** -0.5
|
301
|
+
else:
|
302
|
+
assert scale > 0, "scale must be positive"
|
303
|
+
if beta is None:
|
304
|
+
beta = torch.ones_like(q[..., 0])
|
305
|
+
o, final_state = FusedRecurrentFunction.apply(
|
306
|
+
q,
|
307
|
+
k,
|
308
|
+
v,
|
309
|
+
g,
|
310
|
+
beta,
|
311
|
+
scale,
|
312
|
+
initial_state,
|
313
|
+
output_final_state,
|
314
|
+
cu_seqlens,
|
315
|
+
use_qk_l2norm_in_kernel,
|
316
|
+
)
|
317
|
+
return o, final_state
|
318
|
+
|
319
|
+
|
320
|
+
@triton.heuristics(
|
321
|
+
{
|
322
|
+
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
|
323
|
+
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
324
|
+
"CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"]
|
325
|
+
is not None,
|
326
|
+
}
|
327
|
+
)
|
328
|
+
@triton.jit(do_not_specialize=["T"])
|
329
|
+
def fused_recurrent_gated_delta_rule_update_fwd_kernel(
|
330
|
+
q,
|
331
|
+
k,
|
332
|
+
v,
|
333
|
+
g,
|
334
|
+
beta,
|
335
|
+
o,
|
336
|
+
h0_source,
|
337
|
+
h0_indices,
|
338
|
+
cu_seqlens,
|
339
|
+
scale,
|
340
|
+
intermediate_states_buffer,
|
341
|
+
cache_steps,
|
342
|
+
T,
|
343
|
+
B: tl.constexpr,
|
344
|
+
H: tl.constexpr,
|
345
|
+
HV: tl.constexpr,
|
346
|
+
K: tl.constexpr,
|
347
|
+
V: tl.constexpr,
|
348
|
+
BK: tl.constexpr,
|
349
|
+
BV: tl.constexpr,
|
350
|
+
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
351
|
+
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
|
352
|
+
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
353
|
+
IS_VARLEN: tl.constexpr,
|
354
|
+
DISABLE_STATE_UPDATE: tl.constexpr, # whether to disable final state update
|
355
|
+
DISABLE_OUTPUT_CALCULATION: tl.constexpr, # whether to disable output calculation
|
356
|
+
CACHE_INTERMEDIATE_STATES: tl.constexpr,
|
357
|
+
):
|
358
|
+
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
359
|
+
i_n, i_hv = i_nh // HV, i_nh % HV
|
360
|
+
i_h = i_hv // (HV // H)
|
361
|
+
if IS_VARLEN:
|
362
|
+
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(
|
363
|
+
cu_seqlens + i_n + 1
|
364
|
+
).to(tl.int64)
|
365
|
+
all = T
|
366
|
+
T = eos - bos
|
367
|
+
else:
|
368
|
+
bos, eos = i_n * T, i_n * T + T
|
369
|
+
all = B * T
|
370
|
+
o_k = i_k * BK + tl.arange(0, BK)
|
371
|
+
o_v = i_v * BV + tl.arange(0, BV)
|
372
|
+
|
373
|
+
p_q = q + (bos * H + i_h) * K + o_k
|
374
|
+
p_k = k + (bos * H + i_h) * K + o_k
|
375
|
+
p_v = v + (bos * HV + i_hv) * V + o_v
|
376
|
+
if IS_BETA_HEADWISE:
|
377
|
+
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
378
|
+
else:
|
379
|
+
p_beta = beta + bos * HV + i_hv
|
380
|
+
p_g = g + bos * HV + i_hv
|
381
|
+
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
382
|
+
|
383
|
+
mask_k = o_k < K
|
384
|
+
mask_v = o_v < V
|
385
|
+
mask_h = mask_k[:, None] & mask_v[None, :]
|
386
|
+
|
387
|
+
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
388
|
+
if USE_INITIAL_STATE:
|
389
|
+
idx = tl.load(h0_indices + i_n)
|
390
|
+
# Add bounds checking for idx
|
391
|
+
if idx >= 0: # Assuming negative indices are invalid
|
392
|
+
p_h0 = (
|
393
|
+
h0_source
|
394
|
+
+ idx * HV * K * V
|
395
|
+
+ i_hv * K * V
|
396
|
+
+ o_k[:, None] * V
|
397
|
+
+ o_v[None, :]
|
398
|
+
)
|
399
|
+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
400
|
+
|
401
|
+
# Prepare intermediate state cache variables if enabled
|
402
|
+
cache_idx = -1
|
403
|
+
if CACHE_INTERMEDIATE_STATES:
|
404
|
+
cache_idx = tl.load(h0_indices + i_n)
|
405
|
+
|
406
|
+
step_idx = 0
|
407
|
+
for _ in range(0, T):
|
408
|
+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
409
|
+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
410
|
+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
411
|
+
b_g = tl.load(p_g).to(tl.float32)
|
412
|
+
|
413
|
+
if USE_QK_L2NORM_IN_KERNEL:
|
414
|
+
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
415
|
+
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
416
|
+
b_q = b_q * scale
|
417
|
+
# [BK, BV]
|
418
|
+
b_h *= exp(b_g)
|
419
|
+
# [BV]
|
420
|
+
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
421
|
+
if IS_BETA_HEADWISE:
|
422
|
+
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
423
|
+
else:
|
424
|
+
b_beta = tl.load(p_beta).to(tl.float32)
|
425
|
+
b_v *= b_beta
|
426
|
+
# [BK, BV]
|
427
|
+
b_h += b_k[:, None] * b_v[None, :]
|
428
|
+
# [BV]
|
429
|
+
if not DISABLE_OUTPUT_CALCULATION:
|
430
|
+
b_o = tl.sum(b_h * b_q[:, None], 0)
|
431
|
+
# core attn output
|
432
|
+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
433
|
+
|
434
|
+
# store intermediate states if enabled
|
435
|
+
if CACHE_INTERMEDIATE_STATES:
|
436
|
+
if cache_idx >= 0:
|
437
|
+
# Compute cache pointer for this step
|
438
|
+
step_offset = step_idx * HV * K * V
|
439
|
+
cache_ptr = (
|
440
|
+
intermediate_states_buffer
|
441
|
+
+ cache_idx * cache_steps * HV * K * V
|
442
|
+
+ step_offset
|
443
|
+
+ i_hv * K * V
|
444
|
+
+ o_k[:, None] * V
|
445
|
+
+ o_v[None, :]
|
446
|
+
)
|
447
|
+
tl.store(cache_ptr, b_h.to(cache_ptr.dtype.element_ty), mask=mask_h)
|
448
|
+
|
449
|
+
step_idx += 1
|
450
|
+
|
451
|
+
p_q += H * K
|
452
|
+
p_k += H * K
|
453
|
+
p_o += HV * V
|
454
|
+
p_v += HV * V
|
455
|
+
p_g += HV
|
456
|
+
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
457
|
+
|
458
|
+
# Store final state back to h0_source with bounds checking
|
459
|
+
# ssm states
|
460
|
+
if not DISABLE_STATE_UPDATE:
|
461
|
+
idx = tl.load(h0_indices + i_n)
|
462
|
+
if idx >= 0: # Add bounds checking
|
463
|
+
p_h0 = (
|
464
|
+
h0_source
|
465
|
+
+ idx * HV * K * V
|
466
|
+
+ i_hv * K * V
|
467
|
+
+ o_k[:, None] * V
|
468
|
+
+ o_v[None, :]
|
469
|
+
)
|
470
|
+
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
|
471
|
+
|
472
|
+
|
473
|
+
def fused_recurrent_gated_delta_rule_update_fwd(
|
474
|
+
q: torch.Tensor,
|
475
|
+
k: torch.Tensor,
|
476
|
+
v: torch.Tensor,
|
477
|
+
g: torch.Tensor,
|
478
|
+
beta: torch.Tensor,
|
479
|
+
scale: float,
|
480
|
+
initial_state_source: torch.Tensor,
|
481
|
+
initial_state_indices: torch.Tensor,
|
482
|
+
use_qk_l2norm_in_kernel: bool = False,
|
483
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
484
|
+
disable_state_update: bool = False,
|
485
|
+
disable_output_calculation: bool = False,
|
486
|
+
intermediate_states_buffer: Optional[torch.Tensor] = None,
|
487
|
+
cache_steps: Optional[int] = None,
|
488
|
+
) -> torch.Tensor:
|
489
|
+
B, T, H, K, V = *k.shape, v.shape[-1]
|
490
|
+
HV = v.shape[2]
|
491
|
+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
492
|
+
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
493
|
+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
494
|
+
assert NK == 1, "NK > 1 is not supported yet"
|
495
|
+
num_stages = 3
|
496
|
+
num_warps = 1
|
497
|
+
|
498
|
+
if disable_output_calculation:
|
499
|
+
# When output calculation is disabled, allocate minimal tensor
|
500
|
+
o = q.new_empty(NK, 1, 1, 1, 1) # minimal allocation
|
501
|
+
else:
|
502
|
+
o = q.new_empty(NK, *v.shape)
|
503
|
+
|
504
|
+
grid = (NK, NV, N * HV)
|
505
|
+
|
506
|
+
fused_recurrent_gated_delta_rule_update_fwd_kernel[grid](
|
507
|
+
q=q,
|
508
|
+
k=k,
|
509
|
+
v=v,
|
510
|
+
g=g,
|
511
|
+
beta=beta,
|
512
|
+
o=o,
|
513
|
+
h0_source=initial_state_source,
|
514
|
+
h0_indices=initial_state_indices,
|
515
|
+
cu_seqlens=cu_seqlens,
|
516
|
+
scale=scale,
|
517
|
+
intermediate_states_buffer=intermediate_states_buffer,
|
518
|
+
cache_steps=0 if cache_steps is None else cache_steps,
|
519
|
+
T=T,
|
520
|
+
B=B,
|
521
|
+
H=H,
|
522
|
+
HV=HV,
|
523
|
+
K=K,
|
524
|
+
V=V,
|
525
|
+
BK=BK,
|
526
|
+
BV=BV,
|
527
|
+
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
528
|
+
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
529
|
+
DISABLE_STATE_UPDATE=disable_state_update,
|
530
|
+
DISABLE_OUTPUT_CALCULATION=disable_output_calculation,
|
531
|
+
num_warps=num_warps,
|
532
|
+
num_stages=num_stages,
|
533
|
+
)
|
534
|
+
o = o.squeeze(0)
|
535
|
+
return o
|
536
|
+
|
537
|
+
|
538
|
+
class FusedRecurrentUpdateFunction(torch.autograd.Function):
|
539
|
+
|
540
|
+
@staticmethod
|
541
|
+
@input_guard
|
542
|
+
def forward(
|
543
|
+
ctx,
|
544
|
+
q: torch.Tensor,
|
545
|
+
k: torch.Tensor,
|
546
|
+
v: torch.Tensor,
|
547
|
+
g: torch.Tensor,
|
548
|
+
beta: torch.Tensor,
|
549
|
+
scale: float,
|
550
|
+
initial_state_source: torch.Tensor,
|
551
|
+
initial_state_indices: torch.Tensor,
|
552
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
553
|
+
use_qk_l2norm_in_kernel: bool = False,
|
554
|
+
disable_state_update: bool = False,
|
555
|
+
disable_output_calculation: bool = False,
|
556
|
+
intermediate_states_buffer: Optional[torch.Tensor] = None,
|
557
|
+
cache_steps: Optional[int] = None,
|
558
|
+
):
|
559
|
+
o = fused_recurrent_gated_delta_rule_update_fwd(
|
560
|
+
q=q,
|
561
|
+
k=k,
|
562
|
+
v=v,
|
563
|
+
g=g,
|
564
|
+
beta=beta,
|
565
|
+
scale=scale,
|
566
|
+
initial_state_source=initial_state_source,
|
567
|
+
initial_state_indices=initial_state_indices,
|
568
|
+
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
569
|
+
cu_seqlens=cu_seqlens,
|
570
|
+
disable_state_update=disable_state_update,
|
571
|
+
disable_output_calculation=disable_output_calculation,
|
572
|
+
intermediate_states_buffer=intermediate_states_buffer,
|
573
|
+
cache_steps=cache_steps,
|
574
|
+
)
|
575
|
+
|
576
|
+
return o
|
577
|
+
|
578
|
+
@staticmethod
|
579
|
+
@input_guard
|
580
|
+
def backward(ctx, do, dht):
|
581
|
+
raise NotImplementedError(
|
582
|
+
"Backward pass is not implemented yet and we do not have plans to implement it "
|
583
|
+
"because we haven't figured out how to compute dg without materializing the full "
|
584
|
+
"hidden states for all time steps."
|
585
|
+
)
|
586
|
+
|
587
|
+
|
588
|
+
def fused_recurrent_gated_delta_rule_update(
|
589
|
+
q: torch.Tensor,
|
590
|
+
k: torch.Tensor,
|
591
|
+
v: torch.Tensor,
|
592
|
+
g: torch.Tensor,
|
593
|
+
beta: torch.Tensor = None,
|
594
|
+
scale: float = None,
|
595
|
+
initial_state_source: torch.Tensor = None,
|
596
|
+
initial_state_indices: torch.Tensor = None,
|
597
|
+
cu_seqlens: Optional[torch.LongTensor] = None,
|
598
|
+
use_qk_l2norm_in_kernel: bool = False,
|
599
|
+
disable_state_update: bool = False,
|
600
|
+
disable_output_calculation: bool = False,
|
601
|
+
intermediate_states_buffer: Optional[torch.Tensor] = None,
|
602
|
+
cache_steps: Optional[int] = None,
|
603
|
+
) -> torch.Tensor:
|
604
|
+
if cu_seqlens is not None:
|
605
|
+
if q.shape[0] != 1:
|
606
|
+
raise ValueError(
|
607
|
+
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
608
|
+
f"Please flatten variable-length inputs before processing."
|
609
|
+
)
|
610
|
+
if (
|
611
|
+
initial_state_source is not None
|
612
|
+
and initial_state_indices.shape[0] != len(cu_seqlens) - 1
|
613
|
+
):
|
614
|
+
raise ValueError(
|
615
|
+
f"The number of initial states is expected to be equal to the number of input sequences, "
|
616
|
+
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}."
|
617
|
+
)
|
618
|
+
if scale is None:
|
619
|
+
scale = k.shape[-1] ** -0.5
|
620
|
+
else:
|
621
|
+
assert scale > 0, "scale must be positive"
|
622
|
+
if beta is None:
|
623
|
+
beta = torch.ones_like(q[..., 0])
|
624
|
+
o = FusedRecurrentUpdateFunction.apply(
|
625
|
+
q,
|
626
|
+
k,
|
627
|
+
v,
|
628
|
+
g,
|
629
|
+
beta,
|
630
|
+
scale,
|
631
|
+
initial_state_source,
|
632
|
+
initial_state_indices,
|
633
|
+
cu_seqlens,
|
634
|
+
use_qk_l2norm_in_kernel,
|
635
|
+
disable_state_update,
|
636
|
+
disable_output_calculation,
|
637
|
+
intermediate_states_buffer,
|
638
|
+
cache_steps,
|
639
|
+
)
|
640
|
+
return o
|