sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import logging
4
+ import signal
5
+ import sys
3
6
  import time
4
7
  from typing import TYPE_CHECKING
5
8
 
@@ -7,10 +10,13 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
7
10
  from sglang.srt.managers.schedule_batch import ScheduleBatch
8
11
  from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
9
12
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
13
+ from sglang.srt.utils.common import disable_request_logging, pyspy_dump_schedulers
10
14
 
11
15
  if TYPE_CHECKING:
12
16
  from sglang.srt.managers.scheduler import Scheduler
13
17
 
18
+ logger = logging.getLogger(__name__)
19
+
14
20
 
15
21
  class SchedulerRuntimeCheckerMixin:
16
22
 
@@ -215,3 +221,42 @@ class SchedulerRuntimeCheckerMixin:
215
221
  self.check_tree_cache()
216
222
  self.new_token_ratio = self.init_new_token_ratio
217
223
  self.maybe_sleep_on_idle()
224
+
225
+ def watchdog_thread(self: Scheduler):
226
+ """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
227
+ self.watchdog_last_forward_ct = 0
228
+ self.watchdog_last_time = time.perf_counter()
229
+
230
+ while True:
231
+ current = time.perf_counter()
232
+ if self.cur_batch is not None:
233
+ if self.watchdog_last_forward_ct == self.forward_ct:
234
+ if current > self.watchdog_last_time + self.watchdog_timeout:
235
+ break
236
+ else:
237
+ self.watchdog_last_forward_ct = self.forward_ct
238
+ self.watchdog_last_time = current
239
+ time.sleep(self.watchdog_timeout // 2)
240
+
241
+ if not disable_request_logging():
242
+ # Print batch size and memory pool info to check whether there are de-sync issues.
243
+ if self.is_hybrid:
244
+ _, info_msg = self._check_hybrid_memory()
245
+ elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
246
+ _, info_msg = self._check_mamba_memory()
247
+ else:
248
+ _, info_msg = self._check_radix_cache_memory()
249
+ logger.error(
250
+ f"{self.cur_batch.batch_size()=}\n"
251
+ f"{self.cur_batch.reqs=}\n"
252
+ f"{info_msg}"
253
+ )
254
+
255
+ pyspy_dump_schedulers()
256
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
257
+ print(file=sys.stderr, flush=True)
258
+ print(file=sys.stdout, flush=True)
259
+
260
+ # Wait for some time so that the parent process can print the error.
261
+ time.sleep(5)
262
+ self.parent_process.send_signal(signal.SIGQUIT)
@@ -5,7 +5,12 @@ from typing import TYPE_CHECKING, Tuple
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
8
+ from sglang.srt.constants import (
9
+ GPU_MEMORY_ALL_TYPES,
10
+ GPU_MEMORY_TYPE_CUDA_GRAPH,
11
+ GPU_MEMORY_TYPE_KV_CACHE,
12
+ GPU_MEMORY_TYPE_WEIGHTS,
13
+ )
9
14
  from sglang.srt.managers.io_struct import (
10
15
  DestroyWeightsUpdateGroupReqInput,
11
16
  DestroyWeightsUpdateGroupReqOutput,
@@ -101,10 +106,14 @@ class SchedulerUpdateWeightsMixin:
101
106
  def release_memory_occupation(
102
107
  self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
103
108
  ):
109
+ assert (
110
+ self._is_no_request()
111
+ ), "release_memory_occupation should be called only when no ongoing request."
112
+
104
113
  tags = recv_req.tags
105
114
 
106
115
  if tags is None or len(tags) == 0:
107
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
116
+ tags = GPU_MEMORY_ALL_TYPES
108
117
 
109
118
  for tag in tags:
110
119
  self.offload_tags.add(tag)
@@ -120,6 +129,9 @@ class SchedulerUpdateWeightsMixin:
120
129
  torch.distributed.barrier(self.tp_cpu_group)
121
130
  self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
122
131
 
132
+ if GPU_MEMORY_TYPE_CUDA_GRAPH in tags:
133
+ self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_CUDA_GRAPH)
134
+
123
135
  return ReleaseMemoryOccupationReqOutput()
124
136
 
125
137
  def resume_memory_occupation(
@@ -128,11 +140,14 @@ class SchedulerUpdateWeightsMixin:
128
140
  tags = recv_req.tags
129
141
 
130
142
  if tags is None or len(tags) == 0:
131
- tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
143
+ tags = GPU_MEMORY_ALL_TYPES
132
144
 
133
145
  for tag in tags:
134
146
  self.offload_tags.remove(tag)
135
147
 
148
+ if GPU_MEMORY_TYPE_CUDA_GRAPH in tags:
149
+ self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_CUDA_GRAPH)
150
+
136
151
  if GPU_MEMORY_TYPE_WEIGHTS in tags:
137
152
  self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
138
153
  torch.distributed.barrier(self.tp_cpu_group)
@@ -15,11 +15,11 @@ import uuid
15
15
  from typing import Dict, Optional
16
16
 
17
17
  from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
18
- from sglang.srt.managers.schedule_batch import Req
18
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
19
19
 
20
20
 
21
21
  class SessionReqNode:
22
- def __init__(self, req, parent=None, childs=None):
22
+ def __init__(self, req: Req, parent=None, childs=None):
23
23
  self.req = req
24
24
  self.parent = parent
25
25
  if parent is not None:
@@ -36,12 +36,12 @@ class SessionReqNode:
36
36
  req_node.clear(req_dict)
37
37
 
38
38
  if self.req.finished_reason is None:
39
- self.req.to_abort = True
39
+ self.req.to_finish = FINISH_ABORT()
40
40
  del req_dict[self.req.rid]
41
41
 
42
42
  def abort(self):
43
43
  if self.req.finished_reason is None:
44
- self.req.to_abort = True
44
+ self.req.to_finish = FINISH_ABORT()
45
45
 
46
46
  def __str__(self):
47
47
  return self._str_helper(self.req.rid)
@@ -137,13 +137,14 @@ class Session:
137
137
  origin_input_ids=input_ids,
138
138
  origin_input_ids_unpadded=input_ids_unpadded,
139
139
  sampling_params=req.sampling_params,
140
- lora_path=req.lora_path,
140
+ lora_id=req.lora_id,
141
141
  session_id=self.session_id,
142
142
  custom_logit_processor=req.custom_logit_processor,
143
143
  stream=req.stream,
144
144
  return_logprob=req.return_logprob,
145
145
  top_logprobs_num=req.top_logprobs_num,
146
146
  token_ids_logprob=req.token_ids_logprob,
147
+ vocab_size=tokenizer.vocab_size,
147
148
  )
148
149
  if last_req is not None:
149
150
  new_req.multimodal_inputs = last_req.multimodal_inputs
@@ -43,6 +43,7 @@ from sglang.srt.configs.model_config import ModelConfig
43
43
  from sglang.srt.disaggregation.utils import DisaggregationMode
44
44
  from sglang.srt.lora.lora_registry import LoRARegistry
45
45
  from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
46
+ from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
46
47
  from sglang.srt.managers.disagg_service import start_disagg_service
47
48
  from sglang.srt.managers.io_struct import (
48
49
  AbortReq,
@@ -68,6 +69,7 @@ from sglang.srt.managers.io_struct import (
68
69
  )
69
70
  from sglang.srt.managers.mm_utils import TensorTransportMode
70
71
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
72
+ from sglang.srt.managers.schedule_batch import RequestStage
71
73
  from sglang.srt.managers.scheduler import is_health_check_generate_req
72
74
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
73
75
  from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
@@ -79,6 +81,7 @@ from sglang.srt.tracing.trace import (
79
81
  trace_get_proc_propagate_context,
80
82
  trace_req_finish,
81
83
  trace_req_start,
84
+ trace_set_remote_propagate_context,
82
85
  trace_slice_end,
83
86
  trace_slice_start,
84
87
  )
@@ -213,6 +216,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
213
216
  self.mm_processor = get_mm_processor(
214
217
  self.model_config.hf_config, server_args, _processor, transport_mode
215
218
  )
219
+ self.mm_data_processor = AsyncMMDataProcessor(
220
+ self.mm_processor,
221
+ max_concurrent_calls=self.server_args.mm_max_concurrent_calls,
222
+ timeout_s=self.server_args.mm_per_request_timeout,
223
+ )
216
224
 
217
225
  if server_args.skip_tokenizer_init:
218
226
  self.tokenizer = self.processor = None
@@ -383,6 +391,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
383
391
  self.auto_create_handle_loop()
384
392
  obj.normalize_batch_and_arguments()
385
393
 
394
+ if request:
395
+ if "trace_context" in request.headers:
396
+ trace_set_remote_propagate_context(request.headers["trace_context"])
397
+
386
398
  if self.server_args.tokenizer_worker_num > 1:
387
399
  self._attach_multi_http_worker_info(obj)
388
400
 
@@ -592,10 +604,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
592
604
  obj.image_data = [obj.image_data]
593
605
  if obj.audio_data is not None and not isinstance(obj.audio_data, list):
594
606
  obj.audio_data = [obj.audio_data]
595
- mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
607
+ mm_inputs: Dict = await self.mm_data_processor.process(
596
608
  image_data=obj.image_data,
597
609
  audio_data=obj.audio_data,
598
- input_text=input_text or input_ids,
610
+ input_text_or_ids=(input_text or input_ids),
599
611
  request_obj=obj,
600
612
  max_req_input_len=self.max_req_input_len,
601
613
  )
@@ -605,7 +617,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
605
617
  mm_inputs = None
606
618
 
607
619
  self._validate_one_request(obj, input_ids)
608
- trace_slice_end("tokenize", obj.rid)
620
+ trace_slice_end(RequestStage.TOKENIZE, obj.rid)
609
621
  return self._create_tokenized_object(
610
622
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
611
623
  )
@@ -666,6 +678,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
666
678
  )
667
679
  raise ValueError(error_msg)
668
680
 
681
+ # Matryoshka embeddings validations
682
+ if isinstance(obj, EmbeddingReqInput):
683
+ self._validate_for_matryoshka_dim(obj)
684
+
669
685
  if isinstance(obj, GenerateReqInput):
670
686
  if (
671
687
  obj.return_hidden_states
@@ -684,6 +700,34 @@ class TokenizerManager(TokenizerCommunicatorMixin):
684
700
  "Please set `--enable-custom-logit-processor` to enable this feature."
685
701
  )
686
702
 
703
+ def _validate_for_matryoshka_dim(self, obj: EmbeddingReqInput) -> None:
704
+ """Validate the request for Matryoshka dim if it has the field set."""
705
+ if obj.dimensions is None:
706
+ return
707
+
708
+ if not self.model_config.is_matryoshka:
709
+ raise ValueError(
710
+ f"Model '{self.model_config.model_path}' does not support matryoshka representation, "
711
+ f"changing output dimensions will lead to poor results."
712
+ )
713
+
714
+ if obj.dimensions < 1:
715
+ raise ValueError("Requested dimensions must be greater than 0")
716
+
717
+ if (
718
+ self.model_config.matryoshka_dimensions
719
+ and obj.dimensions not in self.model_config.matryoshka_dimensions
720
+ ):
721
+ raise ValueError(
722
+ f"Model '{self.model_config.model_path}' only supports {self.model_config.matryoshka_dimensions} matryoshka dimensions, "
723
+ f"using other output dimensions will lead to poor results."
724
+ )
725
+
726
+ if obj.dimensions > self.model_config.hidden_size:
727
+ raise ValueError(
728
+ f"Provided dimensions are greater than max embedding dimension: {self.model_config.hidden_size}"
729
+ )
730
+
687
731
  def _validate_input_ids_in_vocab(
688
732
  self, input_ids: List[int], vocab_size: int
689
733
  ) -> None:
@@ -752,6 +796,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
752
796
  sampling_params,
753
797
  rid=obj.rid,
754
798
  priority=obj.priority,
799
+ dimensions=obj.dimensions,
755
800
  http_worker_ipc=obj.http_worker_ipc,
756
801
  )
757
802
 
@@ -798,7 +843,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
798
843
  req, req.text, input_ids_list[i], None, None, token_type_ids
799
844
  )
800
845
  )
801
- trace_slice_end("tokenize", req.rid)
846
+ trace_slice_end(RequestStage.TOKENIZE, req.rid)
802
847
  logger.debug(f"Completed batch processing for {batch_size} requests")
803
848
  return tokenized_objs
804
849
 
@@ -850,12 +895,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
850
895
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
851
896
  created_time: Optional[float] = None,
852
897
  ):
853
- trace_slice_start("dispatch", obj.rid)
898
+ trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid)
854
899
  tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
855
900
  self.send_to_scheduler.send_pyobj(tokenized_obj)
856
901
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
857
902
  self.rid_to_state[obj.rid] = state
858
- trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
903
+ trace_slice_end(
904
+ RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True
905
+ )
859
906
  return state
860
907
 
861
908
  def _send_batch_request(
@@ -1357,6 +1404,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1357
1404
  "finish_reason": recv_obj.finished_reasons[i],
1358
1405
  "prompt_tokens": recv_obj.prompt_tokens[i],
1359
1406
  "weight_version": self.server_args.weight_version,
1407
+ "total_retractions": recv_obj.retraction_counts[i],
1360
1408
  }
1361
1409
 
1362
1410
  if getattr(state.obj, "return_logprob", False):
@@ -1445,6 +1493,51 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1445
1493
  if self.crash_dump_folder and state.finished and state.obj.log_metrics:
1446
1494
  self.record_request_for_crash_dump(state, out_dict)
1447
1495
 
1496
+ def add_logprob_to_meta_info(
1497
+ self,
1498
+ meta_info: dict,
1499
+ state: ReqState,
1500
+ top_logprobs_num: int,
1501
+ token_ids_logprob: List[int],
1502
+ return_text_in_logprobs: bool,
1503
+ ):
1504
+ meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1505
+ state.input_token_logprobs_val,
1506
+ state.input_token_logprobs_idx,
1507
+ return_text_in_logprobs,
1508
+ )
1509
+ meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1510
+ state.output_token_logprobs_val,
1511
+ state.output_token_logprobs_idx,
1512
+ return_text_in_logprobs,
1513
+ )
1514
+
1515
+ if top_logprobs_num > 0:
1516
+ meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1517
+ state.input_top_logprobs_val,
1518
+ state.input_top_logprobs_idx,
1519
+ return_text_in_logprobs,
1520
+ )
1521
+ meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1522
+ state.output_top_logprobs_val,
1523
+ state.output_top_logprobs_idx,
1524
+ return_text_in_logprobs,
1525
+ )
1526
+
1527
+ if token_ids_logprob is not None:
1528
+ meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1529
+ state.input_token_ids_logprobs_val,
1530
+ state.input_token_ids_logprobs_idx,
1531
+ return_text_in_logprobs,
1532
+ )
1533
+ meta_info["output_token_ids_logprobs"] = (
1534
+ self.detokenize_top_logprobs_tokens(
1535
+ state.output_token_ids_logprobs_val,
1536
+ state.output_token_ids_logprobs_idx,
1537
+ return_text_in_logprobs,
1538
+ )
1539
+ )
1540
+
1448
1541
  def convert_logprob_style(
1449
1542
  self,
1450
1543
  meta_info: dict,
@@ -1471,16 +1564,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1471
1564
  state.output_token_logprobs_idx.extend(
1472
1565
  recv_obj.output_token_logprobs_idx[recv_obj_index]
1473
1566
  )
1474
- meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
1475
- state.input_token_logprobs_val,
1476
- state.input_token_logprobs_idx,
1477
- return_text_in_logprobs,
1478
- )
1479
- meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
1480
- state.output_token_logprobs_val,
1481
- state.output_token_logprobs_idx,
1482
- return_text_in_logprobs,
1483
- )
1484
1567
 
1485
1568
  if top_logprobs_num > 0:
1486
1569
  if len(recv_obj.input_top_logprobs_val) > 0:
@@ -1496,16 +1579,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1496
1579
  state.output_top_logprobs_idx.extend(
1497
1580
  recv_obj.output_top_logprobs_idx[recv_obj_index]
1498
1581
  )
1499
- meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1500
- state.input_top_logprobs_val,
1501
- state.input_top_logprobs_idx,
1502
- return_text_in_logprobs,
1503
- )
1504
- meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
1505
- state.output_top_logprobs_val,
1506
- state.output_top_logprobs_idx,
1507
- return_text_in_logprobs,
1508
- )
1509
1582
 
1510
1583
  if token_ids_logprob is not None:
1511
1584
  if len(recv_obj.input_token_ids_logprobs_val) > 0:
@@ -1521,18 +1594,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1521
1594
  state.output_token_ids_logprobs_idx.extend(
1522
1595
  recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
1523
1596
  )
1524
- meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
1525
- state.input_token_ids_logprobs_val,
1526
- state.input_token_ids_logprobs_idx,
1527
- return_text_in_logprobs,
1528
- )
1529
- meta_info["output_token_ids_logprobs"] = (
1530
- self.detokenize_top_logprobs_tokens(
1531
- state.output_token_ids_logprobs_val,
1532
- state.output_token_ids_logprobs_idx,
1533
- return_text_in_logprobs,
1534
- )
1535
- )
1597
+
1598
+ self.add_logprob_to_meta_info(
1599
+ meta_info,
1600
+ state,
1601
+ state.obj.top_logprobs_num,
1602
+ state.obj.token_ids_logprob,
1603
+ return_text_in_logprobs,
1604
+ )
1536
1605
 
1537
1606
  def detokenize_logprob_tokens(
1538
1607
  self,
@@ -1649,6 +1718,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1649
1718
  or state.obj.sampling_params.get("ebnf", None)
1650
1719
  or state.obj.sampling_params.get("structural_tag", None)
1651
1720
  )
1721
+
1722
+ retraction_count = (
1723
+ recv_obj.retraction_counts[i]
1724
+ if getattr(recv_obj, "retraction_counts", None)
1725
+ and i < len(recv_obj.retraction_counts)
1726
+ else 0
1727
+ )
1728
+
1652
1729
  self.metrics_collector.observe_one_finished_request(
1653
1730
  labels,
1654
1731
  recv_obj.prompt_tokens[i],
@@ -1656,6 +1733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1656
1733
  recv_obj.cached_tokens[i],
1657
1734
  state.finished_time - state.created_time,
1658
1735
  has_grammar,
1736
+ retraction_count,
1659
1737
  )
1660
1738
 
1661
1739
  def dump_requests(self, state: ReqState, out_dict: dict):
@@ -1708,26 +1786,33 @@ class TokenizerManager(TokenizerCommunicatorMixin):
1708
1786
  return
1709
1787
  state = self.rid_to_state[recv_obj.rid]
1710
1788
  state.finished = True
1789
+
1790
+ abort_message = recv_obj.abort_message or "Abort in waiting queue"
1791
+ finish_reason = {
1792
+ "type": "abort",
1793
+ "message": abort_message,
1794
+ }
1711
1795
  if recv_obj.finished_reason:
1712
- out = {
1713
- "meta_info": {
1714
- "id": recv_obj.rid,
1715
- "finish_reason": recv_obj.finished_reason,
1716
- },
1717
- }
1718
- else:
1719
- out = {
1720
- "text": "",
1721
- "meta_info": {
1722
- "id": recv_obj.rid,
1723
- "finish_reason": {
1724
- "type": "abort",
1725
- "message": "Abort before prefill",
1726
- },
1727
- "prompt_tokens": 0,
1728
- "completion_tokens": 0,
1729
- },
1730
- }
1796
+ finish_reason = recv_obj.finished_reason
1797
+ meta_info = {"id": recv_obj.rid, "finish_reason": finish_reason}
1798
+ is_stream = getattr(state.obj, "stream", False)
1799
+ if getattr(state.obj, "return_logprob", False):
1800
+ self.add_logprob_to_meta_info(
1801
+ meta_info,
1802
+ state,
1803
+ state.obj.top_logprobs_num,
1804
+ state.obj.token_ids_logprob,
1805
+ state.obj.return_text_in_logprobs
1806
+ and not self.server_args.skip_tokenizer_init,
1807
+ )
1808
+
1809
+ output_ids = state.output_ids
1810
+ meta_info["completion_tokens"] = len(output_ids)
1811
+ out = {
1812
+ "text": state.text,
1813
+ "output_ids": [output_ids[-1]] if is_stream else output_ids,
1814
+ "meta_info": meta_info,
1815
+ }
1731
1816
  state.out_list.append(out)
1732
1817
  state.event.set()
1733
1818
 
@@ -2088,7 +2173,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
2088
2173
  bootstrap_room = (
2089
2174
  obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
2090
2175
  )
2091
- trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
2176
+ trace_req_start(
2177
+ obj.rid,
2178
+ bootstrap_room,
2179
+ ts=int(created_time * 1e9),
2180
+ role=self.server_args.disaggregation_mode,
2181
+ )
2092
2182
  trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
2093
2183
  else:
2094
2184
  for i in range(len(obj.rid)):
@@ -2097,7 +2187,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
2097
2187
  if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
2098
2188
  else None
2099
2189
  )
2100
- trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
2190
+ trace_req_start(
2191
+ obj.rid[i],
2192
+ bootstrap_room,
2193
+ ts=int(created_time * 1e9),
2194
+ role=self.server_args.disaggregation_mode,
2195
+ )
2101
2196
  trace_slice_start(
2102
2197
  "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
2103
2198
  )
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
35
35
  UpdateWeightsFromIPCReqInput,
36
36
  UpdateWeightsFromTensorReqInput,
37
37
  )
38
- from sglang.srt.managers.schedule_batch import ModelWorkerBatch
38
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
39
39
  from sglang.srt.managers.scheduler import GenerationBatchResult
40
40
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
41
41
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
@@ -425,3 +425,26 @@ class TpModelWorker(BaseTpWorker):
425
425
  pp_hidden_states_proxy_tensors=pp_proxy_tensors,
426
426
  can_run_cuda_graph=can_run_cuda_graph,
427
427
  )
428
+
429
+ def forward_batch_split_prefill(self, batch: ScheduleBatch):
430
+ if batch.split_index == 0:
431
+ model_worker_batch = batch.get_model_worker_batch()
432
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
433
+ batch.split_forward_batch = forward_batch
434
+ batch.seq_lens_cpu_cache = model_worker_batch.seq_lens_cpu
435
+ else:
436
+ model_worker_batch = batch.get_model_worker_batch(batch.seq_lens_cpu_cache)
437
+
438
+ logits_output, can_run_cuda_graph = self.model_runner.forward(
439
+ batch.split_forward_batch, split_forward_count=batch.split_forward_count
440
+ )
441
+ if logits_output:
442
+ next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
443
+ else:
444
+ next_token_ids = None
445
+ batch_result = GenerationBatchResult(
446
+ logits_output=logits_output,
447
+ can_run_cuda_graph=can_run_cuda_graph,
448
+ )
449
+ batch_result.next_token_ids = next_token_ids
450
+ return batch_result
@@ -1,12 +1,31 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
- from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Tuple
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ NamedTuple,
8
+ Optional,
9
+ Protocol,
10
+ Tuple,
11
+ runtime_checkable,
12
+ )
3
13
 
4
14
  import torch
5
15
 
16
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
17
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
18
+
6
19
  if TYPE_CHECKING:
7
20
  from sglang.srt.managers.schedule_batch import Req
8
- else:
9
- Req = Any # Placeholder for Req type when not type checking
21
+
22
+
23
+ @runtime_checkable
24
+ class PrefixCacheTrait(Protocol):
25
+ req_to_token_pool: ReqToTokenPool
26
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator
27
+ page_size: int
28
+ disable: bool
10
29
 
11
30
 
12
31
  class MatchResult(NamedTuple):
@@ -28,7 +47,7 @@ class MatchResult(NamedTuple):
28
47
  host_hit_length: int = 0
29
48
 
30
49
 
31
- class BasePrefixCache(ABC):
50
+ class BasePrefixCache(ABC, PrefixCacheTrait):
32
51
  """Cache can be indexed by either rid or key."""
33
52
 
34
53
  @abstractmethod
@@ -89,6 +89,7 @@ def write_cache_indices(
89
89
  prefix_pointers = torch.tensor(
90
90
  [t.data_ptr() for t in prefix_tensors],
91
91
  device=req_to_token_pool.device,
92
+ dtype=torch.uint64,
92
93
  )
93
94
  # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
94
95
  write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](