sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@ from __future__ import annotations
2
2
 
3
3
  import enum
4
4
 
5
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
+
5
7
  # Copyright 2023-2024 SGLang Team
6
8
  # Licensed under the Apache License, Version 2.0 (the "License");
7
9
  # you may not use this file except in compliance with the License.
@@ -70,11 +72,18 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
70
72
  from sglang.srt.mem_cache.radix_cache import RadixKey
71
73
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
72
74
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
73
- from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
75
+ from sglang.srt.model_executor.forward_batch_info import (
76
+ CaptureHiddenMode,
77
+ ForwardBatch,
78
+ ForwardMode,
79
+ )
74
80
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
75
81
  from sglang.srt.sampling.sampling_params import SamplingParams
76
82
  from sglang.srt.server_args import ServerArgs, get_global_server_args
77
83
  from sglang.srt.utils import flatten_nested_list
84
+ from sglang.srt.utils.common import is_npu
85
+
86
+ _is_npu = is_npu()
78
87
 
79
88
  if TYPE_CHECKING:
80
89
  from sglang.srt.configs.model_config import ModelConfig
@@ -392,13 +401,23 @@ class MultimodalInputs:
392
401
 
393
402
 
394
403
  class RequestStage(str, enum.Enum):
395
- # prefill
404
+ # Tokenizer
405
+ TOKENIZE = "tokenize"
406
+ TOKENIZER_DISPATCH = "dispatch"
407
+
408
+ # DP controller
409
+ DC_DISPATCH = "dc_dispatch"
410
+
411
+ # common/non-disaggregation
396
412
  PREFILL_WAITING = "prefill_waiting"
413
+ REQUEST_PROCESS = "request_process"
414
+ DECODE_LOOP = "decode_loop"
415
+ PREFILL_FORWARD = "prefill_forward"
416
+ PREFILL_CHUNKED_FORWARD = "chunked_prefill"
397
417
 
398
418
  # disaggregation prefill
399
419
  PREFILL_PREPARE = "prefill_prepare"
400
420
  PREFILL_BOOTSTRAP = "prefill_bootstrap"
401
- PREFILL_FORWARD = "prefill_forward"
402
421
  PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
403
422
 
404
423
  # disaggregation decode
@@ -406,6 +425,8 @@ class RequestStage(str, enum.Enum):
406
425
  DECODE_BOOTSTRAP = "decode_bootstrap"
407
426
  DECODE_WAITING = "decode_waiting"
408
427
  DECODE_TRANSFERRED = "decode_transferred"
428
+ DECODE_FAKE_OUTPUT = "fake_output"
429
+ DECODE_QUICK_FINISH = "quick_finish"
409
430
 
410
431
 
411
432
  class Req:
@@ -438,6 +459,7 @@ class Req:
438
459
  priority: Optional[int] = None,
439
460
  metrics_collector: Optional[SchedulerMetricsCollector] = None,
440
461
  extra_key: Optional[str] = None,
462
+ dimensions: Optional[int] = None,
441
463
  http_worker_ipc: Optional[str] = None,
442
464
  ):
443
465
  # Input and output info
@@ -490,16 +512,15 @@ class Req:
490
512
 
491
513
  # Check finish
492
514
  self.tokenizer = None
493
- self.finished_reason = None
515
+ self.finished_reason: Optional[BaseFinishReason] = None
494
516
  # finished position (in output_ids), used when checking stop conditions with speculative decoding
495
517
  self.finished_len = None
496
518
  # Whether this request has finished output
497
519
  self.finished_output = None
498
- # If we want to abort the request in the middle of the event loop, set this to true
520
+ # If we want to abort the request in the middle of the event loop,
521
+ # set to_finish instead of directly setting finished_reason.
499
522
  # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
500
- self.to_abort = False
501
- # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
502
- self.to_abort_message: str = None
523
+ self.to_finish: Optional[BaseFinishReason] = None
503
524
  self.stream = stream
504
525
  self.eos_token_ids = eos_token_ids
505
526
  self.vocab_size = vocab_size
@@ -618,6 +639,9 @@ class Req:
618
639
  # This is used to compute the acceptance rate and average acceptance length per request.
619
640
  self.spec_accepted_tokens = 0
620
641
 
642
+ # The number of times this request has been retracted / preempted.
643
+ self.retraction_count = 0
644
+
621
645
  # For metrics
622
646
  self.metrics_collector = metrics_collector
623
647
  self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
@@ -646,6 +670,9 @@ class Req:
646
670
  self.tmp_end_idx: int = -1
647
671
  self.metadata_buffer_index: int = -1
648
672
 
673
+ # For Matryoshka embeddings
674
+ self.dimensions = dimensions
675
+
649
676
  @property
650
677
  def seqlen(self):
651
678
  return len(self.origin_input_ids) + len(self.output_ids)
@@ -845,10 +872,9 @@ class Req:
845
872
  if self.finished():
846
873
  return
847
874
 
848
- if self.to_abort:
849
- self.finished_reason = FINISH_ABORT(
850
- message=self.to_abort_message,
851
- )
875
+ if self.to_finish:
876
+ self.finished_reason = self.to_finish
877
+ self.to_finish = None
852
878
  return
853
879
 
854
880
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
@@ -875,6 +901,10 @@ class Req:
875
901
  return
876
902
 
877
903
  def reset_for_retract(self):
904
+ # Increment retraction count before resetting other state. We should not reset this
905
+ # since we are tracking the total number of retractions for each request.
906
+ self.retraction_count += 1
907
+
878
908
  self.prefix_indices = torch.empty((0,), dtype=torch.int64)
879
909
  self.last_node = None
880
910
  self.swa_uuid_for_lock = None
@@ -920,7 +950,7 @@ class Req:
920
950
  self.grammar = None
921
951
  self.origin_input_ids = [0] # set it to one token to skip the long prefill
922
952
  self.return_logprob = False
923
- self.finished_reason = FINISH_ABORT(
953
+ self.to_finish = FINISH_ABORT(
924
954
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
925
955
  )
926
956
 
@@ -1010,6 +1040,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1010
1040
  encoder_lens_cpu: Optional[List[int]] = None
1011
1041
  encoder_out_cache_loc: Optional[torch.Tensor] = None
1012
1042
 
1043
+ # For matryoshka embeddings
1044
+ dimensions: Optional[list[int]] = None
1045
+
1046
+ # For split prefill
1047
+ split_index: int = 0
1048
+ split_prefill_finished: bool = False
1049
+ split_forward_count: int = 1
1050
+ split_forward_batch: ForwardBatch = None
1051
+ seq_lens_cpu_cache: torch.Tensor = None
1052
+
1013
1053
  # Stream
1014
1054
  has_stream: bool = False
1015
1055
 
@@ -1017,7 +1057,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1017
1057
  has_grammar: bool = False
1018
1058
 
1019
1059
  # Device
1020
- device: str = "cuda"
1060
+ if not _is_npu:
1061
+ device: str = "cuda"
1062
+ else:
1063
+ device: str = "npu"
1021
1064
 
1022
1065
  # Speculative decoding
1023
1066
  spec_algorithm: SpeculativeAlgorithm = None
@@ -1166,6 +1209,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1166
1209
  prefix_lens = [len(r.prefix_indices) for r in reqs]
1167
1210
  extend_lens = [r.extend_input_len for r in reqs]
1168
1211
 
1212
+ # For matryoshka embeddings
1213
+ if self.model_config.is_matryoshka and any(
1214
+ r.dimensions is not None for r in reqs
1215
+ ):
1216
+ self.dimensions = [
1217
+ r.dimensions if r.dimensions else self.model_config.hidden_size
1218
+ for r in reqs
1219
+ ]
1220
+
1169
1221
  token_type_ids = [
1170
1222
  r.token_type_ids for r in reqs if r.token_type_ids is not None
1171
1223
  ]
@@ -1367,6 +1419,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1367
1419
  self.extend_num_tokens += running_bs
1368
1420
  # TODO (lianmin): Revisit this. It should be seq_len - 1
1369
1421
  self.extend_logprob_start_lens.extend([0] * running_bs)
1422
+ self.is_prefill_only = False
1370
1423
 
1371
1424
  def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
1372
1425
  page_size = self.token_to_kv_pool_allocator.page_size
@@ -1397,7 +1450,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1397
1450
  evict_from_tree_cache(self.tree_cache, num_tokens)
1398
1451
  return self._is_available_size_sufficient(num_tokens)
1399
1452
 
1400
- def retract_decode(self, server_args: ServerArgs):
1453
+ def retract_decode(
1454
+ self, server_args: ServerArgs
1455
+ ) -> Tuple[List[Req], float, List[Req]]:
1401
1456
  """Retract the decoding requests when there is not enough memory."""
1402
1457
  sorted_indices = list(range(len(self.reqs)))
1403
1458
 
@@ -1754,6 +1809,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1754
1809
  ),
1755
1810
  extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
1756
1811
  is_prefill_only=self.is_prefill_only,
1812
+ dimensions=self.dimensions,
1757
1813
  )
1758
1814
 
1759
1815
  def copy(self):
@@ -1862,5 +1918,8 @@ class ModelWorkerBatch:
1862
1918
  capture_hidden_mode: CaptureHiddenMode = None
1863
1919
  hicache_consumer_index: int = -1
1864
1920
 
1921
+ # For matryoshka embeddings
1922
+ dimensions: Optional[list[int]] = None
1923
+
1865
1924
  # Whether this batch is prefill-only (no token generation needed)
1866
1925
  is_prefill_only: bool = False