sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,326 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
|
2
|
+
# Copyright (c) 2024, Tri Dao.
|
3
|
+
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
4
|
+
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
5
|
+
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
6
|
+
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
7
|
+
|
8
|
+
import math
|
9
|
+
|
10
|
+
import torch
|
11
|
+
import torch.nn.functional as F
|
12
|
+
import triton
|
13
|
+
import triton.language as tl
|
14
|
+
from einops import rearrange
|
15
|
+
|
16
|
+
|
17
|
+
def rms_norm_ref(
|
18
|
+
x,
|
19
|
+
weight,
|
20
|
+
bias,
|
21
|
+
z=None,
|
22
|
+
eps=1e-6,
|
23
|
+
group_size=None,
|
24
|
+
norm_before_gate=True,
|
25
|
+
upcast=True,
|
26
|
+
):
|
27
|
+
dtype = x.dtype
|
28
|
+
N = x.shape[-1]
|
29
|
+
weight = weight.float()
|
30
|
+
bias = bias.float() if bias is not None else None
|
31
|
+
if upcast:
|
32
|
+
x = x.float()
|
33
|
+
z = z.float() if z is not None else z
|
34
|
+
if z is not None and not norm_before_gate:
|
35
|
+
x = x * F.silu(z)
|
36
|
+
if group_size is None:
|
37
|
+
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
38
|
+
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
39
|
+
else:
|
40
|
+
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
41
|
+
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
42
|
+
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
43
|
+
if bias is not None:
|
44
|
+
out = out + bias
|
45
|
+
if z is not None and norm_before_gate:
|
46
|
+
out *= F.silu(z)
|
47
|
+
return out.to(dtype)
|
48
|
+
|
49
|
+
|
50
|
+
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
51
|
+
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
52
|
+
@triton.jit
|
53
|
+
def _layer_norm_fwd_1pass_kernel(
|
54
|
+
X, # pointer to the input
|
55
|
+
Y, # pointer to the output
|
56
|
+
W, # pointer to the weights
|
57
|
+
B, # pointer to the biases
|
58
|
+
Z, # pointer to the other branch
|
59
|
+
Mean, # pointer to the mean
|
60
|
+
Rstd, # pointer to the 1/std
|
61
|
+
stride_x_row, # how much to increase the pointer when moving by 1 row
|
62
|
+
stride_y_row,
|
63
|
+
stride_z_row,
|
64
|
+
M, # number of rows in X
|
65
|
+
N, # number of columns in X
|
66
|
+
eps, # epsilon to avoid division by zero
|
67
|
+
BLOCK_N: tl.constexpr,
|
68
|
+
HAS_BIAS: tl.constexpr,
|
69
|
+
HAS_Z: tl.constexpr,
|
70
|
+
NORM_BEFORE_GATE: tl.constexpr,
|
71
|
+
IS_RMS_NORM: tl.constexpr,
|
72
|
+
):
|
73
|
+
# Map the program id to the row of X and Y it should compute.
|
74
|
+
row = tl.program_id(0)
|
75
|
+
group = tl.program_id(1)
|
76
|
+
X += row * stride_x_row + group * N
|
77
|
+
Y += row * stride_y_row + group * N
|
78
|
+
if HAS_Z:
|
79
|
+
Z += row * stride_z_row + group * N
|
80
|
+
if not IS_RMS_NORM:
|
81
|
+
Mean += group * M
|
82
|
+
Rstd += group * M
|
83
|
+
W += group * N
|
84
|
+
if HAS_BIAS:
|
85
|
+
B += group * N
|
86
|
+
# Compute mean and variance
|
87
|
+
cols = tl.arange(0, BLOCK_N)
|
88
|
+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
89
|
+
if HAS_Z and not NORM_BEFORE_GATE:
|
90
|
+
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
91
|
+
x *= z * tl.sigmoid(z)
|
92
|
+
if not IS_RMS_NORM:
|
93
|
+
mean = tl.sum(x, axis=0) / N
|
94
|
+
tl.store(Mean + row, mean)
|
95
|
+
xbar = tl.where(cols < N, x - mean, 0.0)
|
96
|
+
var = tl.sum(xbar * xbar, axis=0) / N
|
97
|
+
else:
|
98
|
+
xbar = tl.where(cols < N, x, 0.0)
|
99
|
+
var = tl.sum(xbar * xbar, axis=0) / N
|
100
|
+
rstd = 1 / tl.sqrt(var + eps)
|
101
|
+
tl.store(Rstd + row, rstd)
|
102
|
+
# Normalize and apply linear transformation
|
103
|
+
mask = cols < N
|
104
|
+
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
105
|
+
if HAS_BIAS:
|
106
|
+
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
107
|
+
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
108
|
+
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
109
|
+
if HAS_Z and NORM_BEFORE_GATE:
|
110
|
+
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
111
|
+
y *= z * tl.sigmoid(z)
|
112
|
+
# Write output
|
113
|
+
tl.store(Y + cols, y, mask=mask)
|
114
|
+
|
115
|
+
|
116
|
+
def _layer_norm_fwd(
|
117
|
+
x,
|
118
|
+
weight,
|
119
|
+
bias,
|
120
|
+
eps,
|
121
|
+
z=None,
|
122
|
+
out=None,
|
123
|
+
group_size=None,
|
124
|
+
norm_before_gate=True,
|
125
|
+
is_rms_norm=False,
|
126
|
+
):
|
127
|
+
M, N = x.shape
|
128
|
+
if group_size is None:
|
129
|
+
group_size = N
|
130
|
+
assert N % group_size == 0
|
131
|
+
ngroups = N // group_size
|
132
|
+
assert x.stride(-1) == 1
|
133
|
+
if z is not None:
|
134
|
+
assert z.stride(-1) == 1
|
135
|
+
assert z.shape == (M, N)
|
136
|
+
assert weight.shape == (N,)
|
137
|
+
assert weight.stride(-1) == 1
|
138
|
+
if bias is not None:
|
139
|
+
assert bias.stride(-1) == 1
|
140
|
+
assert bias.shape == (N,)
|
141
|
+
# allocate output
|
142
|
+
if out is not None:
|
143
|
+
assert out.shape == x.shape
|
144
|
+
else:
|
145
|
+
out = torch.empty_like(x)
|
146
|
+
assert out.stride(-1) == 1
|
147
|
+
mean = (
|
148
|
+
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
149
|
+
if not is_rms_norm
|
150
|
+
else None
|
151
|
+
)
|
152
|
+
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
153
|
+
# Less than 64KB per feature: enqueue fused kernel
|
154
|
+
MAX_FUSED_SIZE = 65536 // x.element_size()
|
155
|
+
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
156
|
+
if group_size > BLOCK_N:
|
157
|
+
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
158
|
+
# heuristics for number of warps
|
159
|
+
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
160
|
+
grid = (M, ngroups)
|
161
|
+
with torch.cuda.device(x.device.index):
|
162
|
+
_layer_norm_fwd_1pass_kernel[grid](
|
163
|
+
x,
|
164
|
+
out,
|
165
|
+
weight,
|
166
|
+
bias,
|
167
|
+
z,
|
168
|
+
mean,
|
169
|
+
rstd,
|
170
|
+
x.stride(0),
|
171
|
+
out.stride(0),
|
172
|
+
z.stride(0) if z is not None else 0,
|
173
|
+
M,
|
174
|
+
group_size,
|
175
|
+
eps,
|
176
|
+
BLOCK_N=BLOCK_N,
|
177
|
+
NORM_BEFORE_GATE=norm_before_gate,
|
178
|
+
IS_RMS_NORM=is_rms_norm,
|
179
|
+
num_warps=num_warps,
|
180
|
+
)
|
181
|
+
return out, mean, rstd
|
182
|
+
|
183
|
+
|
184
|
+
class LayerNormFn(torch.autograd.Function):
|
185
|
+
|
186
|
+
@staticmethod
|
187
|
+
def forward(
|
188
|
+
ctx,
|
189
|
+
x,
|
190
|
+
weight,
|
191
|
+
bias,
|
192
|
+
z=None,
|
193
|
+
eps=1e-6,
|
194
|
+
group_size=None,
|
195
|
+
norm_before_gate=True,
|
196
|
+
is_rms_norm=False,
|
197
|
+
):
|
198
|
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
199
|
+
|
200
|
+
x_shape_og = x.shape
|
201
|
+
# reshape input data into 2D tensor
|
202
|
+
x = x.reshape(-1, x.shape[-1])
|
203
|
+
if x.stride(-1) != 1:
|
204
|
+
x = x.contiguous()
|
205
|
+
if z is not None:
|
206
|
+
assert z.shape == x_shape_og
|
207
|
+
z = z.reshape(-1, z.shape[-1])
|
208
|
+
if z.stride(-1) != 1:
|
209
|
+
z = z.contiguous()
|
210
|
+
weight = weight.contiguous()
|
211
|
+
if bias is not None:
|
212
|
+
bias = bias.contiguous()
|
213
|
+
y, mean, rstd = _layer_norm_fwd(
|
214
|
+
x,
|
215
|
+
weight,
|
216
|
+
bias,
|
217
|
+
eps,
|
218
|
+
z=z,
|
219
|
+
group_size=group_size,
|
220
|
+
norm_before_gate=norm_before_gate,
|
221
|
+
is_rms_norm=is_rms_norm,
|
222
|
+
)
|
223
|
+
return y.reshape(x_shape_og)
|
224
|
+
|
225
|
+
|
226
|
+
def layernorm_fn(
|
227
|
+
x,
|
228
|
+
weight,
|
229
|
+
bias,
|
230
|
+
z=None,
|
231
|
+
eps=1e-6,
|
232
|
+
group_size=None,
|
233
|
+
norm_before_gate=True,
|
234
|
+
is_rms_norm=False,
|
235
|
+
):
|
236
|
+
return LayerNormFn.apply(
|
237
|
+
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
238
|
+
)
|
239
|
+
|
240
|
+
|
241
|
+
def rmsnorm_fn(
|
242
|
+
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
243
|
+
):
|
244
|
+
return LayerNormFn.apply(
|
245
|
+
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
246
|
+
)
|
247
|
+
|
248
|
+
|
249
|
+
class LayerNorm(torch.nn.Module):
|
250
|
+
|
251
|
+
def __init__(
|
252
|
+
self,
|
253
|
+
hidden_size,
|
254
|
+
eps=1e-5,
|
255
|
+
group_size=None,
|
256
|
+
norm_before_gate=True,
|
257
|
+
device=None,
|
258
|
+
dtype=None,
|
259
|
+
):
|
260
|
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
261
|
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
262
|
+
"""
|
263
|
+
|
264
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
265
|
+
super().__init__()
|
266
|
+
self.eps = eps
|
267
|
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
268
|
+
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
269
|
+
self.group_size = group_size
|
270
|
+
self.norm_before_gate = norm_before_gate
|
271
|
+
self.reset_parameters()
|
272
|
+
|
273
|
+
def reset_parameters(self):
|
274
|
+
torch.nn.init.ones_(self.weight)
|
275
|
+
torch.nn.init.zeros_(self.bias)
|
276
|
+
|
277
|
+
def forward(self, x, z=None):
|
278
|
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
279
|
+
return layernorm_fn(
|
280
|
+
x,
|
281
|
+
self.weight,
|
282
|
+
self.bias,
|
283
|
+
z=z,
|
284
|
+
group_size=self.group_size,
|
285
|
+
eps=self.eps,
|
286
|
+
norm_before_gate=self.norm_before_gate,
|
287
|
+
)
|
288
|
+
|
289
|
+
|
290
|
+
class RMSNorm(torch.nn.Module):
|
291
|
+
|
292
|
+
def __init__(
|
293
|
+
self,
|
294
|
+
hidden_size,
|
295
|
+
eps=1e-5,
|
296
|
+
group_size=None,
|
297
|
+
norm_before_gate=True,
|
298
|
+
device=None,
|
299
|
+
dtype=None,
|
300
|
+
):
|
301
|
+
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
302
|
+
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
303
|
+
"""
|
304
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
305
|
+
super().__init__()
|
306
|
+
self.eps = eps
|
307
|
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
308
|
+
self.register_parameter("bias", None)
|
309
|
+
self.group_size = group_size
|
310
|
+
self.norm_before_gate = norm_before_gate
|
311
|
+
self.reset_parameters()
|
312
|
+
|
313
|
+
def reset_parameters(self):
|
314
|
+
torch.nn.init.ones_(self.weight)
|
315
|
+
|
316
|
+
def forward(self, x, z=None):
|
317
|
+
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
318
|
+
return rmsnorm_fn(
|
319
|
+
x,
|
320
|
+
self.weight,
|
321
|
+
self.bias,
|
322
|
+
z=z,
|
323
|
+
eps=self.eps,
|
324
|
+
group_size=self.group_size,
|
325
|
+
norm_before_gate=self.norm_before_gate,
|
326
|
+
)
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
4
|
+
|
5
|
+
import os
|
6
|
+
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
9
|
+
import triton.language.extra.libdevice as tldevice
|
10
|
+
|
11
|
+
from sglang.srt.layers.attention.fla.utils import is_gather_supported
|
12
|
+
|
13
|
+
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
14
|
+
exp = tldevice.fast_expf
|
15
|
+
exp2 = tldevice.exp2
|
16
|
+
log = tldevice.fast_logf
|
17
|
+
log2 = tldevice.fast_log2f
|
18
|
+
else:
|
19
|
+
exp = tl.exp
|
20
|
+
exp2 = tl.math.exp2
|
21
|
+
log = tl.log
|
22
|
+
log2 = tl.log2
|
23
|
+
|
24
|
+
|
25
|
+
@triton.jit
|
26
|
+
def safe_exp(x):
|
27
|
+
return exp(tl.where(x <= 0, x, float("-inf")))
|
28
|
+
|
29
|
+
|
30
|
+
if not is_gather_supported:
|
31
|
+
|
32
|
+
@triton.jit
|
33
|
+
def gather(src, index, axis, _builder=None):
|
34
|
+
"""
|
35
|
+
Gather operation that works when tl.gather is not supported.
|
36
|
+
This is a fallback implementation that returns None.
|
37
|
+
Just to make triton compiler happy.
|
38
|
+
"""
|
39
|
+
return None
|
40
|
+
|
41
|
+
else:
|
42
|
+
gather = tl.gather
|
43
|
+
|
44
|
+
|
45
|
+
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
|
46
|
+
# For Triton 3.3.x
|
47
|
+
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
|
48
|
+
elif hasattr(triton.language, "make_tensor_descriptor"):
|
49
|
+
# For Triton 3.4.x and later
|
50
|
+
make_tensor_descriptor = triton.language.make_tensor_descriptor
|
51
|
+
else:
|
52
|
+
"""
|
53
|
+
Fallback implementation when TMA is not supported.
|
54
|
+
Returns None to indicate TMA descriptors are unavailable.
|
55
|
+
Just make triton compiler happy.
|
56
|
+
"""
|
57
|
+
|
58
|
+
@triton.jit
|
59
|
+
def make_tensor_descriptor(
|
60
|
+
base,
|
61
|
+
shape,
|
62
|
+
strides,
|
63
|
+
block_shape,
|
64
|
+
_builder=None,
|
65
|
+
):
|
66
|
+
return None
|