sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -85,7 +85,7 @@ def execute_sbo(
|
|
|
85
85
|
_compute_overlap_args(dispatch_output, alt_stream, disable_sbo=disable_sbo)
|
|
86
86
|
)
|
|
87
87
|
|
|
88
|
-
|
|
88
|
+
combine_input = experts.run_moe_core(
|
|
89
89
|
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
|
|
90
90
|
)
|
|
91
91
|
if (e := meta_overlap_args.get("record_event_after_down")) is not None:
|
|
@@ -99,9 +99,7 @@ def execute_sbo(
|
|
|
99
99
|
forward_shared_experts()
|
|
100
100
|
|
|
101
101
|
hidden_states = experts.dispatcher.combine(
|
|
102
|
-
|
|
103
|
-
topk_ids=dispatch_output.topk_ids,
|
|
104
|
-
topk_weights=dispatch_output.topk_weights,
|
|
102
|
+
combine_input=combine_input,
|
|
105
103
|
overlap_args=combine_overlap_args,
|
|
106
104
|
)
|
|
107
105
|
|
|
@@ -49,6 +49,7 @@ class DraftBackendFactory:
|
|
|
49
49
|
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
|
50
50
|
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
|
51
51
|
"nsa": self._create_nsa_decode_backend,
|
|
52
|
+
"ascend": self._create_ascend_decode_backend,
|
|
52
53
|
}
|
|
53
54
|
|
|
54
55
|
return self._create_backend(
|
|
@@ -72,6 +73,7 @@ class DraftBackendFactory:
|
|
|
72
73
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
|
73
74
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
|
74
75
|
"nsa": self._create_nsa_prefill_backend,
|
|
76
|
+
"ascend": self._create_ascend_prefill_backend,
|
|
75
77
|
}
|
|
76
78
|
backend_name = (
|
|
77
79
|
"decode_attention_backend"
|
|
@@ -173,6 +175,15 @@ class DraftBackendFactory:
|
|
|
173
175
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
174
176
|
)
|
|
175
177
|
|
|
178
|
+
def _create_ascend_decode_backend(self):
|
|
179
|
+
from sglang.srt.layers.attention.ascend_backend import (
|
|
180
|
+
AscendAttnMultiStepDraftBackend,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
return AscendAttnMultiStepDraftBackend(
|
|
184
|
+
self.draft_model_runner, self.topk, self.speculative_num_steps
|
|
185
|
+
)
|
|
186
|
+
|
|
176
187
|
def _create_flashinfer_prefill_backend(self):
|
|
177
188
|
if not get_global_server_args().use_mla_backend:
|
|
178
189
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
@@ -219,6 +230,11 @@ class DraftBackendFactory:
|
|
|
219
230
|
|
|
220
231
|
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
|
221
232
|
|
|
233
|
+
def _create_ascend_prefill_backend(self):
|
|
234
|
+
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
|
235
|
+
|
|
236
|
+
return AscendAttnBackend(self.draft_model_runner)
|
|
237
|
+
|
|
222
238
|
def _create_flashmla_prefill_backend(self):
|
|
223
239
|
logger.warning(
|
|
224
240
|
"flashmla prefill backend is not yet supported for draft extend."
|
|
@@ -24,12 +24,13 @@ from sglang.srt.speculative.eagle_info_v2 import (
|
|
|
24
24
|
EagleDraftInputV2Mixin,
|
|
25
25
|
EagleVerifyInputV2Mixin,
|
|
26
26
|
)
|
|
27
|
+
from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
|
|
27
28
|
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
|
28
29
|
from sglang.srt.speculative.spec_utils import (
|
|
29
30
|
SIMULATE_ACC_LEN,
|
|
30
31
|
TREE_SPEC_KERNEL_AVAILABLE,
|
|
31
32
|
align_evict_mask_to_page_size,
|
|
32
|
-
|
|
33
|
+
assign_req_to_token_pool_func,
|
|
33
34
|
create_accept_length_filter,
|
|
34
35
|
create_extend_after_decode_spec_info,
|
|
35
36
|
filter_finished_cache_loc_kernel,
|
|
@@ -37,17 +38,16 @@ from sglang.srt.speculative.spec_utils import (
|
|
|
37
38
|
get_src_tgt_cache_loc,
|
|
38
39
|
get_target_cache_loc,
|
|
39
40
|
)
|
|
40
|
-
from sglang.srt.utils import is_cuda,
|
|
41
|
+
from sglang.srt.utils import is_cuda, is_npu, next_power_of_2
|
|
42
|
+
|
|
43
|
+
_is_npu = is_npu()
|
|
41
44
|
|
|
42
45
|
if is_cuda():
|
|
43
46
|
from sgl_kernel import (
|
|
44
47
|
top_k_renorm_prob,
|
|
45
48
|
top_p_renorm_prob,
|
|
46
49
|
tree_speculative_sampling_target_only,
|
|
47
|
-
verify_tree_greedy,
|
|
48
50
|
)
|
|
49
|
-
elif is_hip():
|
|
50
|
-
from sgl_kernel import verify_tree_greedy
|
|
51
51
|
|
|
52
52
|
logger = logging.getLogger(__name__)
|
|
53
53
|
|
|
@@ -77,18 +77,22 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
77
77
|
|
|
78
78
|
@classmethod
|
|
79
79
|
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
|
|
80
|
+
if not _is_npu:
|
|
81
|
+
device = "cuda"
|
|
82
|
+
else:
|
|
83
|
+
device = "npu"
|
|
80
84
|
return cls(
|
|
81
|
-
draft_token=torch.empty((0,), dtype=torch.long, device=
|
|
82
|
-
custom_mask=torch.full((0,), True, dtype=torch.bool, device=
|
|
83
|
-
positions=torch.empty((0,), dtype=torch.int64, device=
|
|
85
|
+
draft_token=torch.empty((0,), dtype=torch.long, device=device),
|
|
86
|
+
custom_mask=torch.full((0,), True, dtype=torch.bool, device=device),
|
|
87
|
+
positions=torch.empty((0,), dtype=torch.int64, device=device),
|
|
84
88
|
retrive_index=torch.full(
|
|
85
|
-
(0, num_verify_tokens), -1, dtype=torch.long, device=
|
|
89
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device=device
|
|
86
90
|
),
|
|
87
91
|
retrive_next_token=torch.full(
|
|
88
|
-
(0, num_verify_tokens), -1, dtype=torch.long, device=
|
|
92
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device=device
|
|
89
93
|
),
|
|
90
94
|
retrive_next_sibling=torch.full(
|
|
91
|
-
(0, num_verify_tokens), -1, dtype=torch.long, device=
|
|
95
|
+
(0, num_verify_tokens), -1, dtype=torch.long, device=device
|
|
92
96
|
),
|
|
93
97
|
retrive_cum_len=None,
|
|
94
98
|
topk=topk,
|
|
@@ -134,14 +138,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
134
138
|
self.last_loc = last_loc
|
|
135
139
|
|
|
136
140
|
bs = batch.batch_size()
|
|
137
|
-
|
|
141
|
+
assign_req_to_token_pool_func(
|
|
138
142
|
batch.req_pool_indices,
|
|
139
143
|
batch.req_to_token_pool.req_to_token,
|
|
140
144
|
batch.seq_lens,
|
|
141
145
|
end_offset,
|
|
142
146
|
batch.out_cache_loc,
|
|
143
|
-
|
|
144
|
-
next_power_of_2(bs),
|
|
147
|
+
bs,
|
|
145
148
|
)
|
|
146
149
|
|
|
147
150
|
def generate_attn_arg_prefill(
|
|
@@ -151,16 +154,17 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
151
154
|
paged_kernel_lens_sum: int,
|
|
152
155
|
req_to_token: torch.Tensor,
|
|
153
156
|
):
|
|
157
|
+
device = req_pool_indices.device
|
|
154
158
|
batch_size = len(req_pool_indices)
|
|
155
159
|
qo_indptr = torch.arange(
|
|
156
160
|
0,
|
|
157
161
|
(1 + batch_size) * self.draft_token_num,
|
|
158
162
|
step=self.draft_token_num,
|
|
159
163
|
dtype=torch.int32,
|
|
160
|
-
device=
|
|
164
|
+
device=device,
|
|
161
165
|
)
|
|
162
166
|
cum_kv_seq_len = torch.zeros(
|
|
163
|
-
(batch_size + 1,), dtype=torch.int32, device=
|
|
167
|
+
(batch_size + 1,), dtype=torch.int32, device=device
|
|
164
168
|
)
|
|
165
169
|
|
|
166
170
|
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
|
@@ -169,7 +173,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
169
173
|
kv_indices = torch.empty(
|
|
170
174
|
paged_kernel_lens_sum + self.draft_token_num * batch_size,
|
|
171
175
|
dtype=torch.int32,
|
|
172
|
-
device=
|
|
176
|
+
device=device,
|
|
173
177
|
)
|
|
174
178
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
|
175
179
|
req_to_token,
|
|
@@ -226,11 +230,11 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
226
230
|
|
|
227
231
|
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
|
228
232
|
predict_shape[-1] += 1
|
|
229
|
-
predict = torch.empty(predict_shape, dtype=torch.int32, device=
|
|
233
|
+
predict = torch.empty(predict_shape, dtype=torch.int32, device=batch.device)
|
|
230
234
|
accept_index = torch.full(
|
|
231
|
-
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device=
|
|
235
|
+
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device=batch.device
|
|
232
236
|
)
|
|
233
|
-
accept_length = torch.empty((bs,), dtype=torch.int32, device=
|
|
237
|
+
accept_length = torch.empty((bs,), dtype=torch.int32, device=batch.device)
|
|
234
238
|
|
|
235
239
|
if bs != len(sampling_info):
|
|
236
240
|
sampling_info = copy.deepcopy(sampling_info)
|
|
@@ -254,7 +258,7 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
254
258
|
linear_penalty = torch.zeros(
|
|
255
259
|
(bs, logits_output.next_token_logits.shape[1]),
|
|
256
260
|
dtype=torch.float32,
|
|
257
|
-
device=
|
|
261
|
+
device=batch.device,
|
|
258
262
|
)
|
|
259
263
|
sampling_info.apply_logits_bias(linear_penalty)
|
|
260
264
|
logits_output.next_token_logits.add_(
|
|
@@ -276,11 +280,10 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
276
280
|
"Falling back to greedy verification."
|
|
277
281
|
)
|
|
278
282
|
|
|
279
|
-
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
|
|
283
|
+
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE or _is_npu:
|
|
280
284
|
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
|
281
285
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
|
282
|
-
|
|
283
|
-
verify_tree_greedy(
|
|
286
|
+
predict, accept_index, accept_length = verify_tree_greedy_func(
|
|
284
287
|
predicts=predict, # mutable
|
|
285
288
|
accept_index=accept_index, # mutable
|
|
286
289
|
accept_token_num=accept_length, # mutable
|
|
@@ -289,7 +292,9 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
289
292
|
retrive_next_token=self.retrive_next_token,
|
|
290
293
|
retrive_next_sibling=self.retrive_next_sibling,
|
|
291
294
|
target_predict=target_predict,
|
|
295
|
+
topk=self.topk,
|
|
292
296
|
)
|
|
297
|
+
|
|
293
298
|
else:
|
|
294
299
|
# apply temperature and get target probs
|
|
295
300
|
expanded_temperature = torch.repeat_interleave(
|
|
@@ -315,14 +320,16 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
315
320
|
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
|
316
321
|
|
|
317
322
|
draft_probs = torch.zeros(
|
|
318
|
-
target_probs.shape, dtype=torch.float32, device=
|
|
323
|
+
target_probs.shape, dtype=torch.float32, device=batch.device
|
|
319
324
|
)
|
|
320
325
|
|
|
321
326
|
# coins for rejection sampling
|
|
322
|
-
coins = torch.rand_like(
|
|
327
|
+
coins = torch.rand_like(
|
|
328
|
+
candidates, dtype=torch.float32, device=batch.device
|
|
329
|
+
)
|
|
323
330
|
# coins for final sampling
|
|
324
331
|
coins_for_final_sampling = torch.rand(
|
|
325
|
-
(bs,), dtype=torch.float32, device=
|
|
332
|
+
(bs,), dtype=torch.float32, device=batch.device
|
|
326
333
|
)
|
|
327
334
|
tree_speculative_sampling_target_only(
|
|
328
335
|
predicts=predict, # mutable
|
|
@@ -468,14 +475,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
468
475
|
if not has_finished:
|
|
469
476
|
if page_size == 1 or self.topk == 1:
|
|
470
477
|
batch.out_cache_loc = batch.out_cache_loc[accept_index]
|
|
471
|
-
|
|
478
|
+
assign_req_to_token_pool_func(
|
|
472
479
|
batch.req_pool_indices,
|
|
473
480
|
batch.req_to_token_pool.req_to_token,
|
|
474
481
|
batch.seq_lens,
|
|
475
482
|
batch.seq_lens + accept_length + 1,
|
|
476
483
|
batch.out_cache_loc,
|
|
477
|
-
|
|
478
|
-
next_power_of_2(bs),
|
|
484
|
+
bs,
|
|
479
485
|
)
|
|
480
486
|
else:
|
|
481
487
|
batch.out_cache_loc = tgt_cache_loc
|
|
@@ -501,14 +507,13 @@ class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
|
|
501
507
|
)
|
|
502
508
|
else:
|
|
503
509
|
if page_size == 1 or self.topk == 1:
|
|
504
|
-
|
|
510
|
+
assign_req_to_token_pool_func(
|
|
505
511
|
batch.req_pool_indices,
|
|
506
512
|
batch.req_to_token_pool.req_to_token,
|
|
507
513
|
batch.seq_lens,
|
|
508
514
|
batch.seq_lens + accept_length + 1,
|
|
509
515
|
batch.out_cache_loc[accept_index],
|
|
510
|
-
|
|
511
|
-
next_power_of_2(bs),
|
|
516
|
+
bs,
|
|
512
517
|
)
|
|
513
518
|
batch.seq_lens.add_(accept_length + 1)
|
|
514
519
|
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
|
@@ -695,17 +700,18 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
|
|
695
700
|
paged_kernel_lens_sum: int,
|
|
696
701
|
req_to_token: torch.Tensor,
|
|
697
702
|
):
|
|
703
|
+
device = req_pool_indices.device
|
|
698
704
|
bs = self.accept_length.numel()
|
|
699
|
-
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=
|
|
705
|
+
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
|
|
700
706
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
|
701
|
-
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=
|
|
707
|
+
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device)
|
|
702
708
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
|
703
709
|
|
|
704
710
|
if paged_kernel_lens_sum is None:
|
|
705
711
|
paged_kernel_lens_sum = cum_kv_seq_len[-1]
|
|
706
712
|
|
|
707
713
|
kv_indices = torch.empty(
|
|
708
|
-
paged_kernel_lens_sum, dtype=torch.int32, device=
|
|
714
|
+
paged_kernel_lens_sum, dtype=torch.int32, device=device
|
|
709
715
|
)
|
|
710
716
|
|
|
711
717
|
create_flashinfer_kv_indices_triton[(bs,)](
|
|
@@ -23,11 +23,16 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
|
23
23
|
)
|
|
24
24
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
25
25
|
from sglang.srt.server_args import get_global_server_args
|
|
26
|
+
from sglang.srt.speculative.eagle_utils import verify_tree_greedy_func
|
|
26
27
|
from sglang.srt.speculative.spec_utils import (
|
|
27
28
|
SIMULATE_ACC_LEN,
|
|
28
29
|
generate_simulated_accept_index,
|
|
29
30
|
)
|
|
30
|
-
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
|
|
31
|
+
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, is_npu, next_power_of_2
|
|
32
|
+
|
|
33
|
+
_is_cuda = is_cuda()
|
|
34
|
+
_is_hip = is_hip()
|
|
35
|
+
_is_npu = is_npu()
|
|
31
36
|
|
|
32
37
|
if TYPE_CHECKING:
|
|
33
38
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
@@ -41,11 +46,8 @@ if is_cuda():
|
|
|
41
46
|
top_k_renorm_prob,
|
|
42
47
|
top_p_renorm_prob,
|
|
43
48
|
tree_speculative_sampling_target_only,
|
|
44
|
-
verify_tree_greedy,
|
|
45
49
|
)
|
|
46
50
|
from sgl_kernel.top_k import fast_topk
|
|
47
|
-
elif is_hip():
|
|
48
|
-
from sgl_kernel import verify_tree_greedy
|
|
49
51
|
|
|
50
52
|
|
|
51
53
|
@triton.jit
|
|
@@ -78,7 +80,7 @@ def assign_draft_cache_locs_page_size_1(
|
|
|
78
80
|
@dataclass
|
|
79
81
|
class EagleDraftInputV2Mixin:
|
|
80
82
|
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
|
|
81
|
-
from sglang.srt.speculative.spec_utils import
|
|
83
|
+
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func
|
|
82
84
|
|
|
83
85
|
bs = batch.batch_size()
|
|
84
86
|
|
|
@@ -112,15 +114,15 @@ class EagleDraftInputV2Mixin:
|
|
|
112
114
|
extend_num_tokens,
|
|
113
115
|
)
|
|
114
116
|
|
|
115
|
-
|
|
117
|
+
assign_req_to_token_pool_func(
|
|
116
118
|
batch.req_pool_indices,
|
|
117
119
|
batch.req_to_token_pool.req_to_token,
|
|
118
120
|
self.allocate_lens,
|
|
119
121
|
new_allocate_lens,
|
|
120
122
|
out_cache_loc,
|
|
121
|
-
|
|
122
|
-
next_power_of_2(bs),
|
|
123
|
+
bs,
|
|
123
124
|
)
|
|
125
|
+
|
|
124
126
|
self.allocate_lens = new_allocate_lens
|
|
125
127
|
|
|
126
128
|
# FIXME(lsyin): make this sync optional
|
|
@@ -199,22 +201,16 @@ class EagleVerifyInputV2Mixin:
|
|
|
199
201
|
bs = len(batch.req_pool_indices)
|
|
200
202
|
batch.input_ids = self.draft_token
|
|
201
203
|
device = batch.input_ids.device
|
|
202
|
-
batch.out_cache_loc =
|
|
203
|
-
|
|
204
|
-
|
|
204
|
+
batch.out_cache_loc = assign_extend_cache_locs_func(
|
|
205
|
+
req_pool_indices=batch.req_pool_indices,
|
|
206
|
+
req_to_token=req_to_token_pool.req_to_token,
|
|
207
|
+
start_offset=batch.seq_lens,
|
|
208
|
+
end_offset=batch.seq_lens + self.draft_token_num,
|
|
209
|
+
batch_size=bs,
|
|
210
|
+
draft_token_num=self.draft_token_num,
|
|
205
211
|
device=device,
|
|
206
212
|
)
|
|
207
213
|
|
|
208
|
-
assign_extend_cache_locs[(bs,)](
|
|
209
|
-
batch.req_pool_indices,
|
|
210
|
-
req_to_token_pool.req_to_token,
|
|
211
|
-
batch.seq_lens,
|
|
212
|
-
batch.seq_lens + self.draft_token_num,
|
|
213
|
-
batch.out_cache_loc,
|
|
214
|
-
req_to_token_pool.req_to_token.shape[1],
|
|
215
|
-
next_power_of_2(bs),
|
|
216
|
-
)
|
|
217
|
-
|
|
218
214
|
# Get a forward batch
|
|
219
215
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
|
220
216
|
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
@@ -258,11 +254,10 @@ class EagleVerifyInputV2Mixin:
|
|
|
258
254
|
accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
|
|
259
255
|
|
|
260
256
|
# Sample tokens
|
|
261
|
-
if sampling_info.is_all_greedy:
|
|
257
|
+
if sampling_info.is_all_greedy or _is_npu:
|
|
262
258
|
target_predict = torch.argmax(next_token_logits, dim=-1)
|
|
263
259
|
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
|
264
|
-
|
|
265
|
-
verify_tree_greedy(
|
|
260
|
+
predict, accept_index, accept_length = verify_tree_greedy_func(
|
|
266
261
|
predicts=predict, # mutable
|
|
267
262
|
accept_index=accept_index, # mutable
|
|
268
263
|
accept_token_num=accept_length, # mutable
|
|
@@ -271,6 +266,7 @@ class EagleVerifyInputV2Mixin:
|
|
|
271
266
|
retrive_next_token=self.retrive_next_token,
|
|
272
267
|
retrive_next_sibling=self.retrive_next_sibling,
|
|
273
268
|
target_predict=target_predict,
|
|
269
|
+
topk=self.topk,
|
|
274
270
|
)
|
|
275
271
|
else:
|
|
276
272
|
# Apply temperature and get target probs
|
|
@@ -338,7 +334,7 @@ class EagleVerifyInputV2Mixin:
|
|
|
338
334
|
return predict, accept_length, accept_index
|
|
339
335
|
|
|
340
336
|
|
|
341
|
-
@torch.compile(dynamic=True)
|
|
337
|
+
@torch.compile(dynamic=True, disable=_is_npu)
|
|
342
338
|
def select_top_k_tokens_tmp(
|
|
343
339
|
i: int,
|
|
344
340
|
topk_p: torch.Tensor,
|
|
@@ -456,3 +452,50 @@ def assign_extend_cache_locs(
|
|
|
456
452
|
tl.store(out_cache_ptr + save_offset, data, mask=mask)
|
|
457
453
|
load_offset += BLOCK_SIZE
|
|
458
454
|
save_offset += BLOCK_SIZE
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def assign_extend_cache_locs_func(
|
|
458
|
+
req_pool_indices: torch.Tensor,
|
|
459
|
+
req_to_token: torch.Tensor,
|
|
460
|
+
start_offset: torch.Tensor,
|
|
461
|
+
end_offset: torch.Tensor,
|
|
462
|
+
batch_size: int,
|
|
463
|
+
draft_token_num: int,
|
|
464
|
+
device,
|
|
465
|
+
) -> torch.Tensor:
|
|
466
|
+
if _is_cuda or _is_hip:
|
|
467
|
+
out_cache_loc = torch.empty(
|
|
468
|
+
(batch_size * draft_token_num,),
|
|
469
|
+
dtype=torch.int64,
|
|
470
|
+
device=device,
|
|
471
|
+
)
|
|
472
|
+
assign_extend_cache_locs[(batch_size,)](
|
|
473
|
+
req_pool_indices,
|
|
474
|
+
req_to_token,
|
|
475
|
+
start_offset,
|
|
476
|
+
end_offset,
|
|
477
|
+
out_cache_loc,
|
|
478
|
+
req_to_token.shape[1],
|
|
479
|
+
next_power_of_2(batch_size),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
return out_cache_loc
|
|
483
|
+
|
|
484
|
+
elif _is_npu:
|
|
485
|
+
import sgl_kernel_npu # noqa: F401
|
|
486
|
+
|
|
487
|
+
out_cache_loc = torch.empty(
|
|
488
|
+
(batch_size * draft_token_num,),
|
|
489
|
+
dtype=torch.int32,
|
|
490
|
+
device=device,
|
|
491
|
+
)
|
|
492
|
+
torch.ops.npu.cache_loc_update(
|
|
493
|
+
req_pool_indices,
|
|
494
|
+
req_to_token,
|
|
495
|
+
start_offset,
|
|
496
|
+
end_offset,
|
|
497
|
+
out_cache_loc,
|
|
498
|
+
)
|
|
499
|
+
out_cache_loc = out_cache_loc.to(dtype=torch.int64)
|
|
500
|
+
|
|
501
|
+
return out_cache_loc
|