sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
71
71
  configure_logger,
72
72
  get_bool_env_var,
73
73
  kill_process_tree,
74
+ require_mlp_sync,
75
+ require_mlp_tp_gather,
74
76
  set_gpu_proc_affinity,
75
77
  suppress_other_loggers,
76
78
  )
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
243
245
  enable_custom_logit_processor=False,
244
246
  )
245
247
  batch.prepare_for_extend()
246
- _maybe_prepare_dp_attn_batch(batch, model_runner)
248
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
247
249
  model_worker_batch = batch.get_model_worker_batch()
248
250
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
249
251
  logits_output, _ = model_runner.forward(forward_batch)
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
255
257
  def decode(input_token_ids, batch, model_runner):
256
258
  batch.output_ids = input_token_ids
257
259
  batch.prepare_for_decode()
258
- _maybe_prepare_dp_attn_batch(batch, model_runner)
260
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
259
261
  model_worker_batch = batch.get_model_worker_batch()
260
262
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
261
263
  logits_output, _ = model_runner.forward(forward_batch)
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
263
265
  return next_token_ids, logits_output.next_token_logits
264
266
 
265
267
 
266
- def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
267
- if model_runner.server_args.enable_dp_attention:
268
- Scheduler.prepare_dp_attn_batch_raw(
268
+ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
269
+ if require_mlp_sync(model_runner.server_args):
270
+ Scheduler.prepare_mlp_sync_batch_raw(
269
271
  batch,
270
272
  dp_size=model_runner.server_args.dp_size,
271
273
  attn_tp_size=1,
272
- moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
273
274
  tp_cpu_group=model_runner.tp_group.cpu_group,
274
275
  get_idle_batch=None,
275
276
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
276
277
  spec_algorithm=SpeculativeAlgorithm.NONE,
277
278
  speculative_num_draft_tokens=None,
279
+ require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
278
280
  )
279
281
 
280
282
 
sglang/srt/_custom_ops.py CHANGED
@@ -4,7 +4,7 @@ from typing import List, Tuple
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
7
+ from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
  use_vllm_custom_allreduce = get_bool_env_var(
@@ -25,7 +25,7 @@ if not is_hpu():
25
25
  logger.warning("Failed to import from custom_ar with %r", e)
26
26
 
27
27
 
28
- if not is_hip():
28
+ if not is_hip() and not is_npu():
29
29
  if use_vllm_custom_allreduce:
30
30
  custom_op = torch.ops._C_custom_ar
31
31
  else:
@@ -15,12 +15,10 @@
15
15
 
16
16
 
17
17
  import dataclasses
18
- import json
19
18
  import logging
20
- import os
21
19
  from enum import auto
22
20
 
23
- from sglang.srt.openai_api.protocol import ChatCompletionRequest
21
+ from sglang.srt.entrypoints.openai.protocol import CompletionRequest
24
22
 
25
23
  logger = logging.getLogger(__name__)
26
24
  completion_template_name = None
@@ -57,46 +55,6 @@ class CompletionTemplate:
57
55
  completion_templates: dict[str, CompletionTemplate] = {}
58
56
 
59
57
 
60
- def load_completion_template_for_openai_api(completion_template_arg):
61
- global completion_template_name
62
-
63
- logger.info(
64
- f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
65
- )
66
-
67
- if not completion_template_exists(completion_template_arg):
68
- if not os.path.exists(completion_template_arg):
69
- raise RuntimeError(
70
- f"Completion template {completion_template_arg} is not a built-in template name "
71
- "or a valid completion template file path."
72
- )
73
-
74
- assert completion_template_arg.endswith(
75
- ".json"
76
- ), "unrecognized format of completion template file"
77
- with open(completion_template_arg, "r") as filep:
78
- template = json.load(filep)
79
- try:
80
- fim_position = FimPosition[template["fim_position"]]
81
- except KeyError:
82
- raise ValueError(
83
- f"Unknown fim position: {template['fim_position']}"
84
- ) from None
85
- register_completion_template(
86
- CompletionTemplate(
87
- name=template["name"],
88
- fim_begin_token=template["fim_begin_token"],
89
- fim_middle_token=template["fim_middle_token"],
90
- fim_end_token=template["fim_end_token"],
91
- fim_position=fim_position,
92
- ),
93
- override=True,
94
- )
95
- completion_template_name = template["name"]
96
- else:
97
- completion_template_name = completion_template_arg
98
-
99
-
100
58
  def register_completion_template(template: CompletionTemplate, override: bool = False):
101
59
  """Register a new completion template."""
102
60
  if not override:
@@ -116,7 +74,7 @@ def is_completion_template_defined() -> bool:
116
74
  return completion_template_name is not None
117
75
 
118
76
 
119
- def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
77
+ def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
120
78
  global completion_template_name
121
79
  if request.suffix == "":
122
80
  return request.prompt
@@ -565,6 +565,7 @@ multimodal_model_archs = [
565
565
  "CLIPModel",
566
566
  "DeepseekVL2ForCausalLM",
567
567
  "Gemma3ForConditionalGeneration",
568
+ "Gemma3nForConditionalGeneration",
568
569
  "Grok1VForCausalLM",
569
570
  "Grok1AForCausalLM",
570
571
  "LlavaLlamaForCausalLM",
@@ -0,0 +1,3 @@
1
+ # GPU Memory Types
2
+ GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
3
+ GPU_MEMORY_TYPE_WEIGHTS = "weights"
@@ -11,7 +11,17 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
- """Conversation chat templates."""
14
+ """Conversation chat templates.
15
+
16
+ This module provides conversation template definitions, data structures, and utilities
17
+ for managing chat templates across different model types in SGLang.
18
+
19
+ Key components:
20
+ - Conversation class: Defines the structure and behavior of chat templates
21
+ - SeparatorStyle enum: Different conversation formatting styles
22
+ - Template registry: Functions to register and retrieve templates by name or model path
23
+ - Built-in templates: Pre-defined templates for popular models
24
+ """
15
25
 
16
26
  # Adapted from
17
27
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -20,7 +30,7 @@ import re
20
30
  from enum import IntEnum, auto
21
31
  from typing import Callable, Dict, List, Optional, Tuple, Union
22
32
 
23
- from sglang.srt.openai_api.protocol import ChatCompletionRequest
33
+ from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
24
34
  from sglang.srt.utils import read_system_prompt_from_file
25
35
 
26
36
 
@@ -618,7 +628,7 @@ def generate_chat_conv(
618
628
 
619
629
 
620
630
  # llama2 template
621
- # reference: https://huggingface.co/blog/codellama#conversational-instructions
631
+ # reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
622
632
  # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
623
633
  register_conv_template(
624
634
  Conversation(
@@ -813,6 +823,7 @@ register_conv_template(
813
823
  sep_style=SeparatorStyle.GEMMA3,
814
824
  stop_str=["<end_of_turn>"],
815
825
  image_token="<start_of_image>",
826
+ audio_token="<start_of_audio>",
816
827
  )
817
828
  )
818
829
 
sglang/srt/custom_op.py CHANGED
@@ -1,9 +1,12 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import is_cuda, is_hip
3
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
4
4
 
5
5
  _is_cuda = is_cuda()
6
6
  _is_hip = is_hip()
7
+ _is_cpu = is_cpu()
8
+ _is_cpu_amx_available = cpu_has_amx_support()
9
+ _is_npu = is_npu()
7
10
 
8
11
 
9
12
  class CustomOp(nn.Module):
@@ -58,6 +61,9 @@ class CustomOp(nn.Module):
58
61
  def forward_cuda(self, *args, **kwargs):
59
62
  raise NotImplementedError
60
63
 
64
+ def forward_npu(self, *args, **kwargs):
65
+ raise NotImplementedError
66
+
61
67
  def forward_hip(self, *args, **kwargs):
62
68
  return self.forward_cuda(*args, **kwargs)
63
69
 
@@ -75,5 +81,9 @@ class CustomOp(nn.Module):
75
81
  return self.forward_cuda
76
82
  elif _is_hip:
77
83
  return self.forward_hip
84
+ elif _is_cpu and _is_cpu_amx_available:
85
+ return self.forward_cpu
86
+ elif _is_npu:
87
+ return self.forward_npu
78
88
  else:
79
89
  return self.forward_native
@@ -27,6 +27,8 @@ class KVArgs:
27
27
  decode_tp_size: int
28
28
  # for pp prefill
29
29
  prefill_pp_size: int
30
+ kv_head_num: int
31
+ page_size: int
30
32
 
31
33
 
32
34
  class KVPoll:
@@ -21,16 +21,15 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
- import os
25
24
  from collections import deque
26
25
  from dataclasses import dataclass
27
26
  from http import HTTPStatus
28
27
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
29
28
 
30
- import numpy as np
31
29
  import torch
32
30
  from torch.distributed import ProcessGroup
33
31
 
32
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
34
33
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
35
34
  from sglang.srt.disaggregation.utils import (
36
35
  FAKE_BOOTSTRAP_HOST,
@@ -46,14 +45,12 @@ from sglang.srt.disaggregation.utils import (
46
45
  prepare_abort,
47
46
  )
48
47
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
48
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
49
49
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
- from sglang.srt.mem_cache.memory_pool import (
51
- KVCache,
52
- ReqToTokenPool,
53
- TokenToKVPoolAllocator,
54
- )
50
+ from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
55
51
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
56
52
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
+ from sglang.srt.utils import require_mlp_sync
57
54
 
58
55
  logger = logging.getLogger(__name__)
59
56
 
@@ -90,7 +87,7 @@ class DecodeReqToTokenPool:
90
87
  self.max_context_len = max_context_len
91
88
  self.device = device
92
89
  self.pre_alloc_size = pre_alloc_size
93
- with memory_saver_adapter.region():
90
+ with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
94
91
  self.req_to_token = torch.zeros(
95
92
  (size + pre_alloc_size, max_context_len),
96
93
  dtype=torch.int32,
@@ -139,7 +136,7 @@ class DecodePreallocQueue:
139
136
  def __init__(
140
137
  self,
141
138
  req_to_token_pool: ReqToTokenPool,
142
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
139
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
143
140
  draft_token_to_kv_pool: Optional[KVCache],
144
141
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
145
142
  metadata_buffers: MetadataBuffers,
@@ -540,6 +537,7 @@ class DecodeTransferQueue:
540
537
  self.metadata_buffers = metadata_buffers
541
538
  self.scheduler = scheduler
542
539
  self.tree_cache = tree_cache
540
+ self.spec_algorithm = scheduler.spec_algorithm
543
541
 
544
542
  def add(self, decode_req: DecodeRequest) -> None:
545
543
  self.queue.append(decode_req)
@@ -585,10 +583,12 @@ class DecodeTransferQueue:
585
583
  output_token_logprobs_idx,
586
584
  output_top_logprobs_val,
587
585
  output_top_logprobs_idx,
586
+ output_hidden_states,
588
587
  ) = self.metadata_buffers.get_buf(idx)
589
588
 
590
589
  decode_req.req.output_ids.append(output_id[0].item())
591
-
590
+ if not self.spec_algorithm.is_none():
591
+ decode_req.req.hidden_states_tensor = output_hidden_states
592
592
  if decode_req.req.return_logprob:
593
593
  decode_req.req.output_token_logprobs_val.append(
594
594
  output_token_logprobs_val[0].item()
@@ -645,10 +645,7 @@ class SchedulerDisaggregationDecodeMixin:
645
645
  batch = self.get_next_disagg_decode_batch_to_run()
646
646
  self.cur_batch = batch
647
647
 
648
- prepare_dp_attn_flag = (
649
- self.server_args.enable_dp_attention
650
- or self.server_args.enable_sp_layernorm
651
- )
648
+ prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
652
649
 
653
650
  if batch:
654
651
  # Generate fake extend output.
@@ -657,14 +654,14 @@ class SchedulerDisaggregationDecodeMixin:
657
654
  self.stream_output(
658
655
  batch.reqs, any(req.return_logprob for req in batch.reqs)
659
656
  )
660
- if prepare_dp_attn_flag:
657
+ if prepare_mlp_sync_flag:
661
658
  self._prepare_idle_batch_and_run(None)
662
659
  else:
663
- if prepare_dp_attn_flag:
664
- self.prepare_dp_attn_batch(batch)
660
+ if prepare_mlp_sync_flag:
661
+ self.prepare_mlp_sync_batch(batch)
665
662
  result = self.run_batch(batch)
666
663
  self.process_batch_result(batch, result)
667
- elif prepare_dp_attn_flag:
664
+ elif prepare_mlp_sync_flag:
668
665
  batch, _ = self._prepare_idle_batch_and_run(None)
669
666
 
670
667
  if batch is None and (
@@ -695,10 +692,7 @@ class SchedulerDisaggregationDecodeMixin:
695
692
  self.cur_batch = batch
696
693
  last_batch_in_queue = False
697
694
 
698
- prepare_dp_attn_flag = (
699
- self.server_args.enable_dp_attention
700
- or self.server_args.enable_sp_layernorm
701
- )
695
+ prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
702
696
 
703
697
  if batch:
704
698
  # Generate fake extend output.
@@ -707,7 +701,7 @@ class SchedulerDisaggregationDecodeMixin:
707
701
  self.stream_output(
708
702
  batch.reqs, any(req.return_logprob for req in batch.reqs)
709
703
  )
710
- if prepare_dp_attn_flag:
704
+ if prepare_mlp_sync_flag:
711
705
  batch_, result = self._prepare_idle_batch_and_run(
712
706
  None, delay_process=True
713
707
  )
@@ -715,8 +709,8 @@ class SchedulerDisaggregationDecodeMixin:
715
709
  result_queue.append((batch_.copy(), result))
716
710
  last_batch_in_queue = True
717
711
  else:
718
- if prepare_dp_attn_flag:
719
- self.prepare_dp_attn_batch(batch)
712
+ if prepare_mlp_sync_flag:
713
+ self.prepare_mlp_sync_batch(batch)
720
714
  result = self.run_batch(batch)
721
715
  result_queue.append((batch.copy(), result))
722
716
 
@@ -731,7 +725,7 @@ class SchedulerDisaggregationDecodeMixin:
731
725
  self.set_next_batch_sampling_info_done(tmp_batch)
732
726
  last_batch_in_queue = True
733
727
 
734
- elif prepare_dp_attn_flag:
728
+ elif prepare_mlp_sync_flag:
735
729
  batch, result = self._prepare_idle_batch_and_run(
736
730
  None, delay_process=True
737
731
  )
@@ -761,8 +755,8 @@ class SchedulerDisaggregationDecodeMixin:
761
755
  self.last_batch = batch
762
756
  self.last_batch_in_queue = last_batch_in_queue
763
757
 
764
- def _prepare_idle_batch_and_run(self, batch, delay_process=False):
765
- batch, _ = self.prepare_dp_attn_batch(batch)
758
+ def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
759
+ batch, _ = self.prepare_mlp_sync_batch(batch)
766
760
  result = None
767
761
  if batch:
768
762
  result = self.run_batch(batch)
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
126
126
  )
127
127
  topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
128
128
 
129
+ hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
130
+ hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
131
+
129
132
  # local import to avoid circular import
130
133
  from sglang.srt.speculative.eagle_utils import EagleDraftInput
131
134
 
132
135
  spec_info = EagleDraftInput(
133
136
  topk_p=topk_p,
134
137
  topk_index=topk_index,
135
- hidden_states=torch.ones(
136
- (b, model_config.hidden_size), device=self.device
137
- ),
138
+ hidden_states=hidden_states,
138
139
  verified_id=self.output_ids,
139
140
  )
140
141
  spec_info.prepare_for_extend(self)
@@ -18,6 +18,10 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
18
18
 
19
19
  from sglang.srt.disaggregation.utils import PDRegistryRequest
20
20
 
21
+ AIOHTTP_STREAM_READ_CHUNK_SIZE = (
22
+ 1024 * 64
23
+ ) # 64KB, to prevent aiohttp's "Chunk too big" error
24
+
21
25
 
22
26
  def setup_logger():
23
27
  logger = logging.getLogger("pdlb")
@@ -154,7 +158,9 @@ class MiniLoadBalancer:
154
158
  else:
155
159
  yield chunk
156
160
  else:
157
- async for chunk in decode_response.content:
161
+ async for chunk in decode_response.content.iter_chunked(
162
+ AIOHTTP_STREAM_READ_CHUNK_SIZE
163
+ ):
158
164
  yield chunk
159
165
 
160
166
  return StreamingResponse(
@@ -212,15 +218,39 @@ async def get_server_info():
212
218
  )
213
219
  prefill_infos = []
214
220
  decode_infos = []
221
+ all_internal_states = []
222
+
215
223
  async with aiohttp.ClientSession() as session:
216
224
  for server in chain(prefill_servers):
217
225
  server_info = await session.get(f"{server}/get_server_info")
218
226
  prefill_infos.append(await server_info.json())
219
227
  for server in chain(decode_servers):
220
228
  server_info = await session.get(f"{server}/get_server_info")
221
- decode_infos.append(await server_info.json())
222
-
223
- return {"prefill": prefill_infos, "decode": decode_infos}
229
+ info_json = await server_info.json()
230
+ decode_infos.append(info_json)
231
+ # Extract internal_states from decode servers
232
+ if "internal_states" in info_json:
233
+ all_internal_states.extend(info_json["internal_states"])
234
+
235
+ # Return format expected by bench_one_batch_server.py
236
+ if all_internal_states:
237
+ return {
238
+ "internal_states": all_internal_states,
239
+ "prefill": prefill_infos,
240
+ "decode": decode_infos,
241
+ }
242
+ else:
243
+ # Fallback with dummy data if no internal states found
244
+ return {
245
+ "internal_states": [
246
+ {
247
+ "last_gen_throughput": 0.0,
248
+ "avg_spec_accept_length": None,
249
+ }
250
+ ],
251
+ "prefill": prefill_infos,
252
+ "decode": decode_infos,
253
+ }
224
254
 
225
255
 
226
256
  @app.get("/get_model_info")