sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -34,13 +34,21 @@ from sglang.srt.managers.io_struct import (
34
34
  TokenizedGenerateReqInput,
35
35
  WatchLoadUpdateReq,
36
36
  )
37
- from sglang.srt.managers.schedule_batch import Req
37
+ from sglang.srt.managers.schedule_batch import Req, RequestStage
38
38
  from sglang.srt.managers.scheduler import run_scheduler_process
39
39
  from sglang.srt.server_args import (
40
40
  DP_ATTENTION_HANDSHAKE_PORT_DELTA,
41
41
  PortArgs,
42
42
  ServerArgs,
43
43
  )
44
+ from sglang.srt.tracing.trace import (
45
+ process_tracing_init,
46
+ trace_get_proc_propagate_context,
47
+ trace_set_proc_propagate_context,
48
+ trace_set_thread_info,
49
+ trace_slice_end,
50
+ trace_slice_start,
51
+ )
44
52
  from sglang.srt.utils import (
45
53
  bind_port,
46
54
  configure_logger,
@@ -170,11 +178,22 @@ class DataParallelController:
170
178
  def handle_load_update_req(self, obj):
171
179
  self.dp_budget.update_budget(obj)
172
180
 
181
+ def dispatching_with_trace(self, req: Req):
182
+ if self.server_args.enable_trace:
183
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
184
+ trace_slice_start(RequestStage.DC_DISPATCH, req.rid)
185
+ req.trace_context = trace_get_proc_propagate_context(req.rid)
186
+
187
+ self.dispatching(req)
188
+
189
+ if self.server_args.enable_trace:
190
+ trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True)
191
+
173
192
  def init_dispatcher(self):
174
193
  self._request_dispatcher = TypeBasedDispatcher(
175
194
  [
176
- (TokenizedGenerateReqInput, self.dispatching),
177
- (TokenizedEmbeddingReqInput, self.dispatching),
195
+ (TokenizedGenerateReqInput, self.dispatching_with_trace),
196
+ (TokenizedEmbeddingReqInput, self.dispatching_with_trace),
178
197
  (BlockReqInput, self.send_to_all_workers),
179
198
  (WatchLoadUpdateReq, self.handle_load_update_req),
180
199
  ]
@@ -487,6 +506,14 @@ def run_data_parallel_controller_process(
487
506
  pipe_writer,
488
507
  ):
489
508
  kill_itself_when_parent_died()
509
+ if server_args.enable_trace:
510
+ process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
511
+ thread_label = "DP Controller"
512
+ if server_args.disaggregation_mode == "prefill":
513
+ thread_label = "Prefill DP Controller"
514
+ elif server_args.disaggregation_mode == "decode":
515
+ thread_label = "Decode DP Controller"
516
+ trace_set_thread_info(thread_label)
490
517
  setproctitle.setproctitle("sglang::data_parallel_controller")
491
518
  faulthandler.enable()
492
519
  configure_logger(server_args)
@@ -235,6 +235,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
235
235
  new_text = ""
236
236
  else:
237
237
  new_text = find_printable_text(new_text)
238
+ else:
239
+ del self.decode_status[recv_obj.rids[i]]
238
240
 
239
241
  output_str = self.trim_matched_stop(
240
242
  s.decoded_text + new_text,
@@ -273,6 +275,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
273
275
  output_hidden_states=recv_obj.output_hidden_states,
274
276
  placeholder_tokens_idx=None,
275
277
  placeholder_tokens_val=None,
278
+ retraction_counts=recv_obj.retraction_counts,
276
279
  token_steps=recv_obj.token_steps,
277
280
  )
278
281
 
@@ -695,6 +695,9 @@ class EmbeddingReqInput(BaseReq):
695
695
  # tracing context
696
696
  trace_context: Optional[Dict] = None
697
697
 
698
+ # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
699
+ dimensions: Optional[int] = None
700
+
698
701
  def normalize_batch_and_arguments(self):
699
702
  # at least one of text, input_ids, or image should be provided
700
703
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -771,6 +774,7 @@ class EmbeddingReqInput(BaseReq):
771
774
  video_data=self.video_data[i] if self.video_data is not None else None,
772
775
  sampling_params=self.sampling_params[i],
773
776
  rid=self.rid[i],
777
+ dimensions=self.dimensions,
774
778
  http_worker_ipc=self.http_worker_ipc,
775
779
  )
776
780
 
@@ -791,6 +795,8 @@ class TokenizedEmbeddingReqInput(BaseReq):
791
795
  data_parallel_rank: Optional[int] = None
792
796
  # Priority for the request
793
797
  priority: Optional[int] = None
798
+ # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
799
+ dimensions: Optional[int] = None
794
800
 
795
801
 
796
802
  @dataclass
@@ -854,6 +860,9 @@ class BatchTokenIDOutput(BaseBatchReq):
854
860
  placeholder_tokens_idx: List[Optional[List[int]]]
855
861
  placeholder_tokens_val: List[Optional[List[int]]]
856
862
 
863
+ # Number of times each request was retracted.
864
+ retraction_counts: List[int]
865
+
857
866
  # The trainer step id. Used to know which step's weights are used for sampling.
858
867
  token_steps: List[List[int]] = None
859
868
 
@@ -930,6 +939,9 @@ class BatchStrOutput(BaseBatchReq):
930
939
  placeholder_tokens_idx: List[Optional[List[int]]]
931
940
  placeholder_tokens_val: List[Optional[List[int]]]
932
941
 
942
+ # Number of times each request was retracted.
943
+ retraction_counts: List[int]
944
+
933
945
  # The trainer step id. Used to know which step's weights are used for sampling.
934
946
  token_steps: List[List[int]] = None
935
947
 
@@ -972,6 +984,9 @@ class BatchEmbeddingOutput(BaseBatchReq):
972
984
  placeholder_tokens_idx: List[Optional[List[int]]]
973
985
  placeholder_tokens_val: List[Optional[List[int]]]
974
986
 
987
+ # Number of times each request was retracted.
988
+ retraction_counts: List[int]
989
+
975
990
 
976
991
  @dataclass
977
992
  class ClearHiCacheReqInput(BaseReq):
@@ -1215,7 +1230,7 @@ class AbortReq(BaseReq):
1215
1230
  abort_all: bool = False
1216
1231
  # The finished reason data
1217
1232
  finished_reason: Optional[Dict[str, Any]] = None
1218
- abort_reason: Optional[str] = None
1233
+ abort_message: Optional[str] = None
1219
1234
 
1220
1235
  def __post_init__(self):
1221
1236
  # FIXME: This is a hack to keep the same with the old code
@@ -1458,6 +1473,16 @@ class WatchLoadUpdateReq(BaseReq):
1458
1473
  loads: List[GetLoadReqOutput]
1459
1474
 
1460
1475
 
1476
+ @dataclass
1477
+ class SetInjectDumpMetadataReqInput(BaseReq):
1478
+ dump_metadata: Dict[str, Any]
1479
+
1480
+
1481
+ @dataclass
1482
+ class SetInjectDumpMetadataReqOutput(BaseReq):
1483
+ success: bool
1484
+
1485
+
1461
1486
  @dataclass
1462
1487
  class LazyDumpTensorsReqInput(BaseReq):
1463
1488
  pass
@@ -1489,6 +1514,3 @@ def _check_all_req_types():
1489
1514
  raise ValueError(
1490
1515
  f"{name} is a subclass of BaseReq but not follow the naming convention."
1491
1516
  )
1492
-
1493
-
1494
- _check_all_req_types()
@@ -334,6 +334,11 @@ def _handle_output_by_index(output, i):
334
334
  ),
335
335
  placeholder_tokens_idx=None,
336
336
  placeholder_tokens_val=None,
337
+ retraction_counts=(
338
+ [output.retraction_counts[i]]
339
+ if len(output.retraction_counts) > i
340
+ else None
341
+ ),
337
342
  token_steps=([output.token_steps[i]] if output.token_steps else None),
338
343
  )
339
344
  elif isinstance(output, BatchMultimodalOutput):
@@ -2,6 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import enum
4
4
 
5
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
+
5
7
  # Copyright 2023-2024 SGLang Team
6
8
  # Licensed under the Apache License, Version 2.0 (the "License");
7
9
  # you may not use this file except in compliance with the License.
@@ -70,11 +72,18 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
70
72
  from sglang.srt.mem_cache.radix_cache import RadixKey
71
73
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
72
74
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
73
- from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
75
+ from sglang.srt.model_executor.forward_batch_info import (
76
+ CaptureHiddenMode,
77
+ ForwardBatch,
78
+ ForwardMode,
79
+ )
74
80
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
75
81
  from sglang.srt.sampling.sampling_params import SamplingParams
76
82
  from sglang.srt.server_args import ServerArgs, get_global_server_args
77
83
  from sglang.srt.utils import flatten_nested_list
84
+ from sglang.srt.utils.common import is_npu
85
+
86
+ _is_npu = is_npu()
78
87
 
79
88
  if TYPE_CHECKING:
80
89
  from sglang.srt.configs.model_config import ModelConfig
@@ -392,13 +401,23 @@ class MultimodalInputs:
392
401
 
393
402
 
394
403
  class RequestStage(str, enum.Enum):
395
- # prefill
404
+ # Tokenizer
405
+ TOKENIZE = "tokenize"
406
+ TOKENIZER_DISPATCH = "dispatch"
407
+
408
+ # DP controller
409
+ DC_DISPATCH = "dc_dispatch"
410
+
411
+ # common/non-disaggregation
396
412
  PREFILL_WAITING = "prefill_waiting"
413
+ REQUEST_PROCESS = "request_process"
414
+ DECODE_LOOP = "decode_loop"
415
+ PREFILL_FORWARD = "prefill_forward"
416
+ PREFILL_CHUNKED_FORWARD = "chunked_prefill"
397
417
 
398
418
  # disaggregation prefill
399
419
  PREFILL_PREPARE = "prefill_prepare"
400
420
  PREFILL_BOOTSTRAP = "prefill_bootstrap"
401
- PREFILL_FORWARD = "prefill_forward"
402
421
  PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
403
422
 
404
423
  # disaggregation decode
@@ -406,6 +425,8 @@ class RequestStage(str, enum.Enum):
406
425
  DECODE_BOOTSTRAP = "decode_bootstrap"
407
426
  DECODE_WAITING = "decode_waiting"
408
427
  DECODE_TRANSFERRED = "decode_transferred"
428
+ DECODE_FAKE_OUTPUT = "fake_output"
429
+ DECODE_QUICK_FINISH = "quick_finish"
409
430
 
410
431
 
411
432
  class Req:
@@ -438,6 +459,7 @@ class Req:
438
459
  priority: Optional[int] = None,
439
460
  metrics_collector: Optional[SchedulerMetricsCollector] = None,
440
461
  extra_key: Optional[str] = None,
462
+ dimensions: Optional[int] = None,
441
463
  http_worker_ipc: Optional[str] = None,
442
464
  ):
443
465
  # Input and output info
@@ -490,16 +512,15 @@ class Req:
490
512
 
491
513
  # Check finish
492
514
  self.tokenizer = None
493
- self.finished_reason = None
515
+ self.finished_reason: Optional[BaseFinishReason] = None
494
516
  # finished position (in output_ids), used when checking stop conditions with speculative decoding
495
517
  self.finished_len = None
496
518
  # Whether this request has finished output
497
519
  self.finished_output = None
498
- # If we want to abort the request in the middle of the event loop, set this to true
520
+ # If we want to abort the request in the middle of the event loop,
521
+ # set to_finish instead of directly setting finished_reason.
499
522
  # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
500
- self.to_abort = False
501
- # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
502
- self.to_abort_message: str = None
523
+ self.to_finish: Optional[BaseFinishReason] = None
503
524
  self.stream = stream
504
525
  self.eos_token_ids = eos_token_ids
505
526
  self.vocab_size = vocab_size
@@ -618,6 +639,9 @@ class Req:
618
639
  # This is used to compute the acceptance rate and average acceptance length per request.
619
640
  self.spec_accepted_tokens = 0
620
641
 
642
+ # The number of times this request has been retracted / preempted.
643
+ self.retraction_count = 0
644
+
621
645
  # For metrics
622
646
  self.metrics_collector = metrics_collector
623
647
  self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
@@ -646,6 +670,9 @@ class Req:
646
670
  self.tmp_end_idx: int = -1
647
671
  self.metadata_buffer_index: int = -1
648
672
 
673
+ # For Matryoshka embeddings
674
+ self.dimensions = dimensions
675
+
649
676
  @property
650
677
  def seqlen(self):
651
678
  return len(self.origin_input_ids) + len(self.output_ids)
@@ -845,10 +872,9 @@ class Req:
845
872
  if self.finished():
846
873
  return
847
874
 
848
- if self.to_abort:
849
- self.finished_reason = FINISH_ABORT(
850
- message=self.to_abort_message,
851
- )
875
+ if self.to_finish:
876
+ self.finished_reason = self.to_finish
877
+ self.to_finish = None
852
878
  return
853
879
 
854
880
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
@@ -875,6 +901,10 @@ class Req:
875
901
  return
876
902
 
877
903
  def reset_for_retract(self):
904
+ # Increment retraction count before resetting other state. We should not reset this
905
+ # since we are tracking the total number of retractions for each request.
906
+ self.retraction_count += 1
907
+
878
908
  self.prefix_indices = torch.empty((0,), dtype=torch.int64)
879
909
  self.last_node = None
880
910
  self.swa_uuid_for_lock = None
@@ -920,7 +950,7 @@ class Req:
920
950
  self.grammar = None
921
951
  self.origin_input_ids = [0] # set it to one token to skip the long prefill
922
952
  self.return_logprob = False
923
- self.finished_reason = FINISH_ABORT(
953
+ self.to_finish = FINISH_ABORT(
924
954
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
925
955
  )
926
956
 
@@ -1010,6 +1040,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1010
1040
  encoder_lens_cpu: Optional[List[int]] = None
1011
1041
  encoder_out_cache_loc: Optional[torch.Tensor] = None
1012
1042
 
1043
+ # For matryoshka embeddings
1044
+ dimensions: Optional[list[int]] = None
1045
+
1046
+ # For split prefill
1047
+ split_index: int = 0
1048
+ split_prefill_finished: bool = False
1049
+ split_forward_count: int = 1
1050
+ split_forward_batch: ForwardBatch = None
1051
+ seq_lens_cpu_cache: torch.Tensor = None
1052
+
1013
1053
  # Stream
1014
1054
  has_stream: bool = False
1015
1055
 
@@ -1017,7 +1057,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1017
1057
  has_grammar: bool = False
1018
1058
 
1019
1059
  # Device
1020
- device: str = "cuda"
1060
+ if not _is_npu:
1061
+ device: str = "cuda"
1062
+ else:
1063
+ device: str = "npu"
1021
1064
 
1022
1065
  # Speculative decoding
1023
1066
  spec_algorithm: SpeculativeAlgorithm = None
@@ -1166,6 +1209,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1166
1209
  prefix_lens = [len(r.prefix_indices) for r in reqs]
1167
1210
  extend_lens = [r.extend_input_len for r in reqs]
1168
1211
 
1212
+ # For matryoshka embeddings
1213
+ if self.model_config.is_matryoshka and any(
1214
+ r.dimensions is not None for r in reqs
1215
+ ):
1216
+ self.dimensions = [
1217
+ r.dimensions if r.dimensions else self.model_config.hidden_size
1218
+ for r in reqs
1219
+ ]
1220
+
1169
1221
  token_type_ids = [
1170
1222
  r.token_type_ids for r in reqs if r.token_type_ids is not None
1171
1223
  ]
@@ -1367,6 +1419,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1367
1419
  self.extend_num_tokens += running_bs
1368
1420
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1369
1421
  self.extend_logprob_start_lens.extend([0] * running_bs)
1422
+ self.is_prefill_only = False
1370
1423
 
1371
1424
  def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1372
1425
  page_size = self.token_to_kv_pool_allocator.page_size
@@ -1397,7 +1450,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1397
1450
  evict_from_tree_cache(self.tree_cache, num_tokens)
1398
1451
  return self._is_available_size_sufficient(num_tokens)
1399
1452
 
1400
- def retract_decode(self, server_args: ServerArgs):
1453
+ def retract_decode(
1454
+ self, server_args: ServerArgs
1455
+ ) -> Tuple[List[Req], float, List[Req]]:
1401
1456
  """Retract the decoding requests when there is not enough memory."""
1402
1457
  sorted_indices = list(range(len(self.reqs)))
1403
1458
 
@@ -1754,6 +1809,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1754
1809
  ),
1755
1810
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1756
1811
  is_prefill_only=self.is_prefill_only,
1812
+ dimensions=self.dimensions,
1757
1813
  )
1758
1814
 
1759
1815
  def copy(self):
@@ -1862,5 +1918,8 @@ class ModelWorkerBatch:
1862
1918
  capture_hidden_mode: CaptureHiddenMode = None
1863
1919
  hicache_consumer_index: int = -1
1864
1920
 
1921
+ # For matryoshka embeddings
1922
+ dimensions: Optional[list[int]] = None
1923
+
1865
1924
  # Whether this batch is prefill-only (no token generation needed)
1866
1925
  is_prefill_only: bool = False