sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -20,7 +20,7 @@ import logging
|
|
20
20
|
import os
|
21
21
|
import random
|
22
22
|
import tempfile
|
23
|
-
from typing import List, Literal, Optional, Union
|
23
|
+
from typing import Dict, List, Literal, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.connector import ConnectorType
|
26
26
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
@@ -121,6 +121,17 @@ NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"
|
|
121
121
|
|
122
122
|
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
123
123
|
|
124
|
+
MOE_RUNNER_BACKEND_CHOICES = [
|
125
|
+
"auto",
|
126
|
+
"deep_gemm",
|
127
|
+
"triton",
|
128
|
+
"triton_kernel",
|
129
|
+
"flashinfer_trtllm",
|
130
|
+
"flashinfer_cutlass",
|
131
|
+
"flashinfer_mxfp4",
|
132
|
+
"flashinfer_cutedsl",
|
133
|
+
]
|
134
|
+
|
124
135
|
|
125
136
|
# Allow external code to add more choices
|
126
137
|
def add_load_format_choices(choices):
|
@@ -143,6 +154,10 @@ def add_grammar_backend_choices(choices):
|
|
143
154
|
GRAMMAR_BACKEND_CHOICES.extend(choices)
|
144
155
|
|
145
156
|
|
157
|
+
def add_moe_runner_backend_choices(choices):
|
158
|
+
MOE_RUNNER_BACKEND_CHOICES.extend(choices)
|
159
|
+
|
160
|
+
|
146
161
|
def add_deterministic_attention_backend_choices(choices):
|
147
162
|
DETERMINISTIC_ATTENTION_BACKEND_CHOICES.extend(choices)
|
148
163
|
|
@@ -162,6 +177,7 @@ class ServerArgs:
|
|
162
177
|
load_format: str = "auto"
|
163
178
|
model_loader_extra_config: str = "{}"
|
164
179
|
trust_remote_code: bool = False
|
180
|
+
modelopt_quant: Optional[Union[str, Dict]] = None
|
165
181
|
context_length: Optional[int] = None
|
166
182
|
is_embedding: bool = False
|
167
183
|
enable_multimodal: Optional[bool] = None
|
@@ -204,7 +220,7 @@ class ServerArgs:
|
|
204
220
|
device: Optional[str] = None
|
205
221
|
tp_size: int = 1
|
206
222
|
pp_size: int = 1
|
207
|
-
|
223
|
+
pp_max_micro_batch_size: Optional[int] = None
|
208
224
|
stream_interval: int = 1
|
209
225
|
stream_output: bool = False
|
210
226
|
random_seed: Optional[int] = None
|
@@ -251,6 +267,7 @@ class ServerArgs:
|
|
251
267
|
reasoning_parser: Optional[str] = None
|
252
268
|
tool_call_parser: Optional[str] = None
|
253
269
|
tool_server: Optional[str] = None
|
270
|
+
sampling_defaults: str = "model"
|
254
271
|
|
255
272
|
# Data parallelism
|
256
273
|
dp_size: int = 1
|
@@ -313,14 +330,7 @@ class ServerArgs:
|
|
313
330
|
# Expert parallelism
|
314
331
|
ep_size: int = 1
|
315
332
|
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
316
|
-
moe_runner_backend:
|
317
|
-
"auto",
|
318
|
-
"triton",
|
319
|
-
"triton_kernel",
|
320
|
-
"flashinfer_trtllm",
|
321
|
-
"flashinfer_cutlass",
|
322
|
-
"flashinfer_mxfp4",
|
323
|
-
] = "auto"
|
333
|
+
moe_runner_backend: str = "auto"
|
324
334
|
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
|
325
335
|
enable_flashinfer_allreduce_fusion: bool = False
|
326
336
|
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
@@ -372,6 +382,12 @@ class ServerArgs:
|
|
372
382
|
offload_prefetch_step: int = 1
|
373
383
|
offload_mode: str = "cpu"
|
374
384
|
|
385
|
+
# Scoring configuration
|
386
|
+
# Delimiter token ID used to combine Query and Items into a single sequence for multi-item scoring.
|
387
|
+
# Format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
388
|
+
# This enables efficient batch processing of multiple items against a single query.
|
389
|
+
multi_item_scoring_delimiter: Optional[Union[int]] = None
|
390
|
+
|
375
391
|
# Optimization/debug options
|
376
392
|
disable_radix_cache: bool = False
|
377
393
|
cuda_graph_max_bs: Optional[int] = None
|
@@ -454,6 +470,19 @@ class ServerArgs:
|
|
454
470
|
enable_pdmux: bool = False
|
455
471
|
sm_group_num: int = 3
|
456
472
|
|
473
|
+
def get_attention_backends(server_args):
|
474
|
+
prefill_attention_backend_str = (
|
475
|
+
server_args.prefill_attention_backend
|
476
|
+
if server_args.prefill_attention_backend
|
477
|
+
else server_args.attention_backend
|
478
|
+
)
|
479
|
+
decode_attention_backend_str = (
|
480
|
+
server_args.decode_attention_backend
|
481
|
+
if server_args.decode_attention_backend
|
482
|
+
else server_args.attention_backend
|
483
|
+
)
|
484
|
+
return prefill_attention_backend_str, decode_attention_backend_str
|
485
|
+
|
457
486
|
def __post_init__(self):
|
458
487
|
"""
|
459
488
|
Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
|
@@ -527,7 +556,13 @@ class ServerArgs:
|
|
527
556
|
self._handle_other_validations()
|
528
557
|
|
529
558
|
def _handle_deprecated_args(self):
|
530
|
-
|
559
|
+
# handle deprecated tool call parsers
|
560
|
+
deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"}
|
561
|
+
if self.tool_call_parser in deprecated_tool_call_parsers:
|
562
|
+
logger.warning(
|
563
|
+
f"The tool_call_parser '{self.tool_call_parser}' is deprecated. Please use '{deprecated_tool_call_parsers[self.tool_call_parser]}' instead."
|
564
|
+
)
|
565
|
+
self.tool_call_parser = deprecated_tool_call_parsers[self.tool_call_parser]
|
531
566
|
|
532
567
|
def _handle_missing_default_values(self):
|
533
568
|
if self.tokenizer_path is None:
|
@@ -732,20 +767,28 @@ class ServerArgs:
|
|
732
767
|
hf_config = self.get_hf_config()
|
733
768
|
model_arch = hf_config.architectures[0]
|
734
769
|
if model_arch in ["GptOssForCausalLM"]:
|
735
|
-
if
|
770
|
+
if (
|
771
|
+
self.attention_backend is None
|
772
|
+
and self.prefill_attention_backend is None
|
773
|
+
and self.decode_attention_backend is None
|
774
|
+
):
|
736
775
|
if is_cuda() and is_sm100_supported():
|
737
776
|
self.attention_backend = "trtllm_mha"
|
738
777
|
elif is_cuda() and is_sm90_supported():
|
739
778
|
self.attention_backend = "fa3"
|
740
779
|
else:
|
741
780
|
self.attention_backend = "triton"
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
)
|
781
|
+
|
782
|
+
supported_backends = ["triton", "trtllm_mha", "fa3", "fa4"]
|
783
|
+
prefill_attn_backend, decode_attn_backend = self.get_attention_backends()
|
746
784
|
assert (
|
747
|
-
|
748
|
-
|
785
|
+
prefill_attn_backend in supported_backends
|
786
|
+
and decode_attn_backend in supported_backends
|
787
|
+
), (
|
788
|
+
f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got the following backends\n"
|
789
|
+
f"- Prefill: {prefill_attn_backend}\n"
|
790
|
+
f"- Decode: {decode_attn_backend}\n"
|
791
|
+
)
|
749
792
|
|
750
793
|
if is_sm100_supported():
|
751
794
|
if not self.enable_dp_attention:
|
@@ -820,9 +863,6 @@ class ServerArgs:
|
|
820
863
|
self.page_size = 64
|
821
864
|
logger.warning("Setting page size to 64 for DeepSeek NSA.")
|
822
865
|
|
823
|
-
self.mem_fraction_static = 0.8
|
824
|
-
logger.warning("Setting mem fraction static to 0.8 for DeepSeek NSA.")
|
825
|
-
|
826
866
|
# For Hopper, we support both bf16 and fp8 kv cache; for Blackwell, we support fp8 only currently
|
827
867
|
import torch
|
828
868
|
|
@@ -1455,6 +1495,14 @@ class ServerArgs:
|
|
1455
1495
|
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
1456
1496
|
"default to 1.0, which may cause accuracy issues. ",
|
1457
1497
|
)
|
1498
|
+
parser.add_argument(
|
1499
|
+
"--modelopt-quant",
|
1500
|
+
type=str,
|
1501
|
+
default=ServerArgs.modelopt_quant,
|
1502
|
+
help="The ModelOpt quantization configuration. "
|
1503
|
+
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
|
1504
|
+
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt",
|
1505
|
+
)
|
1458
1506
|
parser.add_argument(
|
1459
1507
|
"--kv-cache-dtype",
|
1460
1508
|
type=str,
|
@@ -1590,9 +1638,9 @@ class ServerArgs:
|
|
1590
1638
|
help="The pipeline parallelism size.",
|
1591
1639
|
)
|
1592
1640
|
parser.add_argument(
|
1593
|
-
"--max-micro-batch-size",
|
1641
|
+
"--pp-max-micro-batch-size",
|
1594
1642
|
type=int,
|
1595
|
-
default=ServerArgs.
|
1643
|
+
default=ServerArgs.pp_max_micro_batch_size,
|
1596
1644
|
help="The maximum micro batch size in pipeline parallelism.",
|
1597
1645
|
)
|
1598
1646
|
parser.add_argument(
|
@@ -1857,6 +1905,16 @@ class ServerArgs:
|
|
1857
1905
|
default=ServerArgs.tool_call_parser,
|
1858
1906
|
help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
|
1859
1907
|
)
|
1908
|
+
parser.add_argument(
|
1909
|
+
"--sampling-defaults",
|
1910
|
+
type=str,
|
1911
|
+
choices=["openai", "model"],
|
1912
|
+
default=ServerArgs.sampling_defaults,
|
1913
|
+
help="Where to get default sampling parameters. "
|
1914
|
+
"'openai' uses SGLang/OpenAI defaults (temperature=1.0, top_p=1.0, etc.). "
|
1915
|
+
"'model' uses the model's generation_config.json to get the recommended "
|
1916
|
+
"sampling parameters if available. Default is 'model'.",
|
1917
|
+
)
|
1860
1918
|
parser.add_argument(
|
1861
1919
|
"--tool-server",
|
1862
1920
|
type=str,
|
@@ -2165,15 +2223,7 @@ class ServerArgs:
|
|
2165
2223
|
parser.add_argument(
|
2166
2224
|
"--moe-runner-backend",
|
2167
2225
|
type=str,
|
2168
|
-
choices=
|
2169
|
-
"auto",
|
2170
|
-
"triton",
|
2171
|
-
"triton_kernel",
|
2172
|
-
"flashinfer_trtllm",
|
2173
|
-
"flashinfer_cutlass",
|
2174
|
-
"flashinfer_mxfp4",
|
2175
|
-
"flashinfer_cutedsl",
|
2176
|
-
],
|
2226
|
+
choices=MOE_RUNNER_BACKEND_CHOICES,
|
2177
2227
|
default=ServerArgs.moe_runner_backend,
|
2178
2228
|
help="Choose the runner backend for MoE.",
|
2179
2229
|
)
|
@@ -2287,7 +2337,13 @@ class ServerArgs:
|
|
2287
2337
|
choices=["float32", "bfloat16"],
|
2288
2338
|
help="The data type of the SSM states in mamba cache.",
|
2289
2339
|
)
|
2290
|
-
|
2340
|
+
# Args for multi-item-scoring
|
2341
|
+
parser.add_argument(
|
2342
|
+
"--multi-item-scoring-delimiter",
|
2343
|
+
type=int,
|
2344
|
+
default=ServerArgs.multi_item_scoring_delimiter,
|
2345
|
+
help="Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: Query<delimiter>Item1<delimiter>Item2<delimiter>... This enables efficient batch processing of multiple items against a single query.",
|
2346
|
+
)
|
2291
2347
|
# Hierarchical cache
|
2292
2348
|
parser.add_argument(
|
2293
2349
|
"--enable-hierarchical-cache",
|
@@ -2957,6 +3013,17 @@ class ServerArgs:
|
|
2957
3013
|
"lof",
|
2958
3014
|
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
|
2959
3015
|
|
3016
|
+
# Check multi-item scoring
|
3017
|
+
if self.multi_item_scoring_delimiter is not None:
|
3018
|
+
assert self.disable_radix_cache, (
|
3019
|
+
"Multi-item scoring requires radix cache to be disabled. "
|
3020
|
+
"Please set --disable-radix-cache when using --multi-item-scoring-delimiter."
|
3021
|
+
)
|
3022
|
+
assert self.chunked_prefill_size == -1, (
|
3023
|
+
"Multi-item scoring requires chunked prefill to be disabled. "
|
3024
|
+
"Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter."
|
3025
|
+
)
|
3026
|
+
|
2960
3027
|
def check_lora_server_args(self):
|
2961
3028
|
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
2962
3029
|
|
@@ -19,11 +19,11 @@ from sglang.srt.managers.schedule_batch import (
|
|
19
19
|
get_last_loc,
|
20
20
|
global_server_args_dict,
|
21
21
|
)
|
22
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
22
23
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
23
24
|
from sglang.srt.model_executor.forward_batch_info import (
|
24
25
|
CaptureHiddenMode,
|
25
26
|
ForwardBatch,
|
26
|
-
ForwardBatchOutput,
|
27
27
|
ForwardMode,
|
28
28
|
)
|
29
29
|
from sglang.srt.server_args import ServerArgs
|
@@ -429,7 +429,7 @@ class EAGLEWorker(TpModelWorker):
|
|
429
429
|
def draft_model_runner(self):
|
430
430
|
return self.model_runner
|
431
431
|
|
432
|
-
def forward_batch_generation(self, batch: ScheduleBatch) ->
|
432
|
+
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
|
433
433
|
"""Run speculative decoding forward.
|
434
434
|
|
435
435
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
@@ -449,7 +449,7 @@ class EAGLEWorker(TpModelWorker):
|
|
449
449
|
self.forward_draft_extend(
|
450
450
|
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
451
451
|
)
|
452
|
-
return
|
452
|
+
return GenerationBatchResult(
|
453
453
|
logits_output=logits_output,
|
454
454
|
next_token_ids=next_token_ids,
|
455
455
|
num_accepted_tokens=0,
|
@@ -472,7 +472,7 @@ class EAGLEWorker(TpModelWorker):
|
|
472
472
|
# decode is not finished
|
473
473
|
self.forward_draft_extend_after_decode(batch)
|
474
474
|
|
475
|
-
return
|
475
|
+
return GenerationBatchResult(
|
476
476
|
logits_output=logits_output,
|
477
477
|
next_token_ids=verify_output.verified_id,
|
478
478
|
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
|
@@ -513,12 +513,10 @@ class EAGLEWorker(TpModelWorker):
|
|
513
513
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
514
514
|
model_worker_batch = batch.get_model_worker_batch()
|
515
515
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
516
|
-
|
517
|
-
model_worker_batch
|
518
|
-
)
|
516
|
+
batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
|
519
517
|
logits_output, next_token_ids = (
|
520
|
-
|
521
|
-
|
518
|
+
batch_result.logits_output,
|
519
|
+
batch_result.next_token_ids,
|
522
520
|
)
|
523
521
|
return (
|
524
522
|
logits_output,
|
@@ -822,12 +820,12 @@ class EAGLEWorker(TpModelWorker):
|
|
822
820
|
).cpu()
|
823
821
|
|
824
822
|
# Forward
|
825
|
-
|
823
|
+
batch_result = self.target_worker.forward_batch_generation(
|
826
824
|
model_worker_batch, is_verify=True
|
827
825
|
)
|
828
826
|
logits_output, can_run_cuda_graph = (
|
829
|
-
|
830
|
-
|
827
|
+
batch_result.logits_output,
|
828
|
+
batch_result.can_run_cuda_graph,
|
831
829
|
)
|
832
830
|
|
833
831
|
vocab_mask = None
|
@@ -868,7 +866,7 @@ class EAGLEWorker(TpModelWorker):
|
|
868
866
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
869
867
|
|
870
868
|
# QQ: can be optimized
|
871
|
-
if self.target_worker.model_runner.
|
869
|
+
if self.target_worker.model_runner.hybrid_gdn_config is not None:
|
872
870
|
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
873
871
|
accepted_length = (
|
874
872
|
torch.tensor(
|
@@ -6,11 +6,12 @@ import torch
|
|
6
6
|
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
|
7
7
|
|
8
8
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
9
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
9
10
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
10
|
-
from sglang.srt.model_executor.forward_batch_info import
|
11
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
11
12
|
from sglang.srt.server_args import ServerArgs
|
12
13
|
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
13
|
-
from sglang.srt.speculative.
|
14
|
+
from sglang.srt.speculative.ngram_info import NgramVerifyInput
|
14
15
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
15
16
|
|
16
17
|
logger = logging.getLogger(__name__)
|
@@ -207,18 +208,18 @@ class NGRAMWorker:
|
|
207
208
|
batch_tokens.append(put_ids)
|
208
209
|
self.ngram_cache.batch_put(batch_tokens)
|
209
210
|
|
210
|
-
def forward_batch_generation(self, batch: ScheduleBatch) ->
|
211
|
+
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
|
211
212
|
self._prepare_for_speculative_decoding(batch)
|
212
213
|
model_worker_batch = batch.get_model_worker_batch()
|
213
214
|
num_accepted_tokens = 0
|
214
215
|
|
215
216
|
if model_worker_batch.forward_mode.is_target_verify():
|
216
|
-
|
217
|
+
batch_result = self.target_worker.forward_batch_generation(
|
217
218
|
model_worker_batch, is_verify=True
|
218
219
|
)
|
219
220
|
logits_output, can_run_cuda_graph = (
|
220
|
-
|
221
|
-
|
221
|
+
batch_result.logits_output,
|
222
|
+
batch_result.can_run_cuda_graph,
|
222
223
|
)
|
223
224
|
verify_input = model_worker_batch.spec_info
|
224
225
|
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
@@ -228,16 +229,16 @@ class NGRAMWorker:
|
|
228
229
|
batch.forward_mode = ForwardMode.DECODE
|
229
230
|
|
230
231
|
else:
|
231
|
-
|
232
|
+
batch_result = self.target_worker.forward_batch_generation(
|
232
233
|
model_worker_batch
|
233
234
|
)
|
234
235
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
235
|
-
|
236
|
-
|
237
|
-
|
236
|
+
batch_result.logits_output,
|
237
|
+
batch_result.next_token_ids,
|
238
|
+
batch_result.can_run_cuda_graph,
|
238
239
|
)
|
239
240
|
|
240
|
-
return
|
241
|
+
return GenerationBatchResult(
|
241
242
|
logits_output=logits_output,
|
242
243
|
next_token_ids=next_token_ids,
|
243
244
|
num_accepted_tokens=num_accepted_tokens,
|
sglang/srt/two_batch_overlap.py
CHANGED
sglang/srt/utils/common.py
CHANGED
@@ -518,6 +518,24 @@ def make_layers(
|
|
518
518
|
return modules, start_layer, end_layer
|
519
519
|
|
520
520
|
|
521
|
+
def make_layers_non_pp(
|
522
|
+
num_hidden_layers: int,
|
523
|
+
layer_fn: LayerFn,
|
524
|
+
prefix: str = "",
|
525
|
+
) -> torch.nn.ModuleList:
|
526
|
+
from sglang.srt.offloader import get_offloader
|
527
|
+
|
528
|
+
layers = torch.nn.ModuleList(
|
529
|
+
get_offloader().wrap_modules(
|
530
|
+
(
|
531
|
+
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
|
532
|
+
for idx in range(num_hidden_layers)
|
533
|
+
)
|
534
|
+
)
|
535
|
+
)
|
536
|
+
return layers
|
537
|
+
|
538
|
+
|
521
539
|
cmo_stream = None
|
522
540
|
|
523
541
|
|
@@ -45,6 +45,7 @@ from sglang.srt.configs import (
|
|
45
45
|
KimiVLConfig,
|
46
46
|
LongcatFlashConfig,
|
47
47
|
MultiModalityConfig,
|
48
|
+
NemotronHConfig,
|
48
49
|
Qwen3NextConfig,
|
49
50
|
Step3VLConfig,
|
50
51
|
)
|
@@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|
66
67
|
FalconH1Config.model_type: FalconH1Config,
|
67
68
|
DotsVLMConfig.model_type: DotsVLMConfig,
|
68
69
|
DotsOCRConfig.model_type: DotsOCRConfig,
|
70
|
+
NemotronHConfig.model_type: NemotronHConfig,
|
69
71
|
}
|
70
72
|
|
71
73
|
for name, cls in _CONFIG_REGISTRY.items():
|
@@ -0,0 +1 @@
|
|
1
|
+
"""LongBench-v2 auxiliary utilities and validation scripts."""
|