sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from sglang.srt.custom_op import CustomOp
|
|
12
12
|
from sglang.srt.utils import (
|
13
13
|
cpu_has_amx_support,
|
14
14
|
get_bool_env_var,
|
15
|
+
get_compiler_backend,
|
15
16
|
is_cpu,
|
16
17
|
is_cuda,
|
17
18
|
is_hip,
|
@@ -26,13 +27,19 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|
26
27
|
_is_cpu = is_cpu()
|
27
28
|
|
28
29
|
if _is_cuda:
|
29
|
-
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
30
|
+
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
|
31
|
+
else:
|
32
|
+
FusedSetKVBufferArg = None
|
33
|
+
|
30
34
|
if _use_aiter:
|
31
35
|
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
32
36
|
|
33
37
|
if is_npu():
|
34
38
|
import torch_npu
|
35
39
|
|
40
|
+
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
|
41
|
+
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
|
42
|
+
|
36
43
|
|
37
44
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
38
45
|
x1 = x[..., : x.shape[-1] // 2]
|
@@ -142,8 +149,13 @@ class RotaryEmbedding(CustomOp):
|
|
142
149
|
query: torch.Tensor,
|
143
150
|
key: torch.Tensor,
|
144
151
|
offsets: Optional[torch.Tensor] = None,
|
152
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
145
153
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
146
154
|
"""A PyTorch-native implementation of forward()."""
|
155
|
+
assert (
|
156
|
+
fused_set_kv_buffer_arg is None
|
157
|
+
), "fused_set_kv_buffer_arg is not supported for native implementation"
|
158
|
+
|
147
159
|
if offsets is not None:
|
148
160
|
positions = positions + offsets
|
149
161
|
positions = positions.flatten()
|
@@ -172,12 +184,17 @@ class RotaryEmbedding(CustomOp):
|
|
172
184
|
query: torch.Tensor,
|
173
185
|
key: torch.Tensor,
|
174
186
|
offsets: Optional[torch.Tensor] = None,
|
187
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
175
188
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
176
189
|
"""A PyTorch-npu implementation of forward()."""
|
177
|
-
|
190
|
+
assert (
|
191
|
+
fused_set_kv_buffer_arg is None
|
192
|
+
), "fused_set_kv_buffer_arg is not supported for npu implementation"
|
178
193
|
|
179
194
|
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
180
|
-
return self.forward_native(
|
195
|
+
return self.forward_native(
|
196
|
+
positions, query, key, offsets, fused_set_kv_buffer_arg
|
197
|
+
)
|
181
198
|
else:
|
182
199
|
rotary_mode = "half"
|
183
200
|
if self.is_neox_style:
|
@@ -202,7 +219,12 @@ class RotaryEmbedding(CustomOp):
|
|
202
219
|
query: torch.Tensor,
|
203
220
|
key: torch.Tensor,
|
204
221
|
offsets: Optional[torch.Tensor] = None,
|
222
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
205
223
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
224
|
+
assert (
|
225
|
+
fused_set_kv_buffer_arg is None
|
226
|
+
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
|
227
|
+
|
206
228
|
positions = torch.add(positions, offsets) if offsets is not None else positions
|
207
229
|
if _is_cpu_amx_available:
|
208
230
|
return torch.ops.sgl_kernel.rotary_embedding_cpu(
|
@@ -214,7 +236,9 @@ class RotaryEmbedding(CustomOp):
|
|
214
236
|
self.is_neox_style,
|
215
237
|
)
|
216
238
|
else:
|
217
|
-
return self.forward_native(
|
239
|
+
return self.forward_native(
|
240
|
+
positions, query, key, offsets, fused_set_kv_buffer_arg
|
241
|
+
)
|
218
242
|
|
219
243
|
def forward_cuda(
|
220
244
|
self,
|
@@ -222,7 +246,7 @@ class RotaryEmbedding(CustomOp):
|
|
222
246
|
query: torch.Tensor,
|
223
247
|
key: torch.Tensor,
|
224
248
|
offsets: Optional[torch.Tensor] = None,
|
225
|
-
fused_set_kv_buffer_arg
|
249
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
226
250
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
227
251
|
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
228
252
|
apply_rope_with_cos_sin_cache_inplace(
|
@@ -782,27 +806,33 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
782
806
|
key: torch.Tensor,
|
783
807
|
offsets: Optional[torch.Tensor] = None,
|
784
808
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
785
|
-
|
786
|
-
|
787
|
-
|
788
|
-
return self.forward_native(positions, query, key, offsets)
|
789
|
-
num_tokens = query.shape[0]
|
790
|
-
rotary_mode = "half" if self.is_neox_style else "interleave"
|
809
|
+
num_tokens, num_q_heads, _ = query.shape
|
810
|
+
num_k_heads = key.shape[1]
|
811
|
+
|
791
812
|
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device)
|
813
|
+
cos_sin = self.cos_sin_cache[
|
814
|
+
torch.add(positions, offsets) if offsets is not None else positions
|
815
|
+
]
|
816
|
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
817
|
+
# Reshape to [batchsize, head_dim, seq, rotary_dim]
|
818
|
+
cos = cos.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
|
819
|
+
sin = sin.repeat(1, 2).unsqueeze(-2).unsqueeze(-2)
|
820
|
+
|
792
821
|
query_rot = query[..., : self.rotary_dim]
|
793
822
|
key_rot = key[..., : self.rotary_dim]
|
794
823
|
if self.rotary_dim < self.head_size:
|
795
824
|
query_pass = query[..., self.rotary_dim :]
|
796
825
|
key_pass = key[..., self.rotary_dim :]
|
797
826
|
|
798
|
-
query_rot
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
827
|
+
query_rot = torch_npu.npu_interleave_rope(
|
828
|
+
query_rot.reshape(num_tokens, num_q_heads, 1, self.rotary_dim),
|
829
|
+
cos,
|
830
|
+
sin,
|
831
|
+
)
|
832
|
+
key_rot = torch_npu.npu_interleave_rope(
|
833
|
+
key_rot.reshape(num_tokens, num_k_heads, 1, self.rotary_dim),
|
834
|
+
cos,
|
835
|
+
sin,
|
806
836
|
)
|
807
837
|
query_rot = query_rot.reshape(num_tokens, -1, self.rotary_dim)
|
808
838
|
key_rot = key_rot.reshape(num_tokens, -1, self.rotary_dim)
|
@@ -1029,12 +1059,13 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1029
1059
|
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
|
1030
1060
|
)
|
1031
1061
|
|
1032
|
-
@torch.compile(dynamic=True)
|
1062
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
1033
1063
|
def forward(
|
1034
1064
|
self,
|
1035
1065
|
positions: torch.Tensor,
|
1036
1066
|
query: torch.Tensor,
|
1037
1067
|
key: torch.Tensor,
|
1068
|
+
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
1038
1069
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1039
1070
|
"""PyTorch-native implementation equivalent to forward().
|
1040
1071
|
|
@@ -1045,6 +1076,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1045
1076
|
query: [num_tokens, num_heads * head_size]
|
1046
1077
|
key: [num_tokens, num_kv_heads * head_size]
|
1047
1078
|
"""
|
1079
|
+
assert (
|
1080
|
+
fused_set_kv_buffer_arg is None
|
1081
|
+
), "save kv cache is not supported for MRotaryEmbedding."
|
1048
1082
|
assert positions.ndim == 1 or positions.ndim == 2
|
1049
1083
|
|
1050
1084
|
num_tokens = positions.shape[-1]
|
@@ -1177,7 +1211,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|
1177
1211
|
|
1178
1212
|
time_tensor_long = time_tensor.long()
|
1179
1213
|
t_index = time_tensor_long.flatten()
|
1180
|
-
elif model_type
|
1214
|
+
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
|
1181
1215
|
t_index = (
|
1182
1216
|
torch.arange(llm_grid_t)
|
1183
1217
|
.view(-1, 1)
|
@@ -1888,17 +1922,30 @@ def apply_rotary_pos_emb_npu(
|
|
1888
1922
|
sin: torch.Tensor,
|
1889
1923
|
unsqueeze_dim=1,
|
1890
1924
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1891
|
-
|
1925
|
+
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
|
1926
|
+
|
1927
|
+
Args:
|
1928
|
+
q: [num_tokens, num_heads, head_size]
|
1929
|
+
k: [num_tokens, num_kv_heads, head_size]
|
1930
|
+
cos: [num_tokens, head_size]
|
1931
|
+
sin: [num_tokens, head_size]
|
1932
|
+
"""
|
1933
|
+
if (
|
1934
|
+
cos.dim() != 2
|
1935
|
+
or q.dim() != 3
|
1936
|
+
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
|
1937
|
+
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
|
1938
|
+
):
|
1939
|
+
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
|
1892
1940
|
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
|
1893
|
-
cos = cos.unsqueeze(unsqueeze_dim)
|
1894
|
-
|
1895
|
-
|
1896
|
-
|
1897
|
-
|
1898
|
-
|
1899
|
-
q_embed
|
1900
|
-
|
1901
|
-
k_embed = torch.transpose(k_embed, 1, 2)
|
1941
|
+
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
1942
|
+
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
|
1943
|
+
q = q.unsqueeze(0)
|
1944
|
+
k = k.unsqueeze(0)
|
1945
|
+
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
|
1946
|
+
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
|
1947
|
+
q_embed = q_embed.squeeze(0)
|
1948
|
+
k_embed = k_embed.squeeze(0)
|
1902
1949
|
return q_embed, k_embed
|
1903
1950
|
|
1904
1951
|
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import logging
|
2
|
-
from typing import List, Tuple
|
2
|
+
from typing import List, Optional, Tuple
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.distributed as dist
|
@@ -65,6 +65,7 @@ class Sampler(nn.Module):
|
|
65
65
|
return_logprob: bool,
|
66
66
|
top_logprobs_nums: List[int],
|
67
67
|
token_ids_logprobs: List[List[int]],
|
68
|
+
positions: torch.Tensor,
|
68
69
|
):
|
69
70
|
"""Run a sampler & compute logprobs and update logits_output accordingly.
|
70
71
|
|
@@ -77,6 +78,8 @@ class Sampler(nn.Module):
|
|
77
78
|
batch_next_token_ids: next token IDs. If set, skip sampling and only
|
78
79
|
compute output logprobs It is used for speculative decoding which
|
79
80
|
performs sampling in draft workers.
|
81
|
+
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
82
|
+
to get the unique seed for each position.
|
80
83
|
"""
|
81
84
|
logits = logits_output.next_token_logits
|
82
85
|
|
@@ -124,6 +127,8 @@ class Sampler(nn.Module):
|
|
124
127
|
sampling_info.top_ps,
|
125
128
|
sampling_info.min_ps,
|
126
129
|
sampling_info.need_min_p_sampling,
|
130
|
+
sampling_info.sampling_seed,
|
131
|
+
positions,
|
127
132
|
)
|
128
133
|
else:
|
129
134
|
raise ValueError(
|
@@ -189,6 +194,7 @@ class Sampler(nn.Module):
|
|
189
194
|
Optimized for prefill-only scoring requests that need token probabilities
|
190
195
|
but don't require next token generation.
|
191
196
|
"""
|
197
|
+
|
192
198
|
if logits_output.next_token_logits is None:
|
193
199
|
logger.warning("No logits available for logprob computation")
|
194
200
|
return
|
@@ -230,8 +236,14 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
230
236
|
top_ps: torch.Tensor,
|
231
237
|
min_ps: torch.Tensor,
|
232
238
|
need_min_p_sampling: bool,
|
239
|
+
sampling_seed: Optional[torch.Tensor],
|
240
|
+
positions: torch.Tensor,
|
233
241
|
):
|
234
|
-
"""
|
242
|
+
"""
|
243
|
+
A top-k, top-p and min-p sampling implementation with native pytorch operations.
|
244
|
+
When sampling_seed is not None, deterministic inference will be enabled, it will sample
|
245
|
+
with the sampling_seed of each request.
|
246
|
+
"""
|
235
247
|
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
236
248
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
237
249
|
probs_sort[
|
@@ -243,14 +255,50 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
243
255
|
if need_min_p_sampling:
|
244
256
|
min_p_thresholds = probs_sort[:, 0] * min_ps
|
245
257
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
246
|
-
|
247
|
-
|
258
|
+
if sampling_seed is not None:
|
259
|
+
sampled_index = multinomial_with_seed(probs_sort, sampling_seed, positions)
|
260
|
+
else:
|
261
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
248
262
|
# int32 range is enough to represent the token ids
|
249
263
|
probs_idx = probs_idx.to(torch.int32)
|
250
264
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
251
265
|
return batch_next_token_ids
|
252
266
|
|
253
267
|
|
268
|
+
def multinomial_with_seed(
|
269
|
+
inputs: torch.Tensor, seed: torch.Tensor, positions: torch.Tensor
|
270
|
+
) -> torch.Tensor:
|
271
|
+
"""
|
272
|
+
Samples n elements from an input tensor `inputs` of shape (n, m) using
|
273
|
+
a unique random seed for each row. This is a deterministic batched alternative to
|
274
|
+
`torch.multinomial`.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
inputs: A float tensor of shape (n, m) representing n categorical
|
278
|
+
distributions with m categories each. The values are treated
|
279
|
+
as weights and do not need to sum to 1.
|
280
|
+
seed: An integer tensor of shape (n,) containing the random seed
|
281
|
+
for each corresponding row in `inputs`.
|
282
|
+
positions: The positions of the tokens in the sequence. Used for deterministic sampling
|
283
|
+
to get the unique seed for each position.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
A tensor of shape (n,) where the i-th element is an index sampled
|
287
|
+
from the distribution in `inputs[i]` using `seed[i]`.
|
288
|
+
"""
|
289
|
+
n, m = inputs.shape
|
290
|
+
col_indices = torch.arange(m, device=inputs.device).unsqueeze(0)
|
291
|
+
step_seed = seed * 19349663 ^ positions * 73856093
|
292
|
+
seed_expanded = step_seed.unsqueeze(-1)
|
293
|
+
hashed = seed_expanded * 8589934591 ^ col_indices * 479001599
|
294
|
+
uniform_samples = (hashed % (2**24)).float() / (2**24)
|
295
|
+
epsilon = 1e-9
|
296
|
+
gumbel_noise = -torch.log(-torch.log(uniform_samples + epsilon) + epsilon)
|
297
|
+
log_probs = torch.log(inputs + epsilon)
|
298
|
+
perturbed_log_probs = log_probs + gumbel_noise
|
299
|
+
return torch.argmax(perturbed_log_probs, dim=1, keepdim=True)
|
300
|
+
|
301
|
+
|
254
302
|
def sampling_from_probs_torch(probs: torch.Tensor):
|
255
303
|
"""A sampling implementation with native pytorch operations, without
|
256
304
|
top-k, top-p, or min-p filtering."""
|
sglang/srt/layers/utils.py
CHANGED
@@ -15,6 +15,29 @@ def get_layer_id(weight_name):
|
|
15
15
|
return None
|
16
16
|
|
17
17
|
|
18
|
+
def pad_or_narrow_weight(
|
19
|
+
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
|
20
|
+
) -> torch.Tensor:
|
21
|
+
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
|
22
|
+
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
|
23
|
+
|
24
|
+
if valid_size > 0:
|
25
|
+
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
|
26
|
+
pad_shape = list(loaded_weight.shape)
|
27
|
+
pad_shape[input_dim] = shard_size - valid_size
|
28
|
+
pad = torch.zeros(
|
29
|
+
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
30
|
+
)
|
31
|
+
return torch.cat([loaded_slice, pad], dim=input_dim)
|
32
|
+
|
33
|
+
# All padding
|
34
|
+
pad_shape = list(loaded_weight.shape)
|
35
|
+
pad_shape[input_dim] = shard_size
|
36
|
+
return torch.zeros(
|
37
|
+
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
|
38
|
+
)
|
39
|
+
|
40
|
+
|
18
41
|
class PPMissingLayer(torch.nn.Identity):
|
19
42
|
# Adapted from
|
20
43
|
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
@@ -143,10 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
|
|
143
143
|
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
144
144
|
|
145
145
|
return TritonLoRABackend
|
146
|
-
|
147
|
-
|
146
|
+
elif name == "csgmv":
|
147
|
+
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
|
148
148
|
|
149
|
-
|
149
|
+
return ChunkedSgmvLoRABackend
|
150
150
|
elif name == "flashinfer":
|
151
151
|
raise ValueError(
|
152
152
|
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
|