sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
|
71
71
|
configure_logger,
|
72
72
|
get_bool_env_var,
|
73
73
|
kill_process_tree,
|
74
|
+
require_mlp_sync,
|
75
|
+
require_mlp_tp_gather,
|
74
76
|
set_gpu_proc_affinity,
|
75
77
|
suppress_other_loggers,
|
76
78
|
)
|
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
|
|
243
245
|
enable_custom_logit_processor=False,
|
244
246
|
)
|
245
247
|
batch.prepare_for_extend()
|
246
|
-
|
248
|
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
247
249
|
model_worker_batch = batch.get_model_worker_batch()
|
248
250
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
249
251
|
logits_output, _ = model_runner.forward(forward_batch)
|
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
|
|
255
257
|
def decode(input_token_ids, batch, model_runner):
|
256
258
|
batch.output_ids = input_token_ids
|
257
259
|
batch.prepare_for_decode()
|
258
|
-
|
260
|
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
259
261
|
model_worker_batch = batch.get_model_worker_batch()
|
260
262
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
261
263
|
logits_output, _ = model_runner.forward(forward_batch)
|
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
|
|
263
265
|
return next_token_ids, logits_output.next_token_logits
|
264
266
|
|
265
267
|
|
266
|
-
def
|
267
|
-
if model_runner.server_args
|
268
|
-
Scheduler.
|
268
|
+
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
269
|
+
if require_mlp_sync(model_runner.server_args):
|
270
|
+
Scheduler.prepare_mlp_sync_batch_raw(
|
269
271
|
batch,
|
270
272
|
dp_size=model_runner.server_args.dp_size,
|
271
273
|
attn_tp_size=1,
|
272
|
-
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
273
274
|
tp_cpu_group=model_runner.tp_group.cpu_group,
|
274
275
|
get_idle_batch=None,
|
275
276
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
276
277
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
277
278
|
speculative_num_draft_tokens=None,
|
279
|
+
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
278
280
|
)
|
279
281
|
|
280
282
|
|
sglang/srt/_custom_ops.py
CHANGED
@@ -4,7 +4,7 @@ from typing import List, Tuple
|
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
|
-
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
|
7
|
+
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
8
8
|
|
9
9
|
logger = logging.getLogger(__name__)
|
10
10
|
use_vllm_custom_allreduce = get_bool_env_var(
|
@@ -25,7 +25,7 @@ if not is_hpu():
|
|
25
25
|
logger.warning("Failed to import from custom_ar with %r", e)
|
26
26
|
|
27
27
|
|
28
|
-
if not is_hip():
|
28
|
+
if not is_hip() and not is_npu():
|
29
29
|
if use_vllm_custom_allreduce:
|
30
30
|
custom_op = torch.ops._C_custom_ar
|
31
31
|
else:
|
@@ -15,12 +15,10 @@
|
|
15
15
|
|
16
16
|
|
17
17
|
import dataclasses
|
18
|
-
import json
|
19
18
|
import logging
|
20
|
-
import os
|
21
19
|
from enum import auto
|
22
20
|
|
23
|
-
from sglang.srt.
|
21
|
+
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
24
22
|
|
25
23
|
logger = logging.getLogger(__name__)
|
26
24
|
completion_template_name = None
|
@@ -57,46 +55,6 @@ class CompletionTemplate:
|
|
57
55
|
completion_templates: dict[str, CompletionTemplate] = {}
|
58
56
|
|
59
57
|
|
60
|
-
def load_completion_template_for_openai_api(completion_template_arg):
|
61
|
-
global completion_template_name
|
62
|
-
|
63
|
-
logger.info(
|
64
|
-
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
|
65
|
-
)
|
66
|
-
|
67
|
-
if not completion_template_exists(completion_template_arg):
|
68
|
-
if not os.path.exists(completion_template_arg):
|
69
|
-
raise RuntimeError(
|
70
|
-
f"Completion template {completion_template_arg} is not a built-in template name "
|
71
|
-
"or a valid completion template file path."
|
72
|
-
)
|
73
|
-
|
74
|
-
assert completion_template_arg.endswith(
|
75
|
-
".json"
|
76
|
-
), "unrecognized format of completion template file"
|
77
|
-
with open(completion_template_arg, "r") as filep:
|
78
|
-
template = json.load(filep)
|
79
|
-
try:
|
80
|
-
fim_position = FimPosition[template["fim_position"]]
|
81
|
-
except KeyError:
|
82
|
-
raise ValueError(
|
83
|
-
f"Unknown fim position: {template['fim_position']}"
|
84
|
-
) from None
|
85
|
-
register_completion_template(
|
86
|
-
CompletionTemplate(
|
87
|
-
name=template["name"],
|
88
|
-
fim_begin_token=template["fim_begin_token"],
|
89
|
-
fim_middle_token=template["fim_middle_token"],
|
90
|
-
fim_end_token=template["fim_end_token"],
|
91
|
-
fim_position=fim_position,
|
92
|
-
),
|
93
|
-
override=True,
|
94
|
-
)
|
95
|
-
completion_template_name = template["name"]
|
96
|
-
else:
|
97
|
-
completion_template_name = completion_template_arg
|
98
|
-
|
99
|
-
|
100
58
|
def register_completion_template(template: CompletionTemplate, override: bool = False):
|
101
59
|
"""Register a new completion template."""
|
102
60
|
if not override:
|
@@ -116,7 +74,7 @@ def is_completion_template_defined() -> bool:
|
|
116
74
|
return completion_template_name is not None
|
117
75
|
|
118
76
|
|
119
|
-
def generate_completion_prompt_from_request(request:
|
77
|
+
def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
|
120
78
|
global completion_template_name
|
121
79
|
if request.suffix == "":
|
122
80
|
return request.prompt
|
sglang/srt/constants.py
ADDED
sglang/srt/conversation.py
CHANGED
@@ -11,7 +11,17 @@
|
|
11
11
|
# See the License for the specific language governing permissions and
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
|
-
"""Conversation chat templates.
|
14
|
+
"""Conversation chat templates.
|
15
|
+
|
16
|
+
This module provides conversation template definitions, data structures, and utilities
|
17
|
+
for managing chat templates across different model types in SGLang.
|
18
|
+
|
19
|
+
Key components:
|
20
|
+
- Conversation class: Defines the structure and behavior of chat templates
|
21
|
+
- SeparatorStyle enum: Different conversation formatting styles
|
22
|
+
- Template registry: Functions to register and retrieve templates by name or model path
|
23
|
+
- Built-in templates: Pre-defined templates for popular models
|
24
|
+
"""
|
15
25
|
|
16
26
|
# Adapted from
|
17
27
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
@@ -20,7 +30,7 @@ import re
|
|
20
30
|
from enum import IntEnum, auto
|
21
31
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
22
32
|
|
23
|
-
from sglang.srt.
|
33
|
+
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
24
34
|
from sglang.srt.utils import read_system_prompt_from_file
|
25
35
|
|
26
36
|
|
@@ -618,7 +628,7 @@ def generate_chat_conv(
|
|
618
628
|
|
619
629
|
|
620
630
|
# llama2 template
|
621
|
-
# reference: https://
|
631
|
+
# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
622
632
|
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
|
623
633
|
register_conv_template(
|
624
634
|
Conversation(
|
@@ -813,6 +823,7 @@ register_conv_template(
|
|
813
823
|
sep_style=SeparatorStyle.GEMMA3,
|
814
824
|
stop_str=["<end_of_turn>"],
|
815
825
|
image_token="<start_of_image>",
|
826
|
+
audio_token="<start_of_audio>",
|
816
827
|
)
|
817
828
|
)
|
818
829
|
|
sglang/srt/custom_op.py
CHANGED
@@ -1,9 +1,12 @@
|
|
1
1
|
from torch import nn
|
2
2
|
|
3
|
-
from sglang.srt.utils import is_cuda, is_hip
|
3
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
4
4
|
|
5
5
|
_is_cuda = is_cuda()
|
6
6
|
_is_hip = is_hip()
|
7
|
+
_is_cpu = is_cpu()
|
8
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
9
|
+
_is_npu = is_npu()
|
7
10
|
|
8
11
|
|
9
12
|
class CustomOp(nn.Module):
|
@@ -58,6 +61,9 @@ class CustomOp(nn.Module):
|
|
58
61
|
def forward_cuda(self, *args, **kwargs):
|
59
62
|
raise NotImplementedError
|
60
63
|
|
64
|
+
def forward_npu(self, *args, **kwargs):
|
65
|
+
raise NotImplementedError
|
66
|
+
|
61
67
|
def forward_hip(self, *args, **kwargs):
|
62
68
|
return self.forward_cuda(*args, **kwargs)
|
63
69
|
|
@@ -75,5 +81,9 @@ class CustomOp(nn.Module):
|
|
75
81
|
return self.forward_cuda
|
76
82
|
elif _is_hip:
|
77
83
|
return self.forward_hip
|
84
|
+
elif _is_cpu and _is_cpu_amx_available:
|
85
|
+
return self.forward_cpu
|
86
|
+
elif _is_npu:
|
87
|
+
return self.forward_npu
|
78
88
|
else:
|
79
89
|
return self.forward_native
|
@@ -21,16 +21,15 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
-
import os
|
25
24
|
from collections import deque
|
26
25
|
from dataclasses import dataclass
|
27
26
|
from http import HTTPStatus
|
28
27
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
29
28
|
|
30
|
-
import numpy as np
|
31
29
|
import torch
|
32
30
|
from torch.distributed import ProcessGroup
|
33
31
|
|
32
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
34
33
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
35
34
|
from sglang.srt.disaggregation.utils import (
|
36
35
|
FAKE_BOOTSTRAP_HOST,
|
@@ -46,14 +45,12 @@ from sglang.srt.disaggregation.utils import (
|
|
46
45
|
prepare_abort,
|
47
46
|
)
|
48
47
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
48
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
49
49
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
50
|
-
from sglang.srt.mem_cache.memory_pool import
|
51
|
-
KVCache,
|
52
|
-
ReqToTokenPool,
|
53
|
-
TokenToKVPoolAllocator,
|
54
|
-
)
|
50
|
+
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
55
51
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
56
52
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
53
|
+
from sglang.srt.utils import require_mlp_sync
|
57
54
|
|
58
55
|
logger = logging.getLogger(__name__)
|
59
56
|
|
@@ -90,7 +87,7 @@ class DecodeReqToTokenPool:
|
|
90
87
|
self.max_context_len = max_context_len
|
91
88
|
self.device = device
|
92
89
|
self.pre_alloc_size = pre_alloc_size
|
93
|
-
with memory_saver_adapter.region():
|
90
|
+
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
|
94
91
|
self.req_to_token = torch.zeros(
|
95
92
|
(size + pre_alloc_size, max_context_len),
|
96
93
|
dtype=torch.int32,
|
@@ -139,7 +136,7 @@ class DecodePreallocQueue:
|
|
139
136
|
def __init__(
|
140
137
|
self,
|
141
138
|
req_to_token_pool: ReqToTokenPool,
|
142
|
-
token_to_kv_pool_allocator:
|
139
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
143
140
|
draft_token_to_kv_pool: Optional[KVCache],
|
144
141
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
145
142
|
metadata_buffers: MetadataBuffers,
|
@@ -540,6 +537,7 @@ class DecodeTransferQueue:
|
|
540
537
|
self.metadata_buffers = metadata_buffers
|
541
538
|
self.scheduler = scheduler
|
542
539
|
self.tree_cache = tree_cache
|
540
|
+
self.spec_algorithm = scheduler.spec_algorithm
|
543
541
|
|
544
542
|
def add(self, decode_req: DecodeRequest) -> None:
|
545
543
|
self.queue.append(decode_req)
|
@@ -585,10 +583,12 @@ class DecodeTransferQueue:
|
|
585
583
|
output_token_logprobs_idx,
|
586
584
|
output_top_logprobs_val,
|
587
585
|
output_top_logprobs_idx,
|
586
|
+
output_hidden_states,
|
588
587
|
) = self.metadata_buffers.get_buf(idx)
|
589
588
|
|
590
589
|
decode_req.req.output_ids.append(output_id[0].item())
|
591
|
-
|
590
|
+
if not self.spec_algorithm.is_none():
|
591
|
+
decode_req.req.hidden_states_tensor = output_hidden_states
|
592
592
|
if decode_req.req.return_logprob:
|
593
593
|
decode_req.req.output_token_logprobs_val.append(
|
594
594
|
output_token_logprobs_val[0].item()
|
@@ -645,10 +645,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
645
645
|
batch = self.get_next_disagg_decode_batch_to_run()
|
646
646
|
self.cur_batch = batch
|
647
647
|
|
648
|
-
|
649
|
-
self.server_args.enable_dp_attention
|
650
|
-
or self.server_args.enable_sp_layernorm
|
651
|
-
)
|
648
|
+
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
652
649
|
|
653
650
|
if batch:
|
654
651
|
# Generate fake extend output.
|
@@ -657,14 +654,14 @@ class SchedulerDisaggregationDecodeMixin:
|
|
657
654
|
self.stream_output(
|
658
655
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
659
656
|
)
|
660
|
-
if
|
657
|
+
if prepare_mlp_sync_flag:
|
661
658
|
self._prepare_idle_batch_and_run(None)
|
662
659
|
else:
|
663
|
-
if
|
664
|
-
self.
|
660
|
+
if prepare_mlp_sync_flag:
|
661
|
+
self.prepare_mlp_sync_batch(batch)
|
665
662
|
result = self.run_batch(batch)
|
666
663
|
self.process_batch_result(batch, result)
|
667
|
-
elif
|
664
|
+
elif prepare_mlp_sync_flag:
|
668
665
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
669
666
|
|
670
667
|
if batch is None and (
|
@@ -695,10 +692,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
695
692
|
self.cur_batch = batch
|
696
693
|
last_batch_in_queue = False
|
697
694
|
|
698
|
-
|
699
|
-
self.server_args.enable_dp_attention
|
700
|
-
or self.server_args.enable_sp_layernorm
|
701
|
-
)
|
695
|
+
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
702
696
|
|
703
697
|
if batch:
|
704
698
|
# Generate fake extend output.
|
@@ -707,7 +701,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
707
701
|
self.stream_output(
|
708
702
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
709
703
|
)
|
710
|
-
if
|
704
|
+
if prepare_mlp_sync_flag:
|
711
705
|
batch_, result = self._prepare_idle_batch_and_run(
|
712
706
|
None, delay_process=True
|
713
707
|
)
|
@@ -715,8 +709,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|
715
709
|
result_queue.append((batch_.copy(), result))
|
716
710
|
last_batch_in_queue = True
|
717
711
|
else:
|
718
|
-
if
|
719
|
-
self.
|
712
|
+
if prepare_mlp_sync_flag:
|
713
|
+
self.prepare_mlp_sync_batch(batch)
|
720
714
|
result = self.run_batch(batch)
|
721
715
|
result_queue.append((batch.copy(), result))
|
722
716
|
|
@@ -731,7 +725,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
731
725
|
self.set_next_batch_sampling_info_done(tmp_batch)
|
732
726
|
last_batch_in_queue = True
|
733
727
|
|
734
|
-
elif
|
728
|
+
elif prepare_mlp_sync_flag:
|
735
729
|
batch, result = self._prepare_idle_batch_and_run(
|
736
730
|
None, delay_process=True
|
737
731
|
)
|
@@ -761,8 +755,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|
761
755
|
self.last_batch = batch
|
762
756
|
self.last_batch_in_queue = last_batch_in_queue
|
763
757
|
|
764
|
-
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
765
|
-
batch, _ = self.
|
758
|
+
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
759
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
766
760
|
result = None
|
767
761
|
if batch:
|
768
762
|
result = self.run_batch(batch)
|
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
126
126
|
)
|
127
127
|
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
128
128
|
|
129
|
+
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
130
|
+
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
131
|
+
|
129
132
|
# local import to avoid circular import
|
130
133
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
131
134
|
|
132
135
|
spec_info = EagleDraftInput(
|
133
136
|
topk_p=topk_p,
|
134
137
|
topk_index=topk_index,
|
135
|
-
hidden_states=
|
136
|
-
(b, model_config.hidden_size), device=self.device
|
137
|
-
),
|
138
|
+
hidden_states=hidden_states,
|
138
139
|
verified_id=self.output_ids,
|
139
140
|
)
|
140
141
|
spec_info.prepare_for_extend(self)
|
@@ -18,6 +18,10 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
18
18
|
|
19
19
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
20
20
|
|
21
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
22
|
+
1024 * 64
|
23
|
+
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
24
|
+
|
21
25
|
|
22
26
|
def setup_logger():
|
23
27
|
logger = logging.getLogger("pdlb")
|
@@ -154,7 +158,9 @@ class MiniLoadBalancer:
|
|
154
158
|
else:
|
155
159
|
yield chunk
|
156
160
|
else:
|
157
|
-
async for chunk in decode_response.content
|
161
|
+
async for chunk in decode_response.content.iter_chunked(
|
162
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
163
|
+
):
|
158
164
|
yield chunk
|
159
165
|
|
160
166
|
return StreamingResponse(
|
@@ -212,15 +218,39 @@ async def get_server_info():
|
|
212
218
|
)
|
213
219
|
prefill_infos = []
|
214
220
|
decode_infos = []
|
221
|
+
all_internal_states = []
|
222
|
+
|
215
223
|
async with aiohttp.ClientSession() as session:
|
216
224
|
for server in chain(prefill_servers):
|
217
225
|
server_info = await session.get(f"{server}/get_server_info")
|
218
226
|
prefill_infos.append(await server_info.json())
|
219
227
|
for server in chain(decode_servers):
|
220
228
|
server_info = await session.get(f"{server}/get_server_info")
|
221
|
-
|
222
|
-
|
223
|
-
|
229
|
+
info_json = await server_info.json()
|
230
|
+
decode_infos.append(info_json)
|
231
|
+
# Extract internal_states from decode servers
|
232
|
+
if "internal_states" in info_json:
|
233
|
+
all_internal_states.extend(info_json["internal_states"])
|
234
|
+
|
235
|
+
# Return format expected by bench_one_batch_server.py
|
236
|
+
if all_internal_states:
|
237
|
+
return {
|
238
|
+
"internal_states": all_internal_states,
|
239
|
+
"prefill": prefill_infos,
|
240
|
+
"decode": decode_infos,
|
241
|
+
}
|
242
|
+
else:
|
243
|
+
# Fallback with dummy data if no internal states found
|
244
|
+
return {
|
245
|
+
"internal_states": [
|
246
|
+
{
|
247
|
+
"last_gen_throughput": 0.0,
|
248
|
+
"avg_spec_accept_length": None,
|
249
|
+
}
|
250
|
+
],
|
251
|
+
"prefill": prefill_infos,
|
252
|
+
"decode": decode_infos,
|
253
|
+
}
|
224
254
|
|
225
255
|
|
226
256
|
@app.get("/get_model_info")
|