sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +13 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +12 -16
- sglang/srt/disaggregation/prefill.py +17 -13
- sglang/srt/disaggregation/utils.py +46 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +22 -28
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +21 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +19 -9
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +207 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +91 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/io_struct.py +9 -12
- sglang/srt/managers/schedule_batch.py +40 -31
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +147 -62
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +76 -45
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +22 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +108 -26
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +36 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/utils.py +177 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
|
|
71
71
|
configure_logger,
|
72
72
|
get_bool_env_var,
|
73
73
|
kill_process_tree,
|
74
|
+
require_mlp_sync,
|
75
|
+
require_mlp_tp_gather,
|
74
76
|
set_gpu_proc_affinity,
|
75
77
|
suppress_other_loggers,
|
76
78
|
)
|
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
|
|
243
245
|
enable_custom_logit_processor=False,
|
244
246
|
)
|
245
247
|
batch.prepare_for_extend()
|
246
|
-
|
248
|
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
247
249
|
model_worker_batch = batch.get_model_worker_batch()
|
248
250
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
249
251
|
logits_output, _ = model_runner.forward(forward_batch)
|
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
|
|
255
257
|
def decode(input_token_ids, batch, model_runner):
|
256
258
|
batch.output_ids = input_token_ids
|
257
259
|
batch.prepare_for_decode()
|
258
|
-
|
260
|
+
_maybe_prepare_mlp_sync_batch(batch, model_runner)
|
259
261
|
model_worker_batch = batch.get_model_worker_batch()
|
260
262
|
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
|
261
263
|
logits_output, _ = model_runner.forward(forward_batch)
|
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
|
|
263
265
|
return next_token_ids, logits_output.next_token_logits
|
264
266
|
|
265
267
|
|
266
|
-
def
|
267
|
-
if model_runner.server_args
|
268
|
-
Scheduler.
|
268
|
+
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
269
|
+
if require_mlp_sync(model_runner.server_args):
|
270
|
+
Scheduler.prepare_mlp_sync_batch_raw(
|
269
271
|
batch,
|
270
272
|
dp_size=model_runner.server_args.dp_size,
|
271
273
|
attn_tp_size=1,
|
272
|
-
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
273
274
|
tp_cpu_group=model_runner.tp_group.cpu_group,
|
274
275
|
get_idle_batch=None,
|
275
276
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
276
277
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
277
278
|
speculative_num_draft_tokens=None,
|
279
|
+
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
278
280
|
)
|
279
281
|
|
280
282
|
|
sglang/srt/_custom_ops.py
CHANGED
@@ -4,7 +4,7 @@ from typing import List, Tuple
|
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
7
|
-
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
|
7
|
+
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
|
8
8
|
|
9
9
|
logger = logging.getLogger(__name__)
|
10
10
|
use_vllm_custom_allreduce = get_bool_env_var(
|
@@ -25,7 +25,7 @@ if not is_hpu():
|
|
25
25
|
logger.warning("Failed to import from custom_ar with %r", e)
|
26
26
|
|
27
27
|
|
28
|
-
if not is_hip():
|
28
|
+
if not is_hip() and not is_npu():
|
29
29
|
if use_vllm_custom_allreduce:
|
30
30
|
custom_op = torch.ops._C_custom_ar
|
31
31
|
else:
|
@@ -15,12 +15,10 @@
|
|
15
15
|
|
16
16
|
|
17
17
|
import dataclasses
|
18
|
-
import json
|
19
18
|
import logging
|
20
|
-
import os
|
21
19
|
from enum import auto
|
22
20
|
|
23
|
-
from sglang.srt.
|
21
|
+
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
|
24
22
|
|
25
23
|
logger = logging.getLogger(__name__)
|
26
24
|
completion_template_name = None
|
@@ -57,46 +55,6 @@ class CompletionTemplate:
|
|
57
55
|
completion_templates: dict[str, CompletionTemplate] = {}
|
58
56
|
|
59
57
|
|
60
|
-
def load_completion_template_for_openai_api(completion_template_arg):
|
61
|
-
global completion_template_name
|
62
|
-
|
63
|
-
logger.info(
|
64
|
-
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
|
65
|
-
)
|
66
|
-
|
67
|
-
if not completion_template_exists(completion_template_arg):
|
68
|
-
if not os.path.exists(completion_template_arg):
|
69
|
-
raise RuntimeError(
|
70
|
-
f"Completion template {completion_template_arg} is not a built-in template name "
|
71
|
-
"or a valid completion template file path."
|
72
|
-
)
|
73
|
-
|
74
|
-
assert completion_template_arg.endswith(
|
75
|
-
".json"
|
76
|
-
), "unrecognized format of completion template file"
|
77
|
-
with open(completion_template_arg, "r") as filep:
|
78
|
-
template = json.load(filep)
|
79
|
-
try:
|
80
|
-
fim_position = FimPosition[template["fim_position"]]
|
81
|
-
except KeyError:
|
82
|
-
raise ValueError(
|
83
|
-
f"Unknown fim position: {template['fim_position']}"
|
84
|
-
) from None
|
85
|
-
register_completion_template(
|
86
|
-
CompletionTemplate(
|
87
|
-
name=template["name"],
|
88
|
-
fim_begin_token=template["fim_begin_token"],
|
89
|
-
fim_middle_token=template["fim_middle_token"],
|
90
|
-
fim_end_token=template["fim_end_token"],
|
91
|
-
fim_position=fim_position,
|
92
|
-
),
|
93
|
-
override=True,
|
94
|
-
)
|
95
|
-
completion_template_name = template["name"]
|
96
|
-
else:
|
97
|
-
completion_template_name = completion_template_arg
|
98
|
-
|
99
|
-
|
100
58
|
def register_completion_template(template: CompletionTemplate, override: bool = False):
|
101
59
|
"""Register a new completion template."""
|
102
60
|
if not override:
|
@@ -116,7 +74,7 @@ def is_completion_template_defined() -> bool:
|
|
116
74
|
return completion_template_name is not None
|
117
75
|
|
118
76
|
|
119
|
-
def generate_completion_prompt_from_request(request:
|
77
|
+
def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
|
120
78
|
global completion_template_name
|
121
79
|
if request.suffix == "":
|
122
80
|
return request.prompt
|
sglang/srt/constants.py
ADDED
sglang/srt/conversation.py
CHANGED
@@ -11,7 +11,17 @@
|
|
11
11
|
# See the License for the specific language governing permissions and
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
|
-
"""Conversation chat templates.
|
14
|
+
"""Conversation chat templates.
|
15
|
+
|
16
|
+
This module provides conversation template definitions, data structures, and utilities
|
17
|
+
for managing chat templates across different model types in SGLang.
|
18
|
+
|
19
|
+
Key components:
|
20
|
+
- Conversation class: Defines the structure and behavior of chat templates
|
21
|
+
- SeparatorStyle enum: Different conversation formatting styles
|
22
|
+
- Template registry: Functions to register and retrieve templates by name or model path
|
23
|
+
- Built-in templates: Pre-defined templates for popular models
|
24
|
+
"""
|
15
25
|
|
16
26
|
# Adapted from
|
17
27
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
@@ -20,7 +30,7 @@ import re
|
|
20
30
|
from enum import IntEnum, auto
|
21
31
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
22
32
|
|
23
|
-
from sglang.srt.
|
33
|
+
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
|
24
34
|
from sglang.srt.utils import read_system_prompt_from_file
|
25
35
|
|
26
36
|
|
@@ -618,7 +628,7 @@ def generate_chat_conv(
|
|
618
628
|
|
619
629
|
|
620
630
|
# llama2 template
|
621
|
-
# reference: https://
|
631
|
+
# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
622
632
|
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
|
623
633
|
register_conv_template(
|
624
634
|
Conversation(
|
sglang/srt/custom_op.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
from torch import nn
|
2
2
|
|
3
|
-
from sglang.srt.utils import is_cuda, is_hip
|
3
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
|
4
4
|
|
5
5
|
_is_cuda = is_cuda()
|
6
6
|
_is_hip = is_hip()
|
7
|
+
_is_cpu = is_cpu()
|
8
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
7
9
|
|
8
10
|
|
9
11
|
class CustomOp(nn.Module):
|
@@ -75,5 +77,7 @@ class CustomOp(nn.Module):
|
|
75
77
|
return self.forward_cuda
|
76
78
|
elif _is_hip:
|
77
79
|
return self.forward_hip
|
80
|
+
elif _is_cpu and _is_cpu_amx_available:
|
81
|
+
return self.forward_cpu
|
78
82
|
else:
|
79
83
|
return self.forward_native
|
@@ -21,16 +21,15 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
-
import os
|
25
24
|
from collections import deque
|
26
25
|
from dataclasses import dataclass
|
27
26
|
from http import HTTPStatus
|
28
27
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
29
28
|
|
30
|
-
import numpy as np
|
31
29
|
import torch
|
32
30
|
from torch.distributed import ProcessGroup
|
33
31
|
|
32
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
34
33
|
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
|
35
34
|
from sglang.srt.disaggregation.utils import (
|
36
35
|
FAKE_BOOTSTRAP_HOST,
|
@@ -46,14 +45,12 @@ from sglang.srt.disaggregation.utils import (
|
|
46
45
|
prepare_abort,
|
47
46
|
)
|
48
47
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
|
48
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
49
49
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
50
|
-
from sglang.srt.mem_cache.memory_pool import
|
51
|
-
KVCache,
|
52
|
-
ReqToTokenPool,
|
53
|
-
TokenToKVPoolAllocator,
|
54
|
-
)
|
50
|
+
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
55
51
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
56
52
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
53
|
+
from sglang.srt.utils import require_mlp_sync
|
57
54
|
|
58
55
|
logger = logging.getLogger(__name__)
|
59
56
|
|
@@ -90,7 +87,7 @@ class DecodeReqToTokenPool:
|
|
90
87
|
self.max_context_len = max_context_len
|
91
88
|
self.device = device
|
92
89
|
self.pre_alloc_size = pre_alloc_size
|
93
|
-
with memory_saver_adapter.region():
|
90
|
+
with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
|
94
91
|
self.req_to_token = torch.zeros(
|
95
92
|
(size + pre_alloc_size, max_context_len),
|
96
93
|
dtype=torch.int32,
|
@@ -139,7 +136,7 @@ class DecodePreallocQueue:
|
|
139
136
|
def __init__(
|
140
137
|
self,
|
141
138
|
req_to_token_pool: ReqToTokenPool,
|
142
|
-
token_to_kv_pool_allocator:
|
139
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
143
140
|
draft_token_to_kv_pool: Optional[KVCache],
|
144
141
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
145
142
|
metadata_buffers: MetadataBuffers,
|
@@ -540,6 +537,7 @@ class DecodeTransferQueue:
|
|
540
537
|
self.metadata_buffers = metadata_buffers
|
541
538
|
self.scheduler = scheduler
|
542
539
|
self.tree_cache = tree_cache
|
540
|
+
self.spec_algorithm = scheduler.spec_algorithm
|
543
541
|
|
544
542
|
def add(self, decode_req: DecodeRequest) -> None:
|
545
543
|
self.queue.append(decode_req)
|
@@ -581,6 +579,7 @@ class DecodeTransferQueue:
|
|
581
579
|
idx = decode_req.metadata_buffer_index
|
582
580
|
(
|
583
581
|
output_id,
|
582
|
+
output_hidden_states,
|
584
583
|
output_token_logprobs_val,
|
585
584
|
output_token_logprobs_idx,
|
586
585
|
output_top_logprobs_val,
|
@@ -588,7 +587,8 @@ class DecodeTransferQueue:
|
|
588
587
|
) = self.metadata_buffers.get_buf(idx)
|
589
588
|
|
590
589
|
decode_req.req.output_ids.append(output_id[0].item())
|
591
|
-
|
590
|
+
if not self.spec_algorithm.is_none():
|
591
|
+
decode_req.req.hidden_states_tensor = output_hidden_states
|
592
592
|
if decode_req.req.return_logprob:
|
593
593
|
decode_req.req.output_token_logprobs_val.append(
|
594
594
|
output_token_logprobs_val[0].item()
|
@@ -645,10 +645,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
645
645
|
batch = self.get_next_disagg_decode_batch_to_run()
|
646
646
|
self.cur_batch = batch
|
647
647
|
|
648
|
-
|
649
|
-
self.server_args.enable_dp_attention
|
650
|
-
or self.server_args.enable_sp_layernorm
|
651
|
-
)
|
648
|
+
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
652
649
|
|
653
650
|
if batch:
|
654
651
|
# Generate fake extend output.
|
@@ -657,14 +654,14 @@ class SchedulerDisaggregationDecodeMixin:
|
|
657
654
|
self.stream_output(
|
658
655
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
659
656
|
)
|
660
|
-
if
|
657
|
+
if prepare_mlp_sync_flag:
|
661
658
|
self._prepare_idle_batch_and_run(None)
|
662
659
|
else:
|
663
|
-
if
|
664
|
-
self.
|
660
|
+
if prepare_mlp_sync_flag:
|
661
|
+
self.prepare_mlp_sync_batch(batch)
|
665
662
|
result = self.run_batch(batch)
|
666
663
|
self.process_batch_result(batch, result)
|
667
|
-
elif
|
664
|
+
elif prepare_mlp_sync_flag:
|
668
665
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
669
666
|
|
670
667
|
if batch is None and (
|
@@ -695,10 +692,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
695
692
|
self.cur_batch = batch
|
696
693
|
last_batch_in_queue = False
|
697
694
|
|
698
|
-
|
699
|
-
self.server_args.enable_dp_attention
|
700
|
-
or self.server_args.enable_sp_layernorm
|
701
|
-
)
|
695
|
+
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
|
702
696
|
|
703
697
|
if batch:
|
704
698
|
# Generate fake extend output.
|
@@ -707,7 +701,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
707
701
|
self.stream_output(
|
708
702
|
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
709
703
|
)
|
710
|
-
if
|
704
|
+
if prepare_mlp_sync_flag:
|
711
705
|
batch_, result = self._prepare_idle_batch_and_run(
|
712
706
|
None, delay_process=True
|
713
707
|
)
|
@@ -715,8 +709,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|
715
709
|
result_queue.append((batch_.copy(), result))
|
716
710
|
last_batch_in_queue = True
|
717
711
|
else:
|
718
|
-
if
|
719
|
-
self.
|
712
|
+
if prepare_mlp_sync_flag:
|
713
|
+
self.prepare_mlp_sync_batch(batch)
|
720
714
|
result = self.run_batch(batch)
|
721
715
|
result_queue.append((batch.copy(), result))
|
722
716
|
|
@@ -731,7 +725,7 @@ class SchedulerDisaggregationDecodeMixin:
|
|
731
725
|
self.set_next_batch_sampling_info_done(tmp_batch)
|
732
726
|
last_batch_in_queue = True
|
733
727
|
|
734
|
-
elif
|
728
|
+
elif prepare_mlp_sync_flag:
|
735
729
|
batch, result = self._prepare_idle_batch_and_run(
|
736
730
|
None, delay_process=True
|
737
731
|
)
|
@@ -761,8 +755,8 @@ class SchedulerDisaggregationDecodeMixin:
|
|
761
755
|
self.last_batch = batch
|
762
756
|
self.last_batch_in_queue = last_batch_in_queue
|
763
757
|
|
764
|
-
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
765
|
-
batch, _ = self.
|
758
|
+
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
|
759
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
766
760
|
result = None
|
767
761
|
if batch:
|
768
762
|
result = self.run_batch(batch)
|
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
126
126
|
)
|
127
127
|
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
128
128
|
|
129
|
+
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
|
130
|
+
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
131
|
+
|
129
132
|
# local import to avoid circular import
|
130
133
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
131
134
|
|
132
135
|
spec_info = EagleDraftInput(
|
133
136
|
topk_p=topk_p,
|
134
137
|
topk_index=topk_index,
|
135
|
-
hidden_states=
|
136
|
-
(b, model_config.hidden_size), device=self.device
|
137
|
-
),
|
138
|
+
hidden_states=hidden_states,
|
138
139
|
verified_id=self.output_ids,
|
139
140
|
)
|
140
141
|
spec_info.prepare_for_extend(self)
|
@@ -18,6 +18,10 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
18
18
|
|
19
19
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
20
20
|
|
21
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
22
|
+
1024 * 64
|
23
|
+
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
24
|
+
|
21
25
|
|
22
26
|
def setup_logger():
|
23
27
|
logger = logging.getLogger("pdlb")
|
@@ -154,7 +158,9 @@ class MiniLoadBalancer:
|
|
154
158
|
else:
|
155
159
|
yield chunk
|
156
160
|
else:
|
157
|
-
async for chunk in decode_response.content
|
161
|
+
async for chunk in decode_response.content.iter_chunked(
|
162
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
163
|
+
):
|
158
164
|
yield chunk
|
159
165
|
|
160
166
|
return StreamingResponse(
|
@@ -212,15 +218,39 @@ async def get_server_info():
|
|
212
218
|
)
|
213
219
|
prefill_infos = []
|
214
220
|
decode_infos = []
|
221
|
+
all_internal_states = []
|
222
|
+
|
215
223
|
async with aiohttp.ClientSession() as session:
|
216
224
|
for server in chain(prefill_servers):
|
217
225
|
server_info = await session.get(f"{server}/get_server_info")
|
218
226
|
prefill_infos.append(await server_info.json())
|
219
227
|
for server in chain(decode_servers):
|
220
228
|
server_info = await session.get(f"{server}/get_server_info")
|
221
|
-
|
222
|
-
|
223
|
-
|
229
|
+
info_json = await server_info.json()
|
230
|
+
decode_infos.append(info_json)
|
231
|
+
# Extract internal_states from decode servers
|
232
|
+
if "internal_states" in info_json:
|
233
|
+
all_internal_states.extend(info_json["internal_states"])
|
234
|
+
|
235
|
+
# Return format expected by bench_one_batch_server.py
|
236
|
+
if all_internal_states:
|
237
|
+
return {
|
238
|
+
"internal_states": all_internal_states,
|
239
|
+
"prefill": prefill_infos,
|
240
|
+
"decode": decode_infos,
|
241
|
+
}
|
242
|
+
else:
|
243
|
+
# Fallback with dummy data if no internal states found
|
244
|
+
return {
|
245
|
+
"internal_states": [
|
246
|
+
{
|
247
|
+
"last_gen_throughput": 0.0,
|
248
|
+
"avg_spec_accept_length": None,
|
249
|
+
}
|
250
|
+
],
|
251
|
+
"prefill": prefill_infos,
|
252
|
+
"decode": decode_infos,
|
253
|
+
}
|
224
254
|
|
225
255
|
|
226
256
|
@app.get("/get_model_info")
|
@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
|
|
35
35
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
36
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
|
-
from sglang.srt.utils import
|
39
|
-
get_free_port,
|
40
|
-
get_int_env_var,
|
41
|
-
get_ip,
|
42
|
-
get_local_ip_by_remote,
|
43
|
-
)
|
38
|
+
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
|
44
39
|
|
45
40
|
logger = logging.getLogger(__name__)
|
46
41
|
|
@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
130
125
|
is_mla_backend: Optional[bool] = False,
|
131
126
|
):
|
132
127
|
self.kv_args = args
|
128
|
+
self.local_ip = get_local_ip_auto()
|
133
129
|
self.engine = MooncakeTransferEngine(
|
134
|
-
hostname=
|
130
|
+
hostname=self.local_ip,
|
135
131
|
gpu_id=self.kv_args.gpu_id,
|
136
132
|
ib_device=self.kv_args.ib_device,
|
137
133
|
)
|
@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
432
428
|
|
433
429
|
def start_prefill_thread(self):
|
434
430
|
self.rank_port = get_free_port()
|
435
|
-
self.server_socket.bind(f"tcp://{
|
431
|
+
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
436
432
|
|
437
433
|
def bootstrap_thread():
|
438
434
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
471
467
|
|
472
468
|
def start_decode_thread(self):
|
473
469
|
self.rank_port = get_free_port()
|
474
|
-
self.server_socket.bind(f"tcp://{
|
470
|
+
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
475
471
|
|
476
472
|
def decode_thread():
|
477
473
|
while True:
|
@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
620
616
|
"role": "Prefill",
|
621
617
|
"tp_size": self.tp_size,
|
622
618
|
"dp_size": self.dp_size,
|
623
|
-
"rank_ip":
|
619
|
+
"rank_ip": self.local_ip,
|
624
620
|
"rank_port": self.rank_port,
|
625
621
|
"engine_rank": self.kv_args.engine_rank,
|
626
622
|
}
|
@@ -746,12 +742,12 @@ class MooncakeKVSender(BaseKVSender):
|
|
746
742
|
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
747
743
|
|
748
744
|
def failure_exception(self):
|
749
|
-
self.clear()
|
750
|
-
|
751
745
|
# Explicitly set the status to failure since this request has failed in another rank
|
752
746
|
if self.conclude_state is None:
|
753
747
|
self.conclude_state = KVPoll.Failed
|
754
748
|
|
749
|
+
self.clear()
|
750
|
+
|
755
751
|
with self.kv_mgr.failure_lock:
|
756
752
|
failure_reason = self.kv_mgr.failure_records.pop(
|
757
753
|
self.bootstrap_room, "Failed due to an unknown reason from another rank"
|
@@ -953,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
953
949
|
sock.send_multipart(
|
954
950
|
[
|
955
951
|
"None".encode("ascii"),
|
956
|
-
|
952
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
957
953
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
958
954
|
self.session_id.encode("ascii"),
|
959
955
|
packed_kv_data_ptrs,
|
@@ -983,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
983
979
|
sock.send_multipart(
|
984
980
|
[
|
985
981
|
str(self.bootstrap_room).encode("ascii"),
|
986
|
-
|
982
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
987
983
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
988
984
|
self.session_id.encode("ascii"),
|
989
985
|
kv_indices.tobytes() if not is_dummy else b"",
|
@@ -1007,12 +1003,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1007
1003
|
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
1008
1004
|
|
1009
1005
|
def failure_exception(self):
|
1010
|
-
self.clear()
|
1011
|
-
|
1012
1006
|
# Explicitly set the status to failure since this request has failed in another rank
|
1013
1007
|
if self.conclude_state is None:
|
1014
1008
|
self.conclude_state = KVPoll.Failed
|
1015
1009
|
|
1010
|
+
self.clear()
|
1011
|
+
|
1016
1012
|
with self.kv_mgr.failure_lock:
|
1017
1013
|
failure_reason = self.kv_mgr.failure_records.pop(
|
1018
1014
|
self.bootstrap_room, "Failed due to an unknown reason from another rank"
|
@@ -25,7 +25,6 @@ from collections import deque
|
|
25
25
|
from http import HTTPStatus
|
26
26
|
from typing import TYPE_CHECKING, List, Optional
|
27
27
|
|
28
|
-
import numpy as np
|
29
28
|
import torch
|
30
29
|
|
31
30
|
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
|
@@ -45,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
|
|
45
44
|
)
|
46
45
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
47
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
47
|
+
from sglang.srt.utils import require_mlp_sync
|
48
48
|
|
49
49
|
if TYPE_CHECKING:
|
50
50
|
from torch.distributed import ProcessGroup
|
@@ -274,12 +274,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
274
274
|
self.process_prefill_chunk()
|
275
275
|
batch = self.get_new_batch_prefill()
|
276
276
|
|
277
|
-
|
278
|
-
|
279
|
-
self.server_args.enable_dp_attention
|
280
|
-
or self.server_args.enable_sp_layernorm
|
281
|
-
):
|
282
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
277
|
+
if require_mlp_sync(self.server_args):
|
278
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
283
279
|
self.cur_batch = batch
|
284
280
|
|
285
281
|
if batch:
|
@@ -312,12 +308,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
312
308
|
self.process_prefill_chunk()
|
313
309
|
batch = self.get_new_batch_prefill()
|
314
310
|
|
315
|
-
|
316
|
-
|
317
|
-
self.server_args.enable_dp_attention
|
318
|
-
or self.server_args.enable_sp_layernorm
|
319
|
-
):
|
320
|
-
batch, _ = self.prepare_dp_attn_batch(batch)
|
311
|
+
if require_mlp_sync(self.server_args):
|
312
|
+
batch, _ = self.prepare_mlp_sync_batch(batch)
|
321
313
|
self.cur_batch = batch
|
322
314
|
if batch:
|
323
315
|
result = self.run_batch(batch)
|
@@ -393,6 +385,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
393
385
|
logits_output.input_token_logprobs = tuple(
|
394
386
|
logits_output.input_token_logprobs.tolist()
|
395
387
|
)
|
388
|
+
|
389
|
+
hidden_state_offset = 0
|
396
390
|
for i, (req, next_token_id) in enumerate(
|
397
391
|
zip(batch.reqs, next_token_ids, strict=True)
|
398
392
|
):
|
@@ -402,6 +396,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
402
396
|
req.output_ids.append(next_token_id)
|
403
397
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
404
398
|
self.disagg_prefill_inflight_queue.append(req)
|
399
|
+
if logits_output.hidden_states is not None:
|
400
|
+
last_hidden_index = (
|
401
|
+
hidden_state_offset + extend_input_len_per_req[i] - 1
|
402
|
+
)
|
403
|
+
req.hidden_states_tensor = (
|
404
|
+
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
405
|
+
)
|
406
|
+
hidden_state_offset += extend_input_len_per_req[i]
|
407
|
+
else:
|
408
|
+
req.hidden_states_tensor = None
|
405
409
|
if req.return_logprob:
|
406
410
|
assert extend_logprob_start_len_per_req is not None
|
407
411
|
assert extend_input_len_per_req is not None
|