sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
|
|
1
|
+
from typing import Callable, List, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
|
6
|
+
|
7
|
+
|
8
|
+
def mamba_v2_sharded_weight_loader(
|
9
|
+
shard_spec: List[Tuple[int, int, float]],
|
10
|
+
tp_size: int,
|
11
|
+
tp_rank: int,
|
12
|
+
) -> LoaderFunction:
|
13
|
+
"""Create a weight loader for mamba v2. This ensures that the projections
|
14
|
+
are correctly sharded so that they can be split into x, B, C. It also
|
15
|
+
ensures the the all the groups corresponding to a head shard is placed
|
16
|
+
together with it.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
20
|
+
|
21
|
+
# - track boundary of (sharded) param, and loaded_weight, respectively
|
22
|
+
boundary, loaded_boundary = 0, 0
|
23
|
+
|
24
|
+
# - iterate over the shard specs
|
25
|
+
for full_dim, extra, duplicate_groups in shard_spec:
|
26
|
+
# - full dim is the model dim (before TP).
|
27
|
+
# - extra > 0, means there is expected overall increase
|
28
|
+
# of dimensions. This is so because of replication.
|
29
|
+
# - ratio is used map the tp_rank to the actual shard
|
30
|
+
# rank. This is useful when there is replication of
|
31
|
+
# groups to accompany head shards.
|
32
|
+
|
33
|
+
# - size of the loaded shard
|
34
|
+
shard_size = full_dim // tp_size
|
35
|
+
|
36
|
+
# - compute the rank into the loaded shard.
|
37
|
+
# - if there is replication, different TP shards will
|
38
|
+
# take from the same rank.
|
39
|
+
# NOTE: currently we only support duplication
|
40
|
+
# in the case where num_groups == 1
|
41
|
+
rank = 0 if duplicate_groups else tp_rank
|
42
|
+
|
43
|
+
# - leftmost boundary index into loaded weight.
|
44
|
+
loaded_skip = rank * shard_size
|
45
|
+
loaded_start_idx = loaded_boundary + loaded_skip
|
46
|
+
|
47
|
+
# - take these many dims from the loaded weight.
|
48
|
+
take = min(shard_size, full_dim - extra - loaded_skip)
|
49
|
+
|
50
|
+
# - always shard on dim 0
|
51
|
+
# - the ignore is for a mundane mypy error as it does not
|
52
|
+
# seem to handle slices well.
|
53
|
+
# https://github.com/python/mypy/issues/2410
|
54
|
+
param.data[
|
55
|
+
boundary : (boundary + take), ... # type: ignore[misc]
|
56
|
+
] = loaded_weight[
|
57
|
+
loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
|
58
|
+
] # type: ignore[misc]
|
59
|
+
|
60
|
+
# move indexing boundaries
|
61
|
+
boundary += shard_size
|
62
|
+
loaded_boundary += full_dim - extra
|
63
|
+
|
64
|
+
return loader
|
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
193
193
|
else:
|
194
194
|
o = torch.empty_like(q)
|
195
195
|
|
196
|
+
if layer.is_cross_attention:
|
197
|
+
cache_loc = forward_batch.encoder_out_cache_loc
|
198
|
+
else:
|
199
|
+
cache_loc = forward_batch.out_cache_loc
|
200
|
+
|
196
201
|
if save_kv_cache:
|
197
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
198
|
-
layer, forward_batch.out_cache_loc, k, v
|
199
|
-
)
|
202
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
200
203
|
|
201
204
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
202
205
|
|
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
241
244
|
else:
|
242
245
|
o = torch.empty_like(q)
|
243
246
|
|
247
|
+
if layer.is_cross_attention:
|
248
|
+
cache_loc = forward_batch.encoder_out_cache_loc
|
249
|
+
else:
|
250
|
+
cache_loc = forward_batch.out_cache_loc
|
251
|
+
|
244
252
|
if save_kv_cache:
|
245
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
246
|
-
layer, forward_batch.out_cache_loc, k, v
|
247
|
-
)
|
253
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
248
254
|
|
249
255
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
250
256
|
|
@@ -45,12 +45,21 @@ TRTLLM_BLOCK_CONSTRAINT = 128
|
|
45
45
|
global_zero_init_workspace_buffer = None
|
46
46
|
|
47
47
|
|
48
|
+
@dataclass
|
49
|
+
class TRTLLMMLAPrefillMetadata:
|
50
|
+
"""Metadata for TRTLLM MLA prefill operations."""
|
51
|
+
|
52
|
+
max_seq_len: int
|
53
|
+
cum_seq_lens: torch.Tensor
|
54
|
+
seq_lens: torch.Tensor
|
55
|
+
|
56
|
+
|
48
57
|
@dataclass
|
49
58
|
class TRTLLMMLADecodeMetadata:
|
50
59
|
"""Metadata for TRTLLM MLA decode operations."""
|
51
60
|
|
52
|
-
workspace: Optional[torch.Tensor] = None
|
53
61
|
block_kv_indices: Optional[torch.Tensor] = None
|
62
|
+
max_seq_len: Optional[int] = None
|
54
63
|
|
55
64
|
|
56
65
|
class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
@@ -100,7 +109,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
100
109
|
# CUDA graph state
|
101
110
|
self.decode_cuda_graph_metadata = {}
|
102
111
|
self.decode_cuda_graph_kv_indices = None
|
103
|
-
self.
|
112
|
+
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
113
|
+
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
104
114
|
|
105
115
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
106
116
|
"""
|
@@ -176,9 +186,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
176
186
|
self.decode_cuda_graph_kv_indices = torch.full(
|
177
187
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
178
188
|
)
|
179
|
-
self.decode_cuda_graph_workspace = torch.empty(
|
180
|
-
self.workspace_size, dtype=torch.int8, device=self.device
|
181
|
-
)
|
182
189
|
|
183
190
|
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
184
191
|
|
@@ -207,8 +214,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
207
214
|
)
|
208
215
|
|
209
216
|
# Custom fast-path for decode/idle.
|
210
|
-
|
211
|
-
|
217
|
+
# Capture with full width so future longer sequences are safe during replay
|
218
|
+
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
219
|
+
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_blocks_per_seq]
|
212
220
|
|
213
221
|
create_flashmla_kv_indices_triton[(bs,)](
|
214
222
|
self.req_to_token,
|
@@ -217,16 +225,22 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
217
225
|
None,
|
218
226
|
block_kv_indices,
|
219
227
|
self.req_to_token.stride(0),
|
220
|
-
|
228
|
+
max_blocks_per_seq,
|
221
229
|
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
222
230
|
PAGED_SIZE=self.page_size,
|
223
231
|
)
|
224
232
|
|
233
|
+
# Record the true maximum sequence length for this capture batch so that
|
234
|
+
# the kernel launch path (which requires an int not a tensor) can reuse
|
235
|
+
# it safely during both capture and replay.
|
236
|
+
max_seq_len_val = int(seq_lens.max().item())
|
237
|
+
|
225
238
|
metadata = TRTLLMMLADecodeMetadata(
|
226
|
-
|
239
|
+
block_kv_indices,
|
240
|
+
max_seq_len_val,
|
227
241
|
)
|
228
242
|
self.decode_cuda_graph_metadata[bs] = metadata
|
229
|
-
self.
|
243
|
+
self.forward_decode_metadata = metadata
|
230
244
|
|
231
245
|
def init_forward_metadata_replay_cuda_graph(
|
232
246
|
self,
|
@@ -268,6 +282,13 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
268
282
|
PAGED_SIZE=self.page_size,
|
269
283
|
)
|
270
284
|
|
285
|
+
# Update stored max_seq_len so subsequent kernel calls use the correct value
|
286
|
+
# Prefer CPU tensor to avoid GPU synchronization when available.
|
287
|
+
if seq_lens_cpu is not None:
|
288
|
+
metadata.max_seq_len = int(seq_lens_cpu.max().item())
|
289
|
+
else:
|
290
|
+
metadata.max_seq_len = int(seq_lens.max().item())
|
291
|
+
|
271
292
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
272
293
|
"""Get the fill value for sequence lengths in CUDA graph."""
|
273
294
|
return 1
|
@@ -275,30 +296,52 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
275
296
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
276
297
|
"""Initialize the metadata for a forward pass."""
|
277
298
|
# Delegate to parent for non-decode modes.
|
278
|
-
if
|
279
|
-
|
299
|
+
if (
|
300
|
+
forward_batch.forward_mode.is_extend()
|
301
|
+
and not forward_batch.forward_mode.is_target_verify()
|
302
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
303
|
+
):
|
304
|
+
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
305
|
+
cum_seq_lens_q = torch.cat(
|
306
|
+
(
|
307
|
+
torch.tensor([0], device=forward_batch.seq_lens.device),
|
308
|
+
torch.cumsum(seq_lens, dim=0),
|
309
|
+
)
|
310
|
+
).int()
|
311
|
+
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
|
312
|
+
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
|
313
|
+
max_seq_len,
|
314
|
+
cum_seq_lens_q,
|
315
|
+
seq_lens,
|
316
|
+
)
|
317
|
+
elif forward_batch.forward_mode.is_decode_or_idle():
|
318
|
+
bs = forward_batch.batch_size
|
280
319
|
|
281
|
-
|
320
|
+
# Get maximum sequence length.
|
321
|
+
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
322
|
+
max_seq = forward_batch.seq_lens_cpu.max().item()
|
323
|
+
else:
|
324
|
+
max_seq = forward_batch.seq_lens.max().item()
|
282
325
|
|
283
|
-
|
284
|
-
|
285
|
-
|
326
|
+
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
327
|
+
block_kv_indices = self._create_block_kv_indices(
|
328
|
+
bs,
|
329
|
+
max_seqlen_pad,
|
330
|
+
forward_batch.req_pool_indices,
|
331
|
+
forward_batch.seq_lens,
|
332
|
+
forward_batch.seq_lens.device,
|
333
|
+
)
|
334
|
+
|
335
|
+
max_seq_len_val = int(max_seq)
|
336
|
+
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
337
|
+
block_kv_indices, max_seq_len_val
|
338
|
+
)
|
339
|
+
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
286
340
|
else:
|
287
|
-
|
288
|
-
|
289
|
-
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
290
|
-
block_kv_indices = self._create_block_kv_indices(
|
291
|
-
bs,
|
292
|
-
max_seqlen_pad,
|
293
|
-
forward_batch.req_pool_indices,
|
294
|
-
forward_batch.seq_lens,
|
295
|
-
forward_batch.seq_lens.device,
|
296
|
-
)
|
341
|
+
return super().init_forward_metadata(forward_batch)
|
297
342
|
|
298
|
-
|
299
|
-
|
300
|
-
)
|
301
|
-
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
343
|
+
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
344
|
+
super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
|
302
345
|
|
303
346
|
def quantize_and_rope_for_fp8(
|
304
347
|
self,
|
@@ -442,7 +485,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
442
485
|
# Get metadata
|
443
486
|
metadata = (
|
444
487
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
445
|
-
or self.
|
488
|
+
or self.forward_decode_metadata
|
446
489
|
)
|
447
490
|
|
448
491
|
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
@@ -465,20 +508,67 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
465
508
|
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
466
509
|
query=query,
|
467
510
|
kv_cache=kv_cache,
|
468
|
-
workspace_buffer=
|
511
|
+
workspace_buffer=self.workspace_buffer,
|
469
512
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
470
513
|
kv_lora_rank=self.kv_lora_rank,
|
471
514
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
472
515
|
block_tables=metadata.block_kv_indices,
|
473
516
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
474
|
-
max_seq_len=
|
517
|
+
max_seq_len=metadata.max_seq_len,
|
475
518
|
bmm1_scale=bmm1_scale,
|
476
519
|
)
|
477
520
|
|
478
|
-
#
|
479
|
-
|
480
|
-
output
|
521
|
+
# Reshape output directly without slicing
|
522
|
+
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
523
|
+
return output
|
524
|
+
|
525
|
+
def forward_extend(
|
526
|
+
self,
|
527
|
+
q: torch.Tensor,
|
528
|
+
k: torch.Tensor,
|
529
|
+
v: torch.Tensor,
|
530
|
+
layer: RadixAttention,
|
531
|
+
forward_batch: ForwardBatch,
|
532
|
+
save_kv_cache: bool = True,
|
533
|
+
q_rope: Optional[torch.Tensor] = None,
|
534
|
+
k_rope: Optional[torch.Tensor] = None,
|
535
|
+
) -> torch.Tensor:
|
536
|
+
if (
|
537
|
+
forward_batch.forward_mode.is_target_verify()
|
538
|
+
or forward_batch.forward_mode.is_draft_extend()
|
539
|
+
):
|
540
|
+
return super().forward_extend(
|
541
|
+
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
542
|
+
)
|
481
543
|
|
544
|
+
if not forward_batch.attn_attend_prefix_cache:
|
545
|
+
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
546
|
+
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
547
|
+
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
548
|
+
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
549
|
+
query=q,
|
550
|
+
key=k,
|
551
|
+
value=v,
|
552
|
+
workspace_buffer=self.workspace_buffer,
|
553
|
+
seq_lens=self.forward_prefill_metadata.seq_lens,
|
554
|
+
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
555
|
+
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
556
|
+
bmm1_scale=layer.scaling,
|
557
|
+
bmm2_scale=1.0,
|
558
|
+
o_sf_scale=1.0,
|
559
|
+
batch_size=forward_batch.batch_size,
|
560
|
+
window_left=-1,
|
561
|
+
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
562
|
+
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
563
|
+
enable_pdl=False,
|
564
|
+
is_causal=True,
|
565
|
+
return_lse=forward_batch.mha_return_lse,
|
566
|
+
)
|
567
|
+
else:
|
568
|
+
# replace with trtllm ragged attention once accuracy is resolved.
|
569
|
+
output = super().forward_extend(
|
570
|
+
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
571
|
+
)
|
482
572
|
return output
|
483
573
|
|
484
574
|
|
@@ -64,8 +64,7 @@ def get_wave_kernel(
|
|
64
64
|
subs=hyperparams_0,
|
65
65
|
canonicalize=True,
|
66
66
|
run_bench=False,
|
67
|
-
|
68
|
-
use_buffer_store_ops=True,
|
67
|
+
use_buffer_ops=True,
|
69
68
|
waves_per_eu=2,
|
70
69
|
dynamic_symbols=dynamic_symbols_0,
|
71
70
|
wave_runtime=True,
|
@@ -77,8 +76,7 @@ def get_wave_kernel(
|
|
77
76
|
subs=hyperparams_1,
|
78
77
|
canonicalize=True,
|
79
78
|
run_bench=False,
|
80
|
-
|
81
|
-
use_buffer_store_ops=False,
|
79
|
+
use_buffer_ops=False,
|
82
80
|
waves_per_eu=4,
|
83
81
|
dynamic_symbols=dynamic_symbols_1,
|
84
82
|
wave_runtime=True,
|
@@ -67,11 +67,9 @@ def get_wave_kernel(
|
|
67
67
|
schedule=SchedulingType.NONE,
|
68
68
|
use_scheduling_barriers=False,
|
69
69
|
dynamic_symbols=dynamic_symbols,
|
70
|
-
|
71
|
-
use_buffer_store_ops=True,
|
70
|
+
use_buffer_ops=True,
|
72
71
|
waves_per_eu=2,
|
73
72
|
denorm_fp_math_f32="preserve-sign",
|
74
|
-
gpu_native_math_precision=True,
|
75
73
|
wave_runtime=True,
|
76
74
|
)
|
77
75
|
options = set_default_run_config(options)
|
@@ -42,10 +42,24 @@ from sglang.srt.layers.moe import (
|
|
42
42
|
)
|
43
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
-
from sglang.srt.utils import
|
45
|
+
from sglang.srt.utils import (
|
46
|
+
get_bool_env_var,
|
47
|
+
is_cuda,
|
48
|
+
is_flashinfer_available,
|
49
|
+
is_gfx95_supported,
|
50
|
+
is_hip,
|
51
|
+
is_sm90_supported,
|
52
|
+
is_sm100_supported,
|
53
|
+
)
|
46
54
|
|
47
55
|
_is_flashinfer_available = is_flashinfer_available()
|
56
|
+
_is_sm90_supported = is_cuda() and is_sm90_supported()
|
48
57
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
58
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
59
|
+
_is_gfx95_supported = is_gfx95_supported()
|
60
|
+
|
61
|
+
if _use_aiter and _is_gfx95_supported:
|
62
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
|
49
63
|
|
50
64
|
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
51
65
|
|
@@ -201,6 +215,7 @@ class LayerCommunicator:
|
|
201
215
|
hidden_states: torch.Tensor,
|
202
216
|
residual: torch.Tensor,
|
203
217
|
forward_batch: ForwardBatch,
|
218
|
+
qaunt_format: str = "",
|
204
219
|
):
|
205
220
|
if hidden_states.shape[0] == 0:
|
206
221
|
residual = hidden_states
|
@@ -218,11 +233,34 @@ class LayerCommunicator:
|
|
218
233
|
else:
|
219
234
|
if residual is None:
|
220
235
|
residual = hidden_states
|
221
|
-
|
236
|
+
|
237
|
+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
238
|
+
hidden_states = fused_rms_mxfp4_quant(
|
239
|
+
hidden_states,
|
240
|
+
self.input_layernorm.weight,
|
241
|
+
self.input_layernorm.variance_epsilon,
|
242
|
+
None,
|
243
|
+
None,
|
244
|
+
None,
|
245
|
+
None,
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
hidden_states = self.input_layernorm(hidden_states)
|
222
249
|
else:
|
223
|
-
|
224
|
-
hidden_states, residual
|
225
|
-
|
250
|
+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
251
|
+
hidden_states, residual = fused_rms_mxfp4_quant(
|
252
|
+
hidden_states,
|
253
|
+
self.input_layernorm.weight,
|
254
|
+
self.input_layernorm.variance_epsilon,
|
255
|
+
None,
|
256
|
+
None,
|
257
|
+
None,
|
258
|
+
residual,
|
259
|
+
)
|
260
|
+
else:
|
261
|
+
hidden_states, residual = self.input_layernorm(
|
262
|
+
hidden_states, residual
|
263
|
+
)
|
226
264
|
|
227
265
|
hidden_states = self._communicate_simple_fn(
|
228
266
|
hidden_states=hidden_states,
|
@@ -484,11 +522,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
484
522
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
485
523
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
486
524
|
if (
|
487
|
-
_is_sm100_supported
|
525
|
+
(_is_sm100_supported or _is_sm90_supported)
|
488
526
|
and _is_flashinfer_available
|
489
527
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
490
528
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
491
|
-
and hidden_states.shape[0] <=
|
529
|
+
and hidden_states.shape[0] <= 4096
|
492
530
|
):
|
493
531
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
494
532
|
hidden_states, residual
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
|
+
from packaging.version import Version
|
21
22
|
|
22
23
|
from sglang.srt.custom_op import CustomOp
|
23
24
|
from sglang.srt.utils import (
|
@@ -25,35 +26,41 @@ from sglang.srt.utils import (
|
|
25
26
|
get_bool_env_var,
|
26
27
|
is_cpu,
|
27
28
|
is_cuda,
|
29
|
+
is_flashinfer_available,
|
28
30
|
is_hip,
|
29
31
|
is_npu,
|
32
|
+
is_xpu,
|
30
33
|
supports_custom_op,
|
31
34
|
)
|
32
35
|
|
33
36
|
_is_cuda = is_cuda()
|
37
|
+
_is_flashinfer_available = is_flashinfer_available()
|
34
38
|
_is_hip = is_hip()
|
35
39
|
_is_npu = is_npu()
|
36
40
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
37
41
|
_is_cpu_amx_available = cpu_has_amx_support()
|
38
42
|
_is_cpu = is_cpu()
|
43
|
+
_is_xpu = is_xpu()
|
39
44
|
|
40
45
|
if _is_cuda:
|
41
|
-
|
42
|
-
fused_add_rmsnorm
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
)
|
46
|
+
if _is_flashinfer_available:
|
47
|
+
from flashinfer.norm import fused_add_rmsnorm
|
48
|
+
else:
|
49
|
+
from sgl_kernel import fused_add_rmsnorm
|
50
|
+
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
47
51
|
|
48
52
|
if _use_aiter:
|
49
53
|
from aiter import rmsnorm2d_fwd as rms_norm
|
50
54
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
51
55
|
elif _is_hip:
|
56
|
+
import vllm
|
52
57
|
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
53
58
|
|
59
|
+
_vllm_version = Version(vllm.__version__)
|
60
|
+
|
54
61
|
logger = logging.getLogger(__name__)
|
55
62
|
|
56
|
-
if
|
63
|
+
if _is_npu:
|
57
64
|
import torch_npu
|
58
65
|
|
59
66
|
|
@@ -127,8 +134,21 @@ class RMSNorm(CustomOp):
|
|
127
134
|
# NOTE: Remove this if aiter kernel supports discontinuous input
|
128
135
|
x = x.contiguous()
|
129
136
|
if residual is not None:
|
130
|
-
|
131
|
-
|
137
|
+
if _vllm_version < Version("0.9"):
|
138
|
+
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
139
|
+
return x, residual
|
140
|
+
else:
|
141
|
+
residual_out = torch.empty_like(x)
|
142
|
+
output = torch.empty_like(x)
|
143
|
+
fused_add_rms_norm(
|
144
|
+
output,
|
145
|
+
x,
|
146
|
+
residual_out,
|
147
|
+
residual,
|
148
|
+
self.weight.data,
|
149
|
+
self.variance_epsilon,
|
150
|
+
)
|
151
|
+
return output, residual_out
|
132
152
|
out = torch.empty_like(x)
|
133
153
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
134
154
|
return out
|
@@ -266,28 +286,50 @@ class GemmaRMSNorm(CustomOp):
|
|
266
286
|
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
267
287
|
return out
|
268
288
|
|
289
|
+
def forward_npu(
|
290
|
+
self,
|
291
|
+
x: torch.Tensor,
|
292
|
+
residual: Optional[torch.Tensor] = None,
|
293
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
294
|
+
if residual is not None:
|
295
|
+
x = x + residual
|
296
|
+
residual = x
|
297
|
+
|
298
|
+
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
299
|
+
return x if residual is None else (x, residual)
|
269
300
|
|
270
|
-
|
301
|
+
|
302
|
+
class Gemma3RMSNorm(CustomOp):
|
271
303
|
def __init__(self, dim: int, eps: float = 1e-6):
|
272
304
|
super().__init__()
|
273
305
|
self.eps = eps
|
274
306
|
self.weight = nn.Parameter(torch.zeros(dim))
|
307
|
+
# Re-dispatch
|
275
308
|
|
276
309
|
def _norm(self, x):
|
277
310
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
278
311
|
|
279
|
-
def
|
312
|
+
def forward_native(self, x):
|
280
313
|
output = self._norm(x.float())
|
281
314
|
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
282
315
|
# See https://github.com/huggingface/transformers/pull/29402
|
283
316
|
output = output * (1.0 + self.weight.float())
|
284
317
|
return output.type_as(x)
|
285
318
|
|
319
|
+
def forward_cuda(self, x):
|
320
|
+
return self.forward_native(x)
|
321
|
+
|
322
|
+
def forward_npu(self, x):
|
323
|
+
output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
|
324
|
+
return output
|
325
|
+
|
286
326
|
def extra_repr(self):
|
287
327
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
288
328
|
|
289
329
|
|
290
|
-
if not (
|
330
|
+
if not (
|
331
|
+
_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
|
332
|
+
):
|
291
333
|
logger.info(
|
292
334
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
293
335
|
)
|
@@ -46,10 +46,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
46
46
|
ForwardBatch,
|
47
47
|
ForwardMode,
|
48
48
|
)
|
49
|
-
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
49
|
+
from sglang.srt.utils import dump_to_file, is_npu, use_intel_amx_backend
|
50
50
|
|
51
51
|
logger = logging.getLogger(__name__)
|
52
52
|
|
53
|
+
_is_npu = is_npu()
|
54
|
+
|
53
55
|
|
54
56
|
@dataclasses.dataclass
|
55
57
|
class LogitsProcessorOutput:
|
@@ -61,7 +63,7 @@ class LogitsProcessorOutput:
|
|
61
63
|
hidden_states: Optional[torch.Tensor] = None
|
62
64
|
|
63
65
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
64
|
-
#
|
66
|
+
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
|
65
67
|
next_token_logprobs: Optional[torch.Tensor] = None
|
66
68
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
67
69
|
next_token_top_logprobs_val: Optional[List] = None
|
@@ -517,7 +519,12 @@ class LogitsProcessor(nn.Module):
|
|
517
519
|
logits = logits[:, : self.config.vocab_size].float()
|
518
520
|
|
519
521
|
if self.final_logit_softcapping:
|
520
|
-
|
522
|
+
if not _is_npu:
|
523
|
+
fused_softcap(logits, self.final_logit_softcapping)
|
524
|
+
else:
|
525
|
+
logits = self.final_logit_softcapping * torch.tanh(
|
526
|
+
logits / self.final_logit_softcapping
|
527
|
+
)
|
521
528
|
|
522
529
|
return logits
|
523
530
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
1
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunner, MoeRunnerConfig
|
2
2
|
from sglang.srt.layers.moe.utils import (
|
3
3
|
DeepEPMode,
|
4
4
|
MoeA2ABackend,
|
@@ -17,6 +17,7 @@ from sglang.srt.layers.moe.utils import (
|
|
17
17
|
__all__ = [
|
18
18
|
"DeepEPMode",
|
19
19
|
"MoeA2ABackend",
|
20
|
+
"MoeRunner",
|
20
21
|
"MoeRunnerConfig",
|
21
22
|
"MoeRunnerBackend",
|
22
23
|
"initialize_moe_config",
|