sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
71
71
  configure_logger,
72
72
  get_bool_env_var,
73
73
  kill_process_tree,
74
+ require_mlp_sync,
75
+ require_mlp_tp_gather,
74
76
  set_gpu_proc_affinity,
75
77
  suppress_other_loggers,
76
78
  )
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
243
245
  enable_custom_logit_processor=False,
244
246
  )
245
247
  batch.prepare_for_extend()
246
- _maybe_prepare_dp_attn_batch(batch, model_runner)
248
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
247
249
  model_worker_batch = batch.get_model_worker_batch()
248
250
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
249
251
  logits_output, _ = model_runner.forward(forward_batch)
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
255
257
  def decode(input_token_ids, batch, model_runner):
256
258
  batch.output_ids = input_token_ids
257
259
  batch.prepare_for_decode()
258
- _maybe_prepare_dp_attn_batch(batch, model_runner)
260
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
259
261
  model_worker_batch = batch.get_model_worker_batch()
260
262
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
261
263
  logits_output, _ = model_runner.forward(forward_batch)
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
263
265
  return next_token_ids, logits_output.next_token_logits
264
266
 
265
267
 
266
- def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
267
- if model_runner.server_args.enable_dp_attention:
268
- Scheduler.prepare_dp_attn_batch_raw(
268
+ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
269
+ if require_mlp_sync(model_runner.server_args):
270
+ Scheduler.prepare_mlp_sync_batch_raw(
269
271
  batch,
270
272
  dp_size=model_runner.server_args.dp_size,
271
273
  attn_tp_size=1,
272
- moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
273
274
  tp_cpu_group=model_runner.tp_group.cpu_group,
274
275
  get_idle_batch=None,
275
276
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
276
277
  spec_algorithm=SpeculativeAlgorithm.NONE,
277
278
  speculative_num_draft_tokens=None,
279
+ require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
278
280
  )
279
281
 
280
282
 
sglang/srt/_custom_ops.py CHANGED
@@ -4,7 +4,7 @@ from typing import List, Tuple
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
7
+ from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
  use_vllm_custom_allreduce = get_bool_env_var(
@@ -25,7 +25,7 @@ if not is_hpu():
25
25
  logger.warning("Failed to import from custom_ar with %r", e)
26
26
 
27
27
 
28
- if not is_hip():
28
+ if not is_hip() and not is_npu():
29
29
  if use_vllm_custom_allreduce:
30
30
  custom_op = torch.ops._C_custom_ar
31
31
  else:
@@ -15,12 +15,10 @@
15
15
 
16
16
 
17
17
  import dataclasses
18
- import json
19
18
  import logging
20
- import os
21
19
  from enum import auto
22
20
 
23
- from sglang.srt.openai_api.protocol import ChatCompletionRequest
21
+ from sglang.srt.entrypoints.openai.protocol import CompletionRequest
24
22
 
25
23
  logger = logging.getLogger(__name__)
26
24
  completion_template_name = None
@@ -57,46 +55,6 @@ class CompletionTemplate:
57
55
  completion_templates: dict[str, CompletionTemplate] = {}
58
56
 
59
57
 
60
- def load_completion_template_for_openai_api(completion_template_arg):
61
- global completion_template_name
62
-
63
- logger.info(
64
- f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
65
- )
66
-
67
- if not completion_template_exists(completion_template_arg):
68
- if not os.path.exists(completion_template_arg):
69
- raise RuntimeError(
70
- f"Completion template {completion_template_arg} is not a built-in template name "
71
- "or a valid completion template file path."
72
- )
73
-
74
- assert completion_template_arg.endswith(
75
- ".json"
76
- ), "unrecognized format of completion template file"
77
- with open(completion_template_arg, "r") as filep:
78
- template = json.load(filep)
79
- try:
80
- fim_position = FimPosition[template["fim_position"]]
81
- except KeyError:
82
- raise ValueError(
83
- f"Unknown fim position: {template['fim_position']}"
84
- ) from None
85
- register_completion_template(
86
- CompletionTemplate(
87
- name=template["name"],
88
- fim_begin_token=template["fim_begin_token"],
89
- fim_middle_token=template["fim_middle_token"],
90
- fim_end_token=template["fim_end_token"],
91
- fim_position=fim_position,
92
- ),
93
- override=True,
94
- )
95
- completion_template_name = template["name"]
96
- else:
97
- completion_template_name = completion_template_arg
98
-
99
-
100
58
  def register_completion_template(template: CompletionTemplate, override: bool = False):
101
59
  """Register a new completion template."""
102
60
  if not override:
@@ -116,7 +74,7 @@ def is_completion_template_defined() -> bool:
116
74
  return completion_template_name is not None
117
75
 
118
76
 
119
- def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
77
+ def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
120
78
  global completion_template_name
121
79
  if request.suffix == "":
122
80
  return request.prompt
@@ -0,0 +1,3 @@
1
+ # GPU Memory Types
2
+ GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
3
+ GPU_MEMORY_TYPE_WEIGHTS = "weights"
@@ -11,7 +11,17 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
- """Conversation chat templates."""
14
+ """Conversation chat templates.
15
+
16
+ This module provides conversation template definitions, data structures, and utilities
17
+ for managing chat templates across different model types in SGLang.
18
+
19
+ Key components:
20
+ - Conversation class: Defines the structure and behavior of chat templates
21
+ - SeparatorStyle enum: Different conversation formatting styles
22
+ - Template registry: Functions to register and retrieve templates by name or model path
23
+ - Built-in templates: Pre-defined templates for popular models
24
+ """
15
25
 
16
26
  # Adapted from
17
27
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -20,7 +30,7 @@ import re
20
30
  from enum import IntEnum, auto
21
31
  from typing import Callable, Dict, List, Optional, Tuple, Union
22
32
 
23
- from sglang.srt.openai_api.protocol import ChatCompletionRequest
33
+ from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
24
34
  from sglang.srt.utils import read_system_prompt_from_file
25
35
 
26
36
 
@@ -618,7 +628,7 @@ def generate_chat_conv(
618
628
 
619
629
 
620
630
  # llama2 template
621
- # reference: https://huggingface.co/blog/codellama#conversational-instructions
631
+ # reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
622
632
  # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
623
633
  register_conv_template(
624
634
  Conversation(
sglang/srt/custom_op.py CHANGED
@@ -1,9 +1,11 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import is_cuda, is_hip
3
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
4
4
 
5
5
  _is_cuda = is_cuda()
6
6
  _is_hip = is_hip()
7
+ _is_cpu = is_cpu()
8
+ _is_cpu_amx_available = cpu_has_amx_support()
7
9
 
8
10
 
9
11
  class CustomOp(nn.Module):
@@ -75,5 +77,7 @@ class CustomOp(nn.Module):
75
77
  return self.forward_cuda
76
78
  elif _is_hip:
77
79
  return self.forward_hip
80
+ elif _is_cpu and _is_cpu_amx_available:
81
+ return self.forward_cpu
78
82
  else:
79
83
  return self.forward_native
@@ -21,16 +21,15 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
- import os
25
24
  from collections import deque
26
25
  from dataclasses import dataclass
27
26
  from http import HTTPStatus
28
27
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
29
28
 
30
- import numpy as np
31
29
  import torch
32
30
  from torch.distributed import ProcessGroup
33
31
 
32
+ from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
34
33
  from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
35
34
  from sglang.srt.disaggregation.utils import (
36
35
  FAKE_BOOTSTRAP_HOST,
@@ -46,14 +45,12 @@ from sglang.srt.disaggregation.utils import (
46
45
  prepare_abort,
47
46
  )
48
47
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
48
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
49
49
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
50
- from sglang.srt.mem_cache.memory_pool import (
51
- KVCache,
52
- ReqToTokenPool,
53
- TokenToKVPoolAllocator,
54
- )
50
+ from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
55
51
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
56
52
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
+ from sglang.srt.utils import require_mlp_sync
57
54
 
58
55
  logger = logging.getLogger(__name__)
59
56
 
@@ -90,7 +87,7 @@ class DecodeReqToTokenPool:
90
87
  self.max_context_len = max_context_len
91
88
  self.device = device
92
89
  self.pre_alloc_size = pre_alloc_size
93
- with memory_saver_adapter.region():
90
+ with memory_saver_adapter.region(tag=GPU_MEMORY_TYPE_KV_CACHE):
94
91
  self.req_to_token = torch.zeros(
95
92
  (size + pre_alloc_size, max_context_len),
96
93
  dtype=torch.int32,
@@ -139,7 +136,7 @@ class DecodePreallocQueue:
139
136
  def __init__(
140
137
  self,
141
138
  req_to_token_pool: ReqToTokenPool,
142
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
139
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
143
140
  draft_token_to_kv_pool: Optional[KVCache],
144
141
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
145
142
  metadata_buffers: MetadataBuffers,
@@ -540,6 +537,7 @@ class DecodeTransferQueue:
540
537
  self.metadata_buffers = metadata_buffers
541
538
  self.scheduler = scheduler
542
539
  self.tree_cache = tree_cache
540
+ self.spec_algorithm = scheduler.spec_algorithm
543
541
 
544
542
  def add(self, decode_req: DecodeRequest) -> None:
545
543
  self.queue.append(decode_req)
@@ -581,6 +579,7 @@ class DecodeTransferQueue:
581
579
  idx = decode_req.metadata_buffer_index
582
580
  (
583
581
  output_id,
582
+ output_hidden_states,
584
583
  output_token_logprobs_val,
585
584
  output_token_logprobs_idx,
586
585
  output_top_logprobs_val,
@@ -588,7 +587,8 @@ class DecodeTransferQueue:
588
587
  ) = self.metadata_buffers.get_buf(idx)
589
588
 
590
589
  decode_req.req.output_ids.append(output_id[0].item())
591
-
590
+ if not self.spec_algorithm.is_none():
591
+ decode_req.req.hidden_states_tensor = output_hidden_states
592
592
  if decode_req.req.return_logprob:
593
593
  decode_req.req.output_token_logprobs_val.append(
594
594
  output_token_logprobs_val[0].item()
@@ -645,10 +645,7 @@ class SchedulerDisaggregationDecodeMixin:
645
645
  batch = self.get_next_disagg_decode_batch_to_run()
646
646
  self.cur_batch = batch
647
647
 
648
- prepare_dp_attn_flag = (
649
- self.server_args.enable_dp_attention
650
- or self.server_args.enable_sp_layernorm
651
- )
648
+ prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
652
649
 
653
650
  if batch:
654
651
  # Generate fake extend output.
@@ -657,14 +654,14 @@ class SchedulerDisaggregationDecodeMixin:
657
654
  self.stream_output(
658
655
  batch.reqs, any(req.return_logprob for req in batch.reqs)
659
656
  )
660
- if prepare_dp_attn_flag:
657
+ if prepare_mlp_sync_flag:
661
658
  self._prepare_idle_batch_and_run(None)
662
659
  else:
663
- if prepare_dp_attn_flag:
664
- self.prepare_dp_attn_batch(batch)
660
+ if prepare_mlp_sync_flag:
661
+ self.prepare_mlp_sync_batch(batch)
665
662
  result = self.run_batch(batch)
666
663
  self.process_batch_result(batch, result)
667
- elif prepare_dp_attn_flag:
664
+ elif prepare_mlp_sync_flag:
668
665
  batch, _ = self._prepare_idle_batch_and_run(None)
669
666
 
670
667
  if batch is None and (
@@ -695,10 +692,7 @@ class SchedulerDisaggregationDecodeMixin:
695
692
  self.cur_batch = batch
696
693
  last_batch_in_queue = False
697
694
 
698
- prepare_dp_attn_flag = (
699
- self.server_args.enable_dp_attention
700
- or self.server_args.enable_sp_layernorm
701
- )
695
+ prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
702
696
 
703
697
  if batch:
704
698
  # Generate fake extend output.
@@ -707,7 +701,7 @@ class SchedulerDisaggregationDecodeMixin:
707
701
  self.stream_output(
708
702
  batch.reqs, any(req.return_logprob for req in batch.reqs)
709
703
  )
710
- if prepare_dp_attn_flag:
704
+ if prepare_mlp_sync_flag:
711
705
  batch_, result = self._prepare_idle_batch_and_run(
712
706
  None, delay_process=True
713
707
  )
@@ -715,8 +709,8 @@ class SchedulerDisaggregationDecodeMixin:
715
709
  result_queue.append((batch_.copy(), result))
716
710
  last_batch_in_queue = True
717
711
  else:
718
- if prepare_dp_attn_flag:
719
- self.prepare_dp_attn_batch(batch)
712
+ if prepare_mlp_sync_flag:
713
+ self.prepare_mlp_sync_batch(batch)
720
714
  result = self.run_batch(batch)
721
715
  result_queue.append((batch.copy(), result))
722
716
 
@@ -731,7 +725,7 @@ class SchedulerDisaggregationDecodeMixin:
731
725
  self.set_next_batch_sampling_info_done(tmp_batch)
732
726
  last_batch_in_queue = True
733
727
 
734
- elif prepare_dp_attn_flag:
728
+ elif prepare_mlp_sync_flag:
735
729
  batch, result = self._prepare_idle_batch_and_run(
736
730
  None, delay_process=True
737
731
  )
@@ -761,8 +755,8 @@ class SchedulerDisaggregationDecodeMixin:
761
755
  self.last_batch = batch
762
756
  self.last_batch_in_queue = last_batch_in_queue
763
757
 
764
- def _prepare_idle_batch_and_run(self, batch, delay_process=False):
765
- batch, _ = self.prepare_dp_attn_batch(batch)
758
+ def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
759
+ batch, _ = self.prepare_mlp_sync_batch(batch)
766
760
  result = None
767
761
  if batch:
768
762
  result = self.run_batch(batch)
@@ -126,15 +126,16 @@ class ScheduleBatchDisaggregationDecodeMixin:
126
126
  )
127
127
  topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
128
128
 
129
+ hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
130
+ hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
131
+
129
132
  # local import to avoid circular import
130
133
  from sglang.srt.speculative.eagle_utils import EagleDraftInput
131
134
 
132
135
  spec_info = EagleDraftInput(
133
136
  topk_p=topk_p,
134
137
  topk_index=topk_index,
135
- hidden_states=torch.ones(
136
- (b, model_config.hidden_size), device=self.device
137
- ),
138
+ hidden_states=hidden_states,
138
139
  verified_id=self.output_ids,
139
140
  )
140
141
  spec_info.prepare_for_extend(self)
@@ -18,6 +18,10 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
18
18
 
19
19
  from sglang.srt.disaggregation.utils import PDRegistryRequest
20
20
 
21
+ AIOHTTP_STREAM_READ_CHUNK_SIZE = (
22
+ 1024 * 64
23
+ ) # 64KB, to prevent aiohttp's "Chunk too big" error
24
+
21
25
 
22
26
  def setup_logger():
23
27
  logger = logging.getLogger("pdlb")
@@ -154,7 +158,9 @@ class MiniLoadBalancer:
154
158
  else:
155
159
  yield chunk
156
160
  else:
157
- async for chunk in decode_response.content:
161
+ async for chunk in decode_response.content.iter_chunked(
162
+ AIOHTTP_STREAM_READ_CHUNK_SIZE
163
+ ):
158
164
  yield chunk
159
165
 
160
166
  return StreamingResponse(
@@ -212,15 +218,39 @@ async def get_server_info():
212
218
  )
213
219
  prefill_infos = []
214
220
  decode_infos = []
221
+ all_internal_states = []
222
+
215
223
  async with aiohttp.ClientSession() as session:
216
224
  for server in chain(prefill_servers):
217
225
  server_info = await session.get(f"{server}/get_server_info")
218
226
  prefill_infos.append(await server_info.json())
219
227
  for server in chain(decode_servers):
220
228
  server_info = await session.get(f"{server}/get_server_info")
221
- decode_infos.append(await server_info.json())
222
-
223
- return {"prefill": prefill_infos, "decode": decode_infos}
229
+ info_json = await server_info.json()
230
+ decode_infos.append(info_json)
231
+ # Extract internal_states from decode servers
232
+ if "internal_states" in info_json:
233
+ all_internal_states.extend(info_json["internal_states"])
234
+
235
+ # Return format expected by bench_one_batch_server.py
236
+ if all_internal_states:
237
+ return {
238
+ "internal_states": all_internal_states,
239
+ "prefill": prefill_infos,
240
+ "decode": decode_infos,
241
+ }
242
+ else:
243
+ # Fallback with dummy data if no internal states found
244
+ return {
245
+ "internal_states": [
246
+ {
247
+ "last_gen_throughput": 0.0,
248
+ "avg_spec_accept_length": None,
249
+ }
250
+ ],
251
+ "prefill": prefill_infos,
252
+ "decode": decode_infos,
253
+ }
224
254
 
225
255
 
226
256
  @app.get("/get_model_info")
@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
35
35
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
36
  from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
- from sglang.srt.utils import (
39
- get_free_port,
40
- get_int_env_var,
41
- get_ip,
42
- get_local_ip_by_remote,
43
- )
38
+ from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
44
39
 
45
40
  logger = logging.getLogger(__name__)
46
41
 
@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
130
125
  is_mla_backend: Optional[bool] = False,
131
126
  ):
132
127
  self.kv_args = args
128
+ self.local_ip = get_local_ip_auto()
133
129
  self.engine = MooncakeTransferEngine(
134
- hostname=get_local_ip_by_remote(),
130
+ hostname=self.local_ip,
135
131
  gpu_id=self.kv_args.gpu_id,
136
132
  ib_device=self.kv_args.ib_device,
137
133
  )
@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
432
428
 
433
429
  def start_prefill_thread(self):
434
430
  self.rank_port = get_free_port()
435
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
431
+ self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
436
432
 
437
433
  def bootstrap_thread():
438
434
  """This thread recvs pre-alloc notification from the decode engine"""
@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
471
467
 
472
468
  def start_decode_thread(self):
473
469
  self.rank_port = get_free_port()
474
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
470
+ self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
475
471
 
476
472
  def decode_thread():
477
473
  while True:
@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
620
616
  "role": "Prefill",
621
617
  "tp_size": self.tp_size,
622
618
  "dp_size": self.dp_size,
623
- "rank_ip": get_local_ip_by_remote(),
619
+ "rank_ip": self.local_ip,
624
620
  "rank_port": self.rank_port,
625
621
  "engine_rank": self.kv_args.engine_rank,
626
622
  }
@@ -746,12 +742,12 @@ class MooncakeKVSender(BaseKVSender):
746
742
  self.kv_mgr.request_status.pop(self.bootstrap_room)
747
743
 
748
744
  def failure_exception(self):
749
- self.clear()
750
-
751
745
  # Explicitly set the status to failure since this request has failed in another rank
752
746
  if self.conclude_state is None:
753
747
  self.conclude_state = KVPoll.Failed
754
748
 
749
+ self.clear()
750
+
755
751
  with self.kv_mgr.failure_lock:
756
752
  failure_reason = self.kv_mgr.failure_records.pop(
757
753
  self.bootstrap_room, "Failed due to an unknown reason from another rank"
@@ -953,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
953
949
  sock.send_multipart(
954
950
  [
955
951
  "None".encode("ascii"),
956
- get_local_ip_by_remote().encode("ascii"),
952
+ self.kv_mgr.local_ip.encode("ascii"),
957
953
  str(self.kv_mgr.rank_port).encode("ascii"),
958
954
  self.session_id.encode("ascii"),
959
955
  packed_kv_data_ptrs,
@@ -983,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
983
979
  sock.send_multipart(
984
980
  [
985
981
  str(self.bootstrap_room).encode("ascii"),
986
- get_local_ip_by_remote().encode("ascii"),
982
+ self.kv_mgr.local_ip.encode("ascii"),
987
983
  str(self.kv_mgr.rank_port).encode("ascii"),
988
984
  self.session_id.encode("ascii"),
989
985
  kv_indices.tobytes() if not is_dummy else b"",
@@ -1007,12 +1003,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
1007
1003
  self.kv_mgr.request_status.pop(self.bootstrap_room)
1008
1004
 
1009
1005
  def failure_exception(self):
1010
- self.clear()
1011
-
1012
1006
  # Explicitly set the status to failure since this request has failed in another rank
1013
1007
  if self.conclude_state is None:
1014
1008
  self.conclude_state = KVPoll.Failed
1015
1009
 
1010
+ self.clear()
1011
+
1016
1012
  with self.kv_mgr.failure_lock:
1017
1013
  failure_reason = self.kv_mgr.failure_records.pop(
1018
1014
  self.bootstrap_room, "Failed due to an unknown reason from another rank"
@@ -25,7 +25,6 @@ from collections import deque
25
25
  from http import HTTPStatus
26
26
  from typing import TYPE_CHECKING, List, Optional
27
27
 
28
- import numpy as np
29
28
  import torch
30
29
 
31
30
  from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
@@ -45,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
45
44
  )
46
45
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
47
46
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
+ from sglang.srt.utils import require_mlp_sync
48
48
 
49
49
  if TYPE_CHECKING:
50
50
  from torch.distributed import ProcessGroup
@@ -274,12 +274,8 @@ class SchedulerDisaggregationPrefillMixin:
274
274
  self.process_prefill_chunk()
275
275
  batch = self.get_new_batch_prefill()
276
276
 
277
- # Handle DP attention
278
- if (
279
- self.server_args.enable_dp_attention
280
- or self.server_args.enable_sp_layernorm
281
- ):
282
- batch, _ = self.prepare_dp_attn_batch(batch)
277
+ if require_mlp_sync(self.server_args):
278
+ batch, _ = self.prepare_mlp_sync_batch(batch)
283
279
  self.cur_batch = batch
284
280
 
285
281
  if batch:
@@ -312,12 +308,8 @@ class SchedulerDisaggregationPrefillMixin:
312
308
  self.process_prefill_chunk()
313
309
  batch = self.get_new_batch_prefill()
314
310
 
315
- # Handle DP attention
316
- if (
317
- self.server_args.enable_dp_attention
318
- or self.server_args.enable_sp_layernorm
319
- ):
320
- batch, _ = self.prepare_dp_attn_batch(batch)
311
+ if require_mlp_sync(self.server_args):
312
+ batch, _ = self.prepare_mlp_sync_batch(batch)
321
313
  self.cur_batch = batch
322
314
  if batch:
323
315
  result = self.run_batch(batch)
@@ -393,6 +385,8 @@ class SchedulerDisaggregationPrefillMixin:
393
385
  logits_output.input_token_logprobs = tuple(
394
386
  logits_output.input_token_logprobs.tolist()
395
387
  )
388
+
389
+ hidden_state_offset = 0
396
390
  for i, (req, next_token_id) in enumerate(
397
391
  zip(batch.reqs, next_token_ids, strict=True)
398
392
  ):
@@ -402,6 +396,16 @@ class SchedulerDisaggregationPrefillMixin:
402
396
  req.output_ids.append(next_token_id)
403
397
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
404
398
  self.disagg_prefill_inflight_queue.append(req)
399
+ if logits_output.hidden_states is not None:
400
+ last_hidden_index = (
401
+ hidden_state_offset + extend_input_len_per_req[i] - 1
402
+ )
403
+ req.hidden_states_tensor = (
404
+ logits_output.hidden_states[last_hidden_index].cpu().clone()
405
+ )
406
+ hidden_state_offset += extend_input_len_per_req[i]
407
+ else:
408
+ req.hidden_states_tensor = None
405
409
  if req.return_logprob:
406
410
  assert extend_logprob_start_len_per_req is not None
407
411
  assert extend_input_len_per_req is not None