sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
 - sglang/bench_serving.py +18 -3
 - sglang/compile_deep_gemm.py +13 -7
 - sglang/srt/batch_invariant_ops/__init__.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
 - sglang/srt/checkpoint_engine/__init__.py +9 -0
 - sglang/srt/checkpoint_engine/update.py +317 -0
 - sglang/srt/configs/__init__.py +2 -0
 - sglang/srt/configs/deepseek_ocr.py +542 -10
 - sglang/srt/configs/deepseekvl2.py +95 -194
 - sglang/srt/configs/kimi_linear.py +160 -0
 - sglang/srt/configs/mamba_utils.py +66 -0
 - sglang/srt/configs/model_config.py +25 -2
 - sglang/srt/constants.py +7 -0
 - sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
 - sglang/srt/disaggregation/decode.py +34 -6
 - sglang/srt/disaggregation/nixl/conn.py +2 -2
 - sglang/srt/disaggregation/prefill.py +25 -3
 - sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
 - sglang/srt/distributed/parallel_state.py +9 -5
 - sglang/srt/entrypoints/engine.py +13 -5
 - sglang/srt/entrypoints/http_server.py +22 -3
 - sglang/srt/entrypoints/openai/protocol.py +7 -1
 - sglang/srt/entrypoints/openai/serving_chat.py +42 -0
 - sglang/srt/entrypoints/openai/serving_completions.py +10 -0
 - sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
 - sglang/srt/environ.py +7 -0
 - sglang/srt/eplb/expert_distribution.py +34 -1
 - sglang/srt/eplb/expert_location.py +106 -36
 - sglang/srt/grpc/compile_proto.py +3 -0
 - sglang/srt/layers/attention/ascend_backend.py +233 -5
 - sglang/srt/layers/attention/attention_registry.py +3 -0
 - sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
 - sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
 - sglang/srt/layers/attention/fla/kda.py +1359 -0
 - sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
 - sglang/srt/layers/attention/flashattention_backend.py +7 -6
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
 - sglang/srt/layers/attention/flashmla_backend.py +1 -1
 - sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
 - sglang/srt/layers/attention/mamba/mamba.py +20 -11
 - sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
 - sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
 - sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
 - sglang/srt/layers/attention/nsa/transform_index.py +1 -1
 - sglang/srt/layers/attention/nsa_backend.py +157 -23
 - sglang/srt/layers/attention/triton_backend.py +4 -1
 - sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
 - sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
 - sglang/srt/layers/communicator.py +23 -1
 - sglang/srt/layers/layernorm.py +16 -2
 - sglang/srt/layers/logits_processor.py +4 -20
 - sglang/srt/layers/moe/ep_moe/layer.py +0 -18
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
 - sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
 - sglang/srt/layers/moe/topk.py +31 -6
 - sglang/srt/layers/pooler.py +21 -2
 - sglang/srt/layers/quantization/__init__.py +9 -78
 - sglang/srt/layers/quantization/auto_round.py +394 -0
 - sglang/srt/layers/quantization/fp8_kernel.py +1 -1
 - sglang/srt/layers/quantization/fp8_utils.py +2 -2
 - sglang/srt/layers/quantization/modelopt_quant.py +168 -11
 - sglang/srt/layers/rotary_embedding.py +117 -45
 - sglang/srt/lora/lora_registry.py +9 -0
 - sglang/srt/managers/async_mm_data_processor.py +122 -0
 - sglang/srt/managers/data_parallel_controller.py +30 -3
 - sglang/srt/managers/detokenizer_manager.py +3 -0
 - sglang/srt/managers/io_struct.py +26 -4
 - sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
 - sglang/srt/managers/schedule_batch.py +74 -15
 - sglang/srt/managers/scheduler.py +164 -129
 - sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
 - sglang/srt/managers/scheduler_pp_mixin.py +7 -2
 - sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
 - sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
 - sglang/srt/managers/session_controller.py +6 -5
 - sglang/srt/managers/tokenizer_manager.py +154 -59
 - sglang/srt/managers/tp_worker.py +24 -1
 - sglang/srt/mem_cache/base_prefix_cache.py +23 -4
 - sglang/srt/mem_cache/common.py +1 -0
 - sglang/srt/mem_cache/memory_pool.py +171 -57
 - sglang/srt/mem_cache/memory_pool_host.py +12 -5
 - sglang/srt/mem_cache/radix_cache.py +4 -0
 - sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
 - sglang/srt/metrics/collector.py +46 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +15 -3
 - sglang/srt/model_executor/forward_batch_info.py +11 -11
 - sglang/srt/model_executor/model_runner.py +76 -21
 - sglang/srt/model_executor/npu_graph_runner.py +7 -3
 - sglang/srt/model_loader/weight_utils.py +1 -1
 - sglang/srt/models/bailing_moe.py +9 -2
 - sglang/srt/models/deepseek_nextn.py +11 -2
 - sglang/srt/models/deepseek_v2.py +149 -34
 - sglang/srt/models/glm4.py +391 -77
 - sglang/srt/models/glm4v.py +196 -55
 - sglang/srt/models/glm4v_moe.py +0 -1
 - sglang/srt/models/gpt_oss.py +1 -10
 - sglang/srt/models/kimi_linear.py +678 -0
 - sglang/srt/models/llama4.py +1 -1
 - sglang/srt/models/llama_eagle3.py +11 -1
 - sglang/srt/models/longcat_flash.py +2 -2
 - sglang/srt/models/minimax_m2.py +1 -1
 - sglang/srt/models/qwen2.py +1 -1
 - sglang/srt/models/qwen2_moe.py +30 -15
 - sglang/srt/models/qwen3.py +1 -1
 - sglang/srt/models/qwen3_moe.py +16 -8
 - sglang/srt/models/qwen3_next.py +7 -0
 - sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
 - sglang/srt/multiplex/multiplexing_mixin.py +209 -0
 - sglang/srt/multiplex/pdmux_context.py +164 -0
 - sglang/srt/parser/conversation.py +7 -1
 - sglang/srt/sampling/custom_logit_processor.py +67 -1
 - sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
 - sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
 - sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
 - sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
 - sglang/srt/server_args.py +103 -22
 - sglang/srt/single_batch_overlap.py +4 -1
 - sglang/srt/speculative/draft_utils.py +16 -0
 - sglang/srt/speculative/eagle_info.py +42 -36
 - sglang/srt/speculative/eagle_info_v2.py +68 -25
 - sglang/srt/speculative/eagle_utils.py +261 -16
 - sglang/srt/speculative/eagle_worker.py +11 -3
 - sglang/srt/speculative/eagle_worker_v2.py +15 -9
 - sglang/srt/speculative/spec_info.py +305 -31
 - sglang/srt/speculative/spec_utils.py +44 -8
 - sglang/srt/tracing/trace.py +121 -12
 - sglang/srt/utils/common.py +55 -32
 - sglang/srt/utils/hf_transformers_utils.py +38 -16
 - sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
 - sglang/test/kits/radix_cache_server_kit.py +50 -0
 - sglang/test/runners.py +31 -7
 - sglang/test/simple_eval_common.py +5 -3
 - sglang/test/simple_eval_humaneval.py +1 -0
 - sglang/test/simple_eval_math.py +1 -0
 - sglang/test/simple_eval_mmlu.py +1 -0
 - sglang/test/simple_eval_mmmu_vlm.py +1 -0
 - sglang/test/test_utils.py +7 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
 - /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -34,13 +34,21 @@ from sglang.srt.managers.io_struct import ( 
     | 
|
| 
       34 
34 
     | 
    
         
             
                TokenizedGenerateReqInput,
         
     | 
| 
       35 
35 
     | 
    
         
             
                WatchLoadUpdateReq,
         
     | 
| 
       36 
36 
     | 
    
         
             
            )
         
     | 
| 
       37 
     | 
    
         
            -
            from sglang.srt.managers.schedule_batch import Req
         
     | 
| 
      
 37 
     | 
    
         
            +
            from sglang.srt.managers.schedule_batch import Req, RequestStage
         
     | 
| 
       38 
38 
     | 
    
         
             
            from sglang.srt.managers.scheduler import run_scheduler_process
         
     | 
| 
       39 
39 
     | 
    
         
             
            from sglang.srt.server_args import (
         
     | 
| 
       40 
40 
     | 
    
         
             
                DP_ATTENTION_HANDSHAKE_PORT_DELTA,
         
     | 
| 
       41 
41 
     | 
    
         
             
                PortArgs,
         
     | 
| 
       42 
42 
     | 
    
         
             
                ServerArgs,
         
     | 
| 
       43 
43 
     | 
    
         
             
            )
         
     | 
| 
      
 44 
     | 
    
         
            +
            from sglang.srt.tracing.trace import (
         
     | 
| 
      
 45 
     | 
    
         
            +
                process_tracing_init,
         
     | 
| 
      
 46 
     | 
    
         
            +
                trace_get_proc_propagate_context,
         
     | 
| 
      
 47 
     | 
    
         
            +
                trace_set_proc_propagate_context,
         
     | 
| 
      
 48 
     | 
    
         
            +
                trace_set_thread_info,
         
     | 
| 
      
 49 
     | 
    
         
            +
                trace_slice_end,
         
     | 
| 
      
 50 
     | 
    
         
            +
                trace_slice_start,
         
     | 
| 
      
 51 
     | 
    
         
            +
            )
         
     | 
| 
       44 
52 
     | 
    
         
             
            from sglang.srt.utils import (
         
     | 
| 
       45 
53 
     | 
    
         
             
                bind_port,
         
     | 
| 
       46 
54 
     | 
    
         
             
                configure_logger,
         
     | 
| 
         @@ -170,11 +178,22 @@ class DataParallelController: 
     | 
|
| 
       170 
178 
     | 
    
         
             
                def handle_load_update_req(self, obj):
         
     | 
| 
       171 
179 
     | 
    
         
             
                    self.dp_budget.update_budget(obj)
         
     | 
| 
       172 
180 
     | 
    
         | 
| 
      
 181 
     | 
    
         
            +
                def dispatching_with_trace(self, req: Req):
         
     | 
| 
      
 182 
     | 
    
         
            +
                    if self.server_args.enable_trace:
         
     | 
| 
      
 183 
     | 
    
         
            +
                        trace_set_proc_propagate_context(req.rid, req.trace_context)
         
     | 
| 
      
 184 
     | 
    
         
            +
                        trace_slice_start(RequestStage.DC_DISPATCH, req.rid)
         
     | 
| 
      
 185 
     | 
    
         
            +
                        req.trace_context = trace_get_proc_propagate_context(req.rid)
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                    self.dispatching(req)
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                    if self.server_args.enable_trace:
         
     | 
| 
      
 190 
     | 
    
         
            +
                        trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True)
         
     | 
| 
      
 191 
     | 
    
         
            +
             
     | 
| 
       173 
192 
     | 
    
         
             
                def init_dispatcher(self):
         
     | 
| 
       174 
193 
     | 
    
         
             
                    self._request_dispatcher = TypeBasedDispatcher(
         
     | 
| 
       175 
194 
     | 
    
         
             
                        [
         
     | 
| 
       176 
     | 
    
         
            -
                            (TokenizedGenerateReqInput, self. 
     | 
| 
       177 
     | 
    
         
            -
                            (TokenizedEmbeddingReqInput, self. 
     | 
| 
      
 195 
     | 
    
         
            +
                            (TokenizedGenerateReqInput, self.dispatching_with_trace),
         
     | 
| 
      
 196 
     | 
    
         
            +
                            (TokenizedEmbeddingReqInput, self.dispatching_with_trace),
         
     | 
| 
       178 
197 
     | 
    
         
             
                            (BlockReqInput, self.send_to_all_workers),
         
     | 
| 
       179 
198 
     | 
    
         
             
                            (WatchLoadUpdateReq, self.handle_load_update_req),
         
     | 
| 
       180 
199 
     | 
    
         
             
                        ]
         
     | 
| 
         @@ -487,6 +506,14 @@ def run_data_parallel_controller_process( 
     | 
|
| 
       487 
506 
     | 
    
         
             
                pipe_writer,
         
     | 
| 
       488 
507 
     | 
    
         
             
            ):
         
     | 
| 
       489 
508 
     | 
    
         
             
                kill_itself_when_parent_died()
         
     | 
| 
      
 509 
     | 
    
         
            +
                if server_args.enable_trace:
         
     | 
| 
      
 510 
     | 
    
         
            +
                    process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
         
     | 
| 
      
 511 
     | 
    
         
            +
                    thread_label = "DP Controller"
         
     | 
| 
      
 512 
     | 
    
         
            +
                    if server_args.disaggregation_mode == "prefill":
         
     | 
| 
      
 513 
     | 
    
         
            +
                        thread_label = "Prefill DP Controller"
         
     | 
| 
      
 514 
     | 
    
         
            +
                    elif server_args.disaggregation_mode == "decode":
         
     | 
| 
      
 515 
     | 
    
         
            +
                        thread_label = "Decode DP Controller"
         
     | 
| 
      
 516 
     | 
    
         
            +
                    trace_set_thread_info(thread_label)
         
     | 
| 
       490 
517 
     | 
    
         
             
                setproctitle.setproctitle("sglang::data_parallel_controller")
         
     | 
| 
       491 
518 
     | 
    
         
             
                faulthandler.enable()
         
     | 
| 
       492 
519 
     | 
    
         
             
                configure_logger(server_args)
         
     | 
| 
         @@ -235,6 +235,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): 
     | 
|
| 
       235 
235 
     | 
    
         
             
                                new_text = ""
         
     | 
| 
       236 
236 
     | 
    
         
             
                            else:
         
     | 
| 
       237 
237 
     | 
    
         
             
                                new_text = find_printable_text(new_text)
         
     | 
| 
      
 238 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 239 
     | 
    
         
            +
                            del self.decode_status[recv_obj.rids[i]]
         
     | 
| 
       238 
240 
     | 
    
         | 
| 
       239 
241 
     | 
    
         
             
                        output_str = self.trim_matched_stop(
         
     | 
| 
       240 
242 
     | 
    
         
             
                            s.decoded_text + new_text,
         
     | 
| 
         @@ -273,6 +275,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): 
     | 
|
| 
       273 
275 
     | 
    
         
             
                        output_hidden_states=recv_obj.output_hidden_states,
         
     | 
| 
       274 
276 
     | 
    
         
             
                        placeholder_tokens_idx=None,
         
     | 
| 
       275 
277 
     | 
    
         
             
                        placeholder_tokens_val=None,
         
     | 
| 
      
 278 
     | 
    
         
            +
                        retraction_counts=recv_obj.retraction_counts,
         
     | 
| 
       276 
279 
     | 
    
         
             
                        token_steps=recv_obj.token_steps,
         
     | 
| 
       277 
280 
     | 
    
         
             
                    )
         
     | 
| 
       278 
281 
     | 
    
         | 
    
        sglang/srt/managers/io_struct.py
    CHANGED
    
    | 
         @@ -695,6 +695,9 @@ class EmbeddingReqInput(BaseReq): 
     | 
|
| 
       695 
695 
     | 
    
         
             
                # tracing context
         
     | 
| 
       696 
696 
     | 
    
         
             
                trace_context: Optional[Dict] = None
         
     | 
| 
       697 
697 
     | 
    
         | 
| 
      
 698 
     | 
    
         
            +
                # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
         
     | 
| 
      
 699 
     | 
    
         
            +
                dimensions: Optional[int] = None
         
     | 
| 
      
 700 
     | 
    
         
            +
             
     | 
| 
       698 
701 
     | 
    
         
             
                def normalize_batch_and_arguments(self):
         
     | 
| 
       699 
702 
     | 
    
         
             
                    # at least one of text, input_ids, or image should be provided
         
     | 
| 
       700 
703 
     | 
    
         
             
                    if self.text is None and self.input_ids is None and self.image_data is None:
         
     | 
| 
         @@ -771,6 +774,7 @@ class EmbeddingReqInput(BaseReq): 
     | 
|
| 
       771 
774 
     | 
    
         
             
                        video_data=self.video_data[i] if self.video_data is not None else None,
         
     | 
| 
       772 
775 
     | 
    
         
             
                        sampling_params=self.sampling_params[i],
         
     | 
| 
       773 
776 
     | 
    
         
             
                        rid=self.rid[i],
         
     | 
| 
      
 777 
     | 
    
         
            +
                        dimensions=self.dimensions,
         
     | 
| 
       774 
778 
     | 
    
         
             
                        http_worker_ipc=self.http_worker_ipc,
         
     | 
| 
       775 
779 
     | 
    
         
             
                    )
         
     | 
| 
       776 
780 
     | 
    
         | 
| 
         @@ -791,6 +795,8 @@ class TokenizedEmbeddingReqInput(BaseReq): 
     | 
|
| 
       791 
795 
     | 
    
         
             
                data_parallel_rank: Optional[int] = None
         
     | 
| 
       792 
796 
     | 
    
         
             
                # Priority for the request
         
     | 
| 
       793 
797 
     | 
    
         
             
                priority: Optional[int] = None
         
     | 
| 
      
 798 
     | 
    
         
            +
                # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
         
     | 
| 
      
 799 
     | 
    
         
            +
                dimensions: Optional[int] = None
         
     | 
| 
       794 
800 
     | 
    
         | 
| 
       795 
801 
     | 
    
         | 
| 
       796 
802 
     | 
    
         
             
            @dataclass
         
     | 
| 
         @@ -854,6 +860,9 @@ class BatchTokenIDOutput(BaseBatchReq): 
     | 
|
| 
       854 
860 
     | 
    
         
             
                placeholder_tokens_idx: List[Optional[List[int]]]
         
     | 
| 
       855 
861 
     | 
    
         
             
                placeholder_tokens_val: List[Optional[List[int]]]
         
     | 
| 
       856 
862 
     | 
    
         | 
| 
      
 863 
     | 
    
         
            +
                # Number of times each request was retracted.
         
     | 
| 
      
 864 
     | 
    
         
            +
                retraction_counts: List[int]
         
     | 
| 
      
 865 
     | 
    
         
            +
             
     | 
| 
       857 
866 
     | 
    
         
             
                # The trainer step id. Used to know which step's weights are used for sampling.
         
     | 
| 
       858 
867 
     | 
    
         
             
                token_steps: List[List[int]] = None
         
     | 
| 
       859 
868 
     | 
    
         | 
| 
         @@ -930,6 +939,9 @@ class BatchStrOutput(BaseBatchReq): 
     | 
|
| 
       930 
939 
     | 
    
         
             
                placeholder_tokens_idx: List[Optional[List[int]]]
         
     | 
| 
       931 
940 
     | 
    
         
             
                placeholder_tokens_val: List[Optional[List[int]]]
         
     | 
| 
       932 
941 
     | 
    
         | 
| 
      
 942 
     | 
    
         
            +
                # Number of times each request was retracted.
         
     | 
| 
      
 943 
     | 
    
         
            +
                retraction_counts: List[int]
         
     | 
| 
      
 944 
     | 
    
         
            +
             
     | 
| 
       933 
945 
     | 
    
         
             
                # The trainer step id. Used to know which step's weights are used for sampling.
         
     | 
| 
       934 
946 
     | 
    
         
             
                token_steps: List[List[int]] = None
         
     | 
| 
       935 
947 
     | 
    
         | 
| 
         @@ -972,6 +984,9 @@ class BatchEmbeddingOutput(BaseBatchReq): 
     | 
|
| 
       972 
984 
     | 
    
         
             
                placeholder_tokens_idx: List[Optional[List[int]]]
         
     | 
| 
       973 
985 
     | 
    
         
             
                placeholder_tokens_val: List[Optional[List[int]]]
         
     | 
| 
       974 
986 
     | 
    
         | 
| 
      
 987 
     | 
    
         
            +
                # Number of times each request was retracted.
         
     | 
| 
      
 988 
     | 
    
         
            +
                retraction_counts: List[int]
         
     | 
| 
      
 989 
     | 
    
         
            +
             
     | 
| 
       975 
990 
     | 
    
         | 
| 
       976 
991 
     | 
    
         
             
            @dataclass
         
     | 
| 
       977 
992 
     | 
    
         
             
            class ClearHiCacheReqInput(BaseReq):
         
     | 
| 
         @@ -1215,7 +1230,7 @@ class AbortReq(BaseReq): 
     | 
|
| 
       1215 
1230 
     | 
    
         
             
                abort_all: bool = False
         
     | 
| 
       1216 
1231 
     | 
    
         
             
                # The finished reason data
         
     | 
| 
       1217 
1232 
     | 
    
         
             
                finished_reason: Optional[Dict[str, Any]] = None
         
     | 
| 
       1218 
     | 
    
         
            -
                 
     | 
| 
      
 1233 
     | 
    
         
            +
                abort_message: Optional[str] = None
         
     | 
| 
       1219 
1234 
     | 
    
         | 
| 
       1220 
1235 
     | 
    
         
             
                def __post_init__(self):
         
     | 
| 
       1221 
1236 
     | 
    
         
             
                    # FIXME: This is a hack to keep the same with the old code
         
     | 
| 
         @@ -1458,6 +1473,16 @@ class WatchLoadUpdateReq(BaseReq): 
     | 
|
| 
       1458 
1473 
     | 
    
         
             
                loads: List[GetLoadReqOutput]
         
     | 
| 
       1459 
1474 
     | 
    
         | 
| 
       1460 
1475 
     | 
    
         | 
| 
      
 1476 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 1477 
     | 
    
         
            +
            class SetInjectDumpMetadataReqInput(BaseReq):
         
     | 
| 
      
 1478 
     | 
    
         
            +
                dump_metadata: Dict[str, Any]
         
     | 
| 
      
 1479 
     | 
    
         
            +
             
     | 
| 
      
 1480 
     | 
    
         
            +
             
     | 
| 
      
 1481 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 1482 
     | 
    
         
            +
            class SetInjectDumpMetadataReqOutput(BaseReq):
         
     | 
| 
      
 1483 
     | 
    
         
            +
                success: bool
         
     | 
| 
      
 1484 
     | 
    
         
            +
             
     | 
| 
      
 1485 
     | 
    
         
            +
             
     | 
| 
       1461 
1486 
     | 
    
         
             
            @dataclass
         
     | 
| 
       1462 
1487 
     | 
    
         
             
            class LazyDumpTensorsReqInput(BaseReq):
         
     | 
| 
       1463 
1488 
     | 
    
         
             
                pass
         
     | 
| 
         @@ -1489,6 +1514,3 @@ def _check_all_req_types(): 
     | 
|
| 
       1489 
1514 
     | 
    
         
             
                        raise ValueError(
         
     | 
| 
       1490 
1515 
     | 
    
         
             
                            f"{name} is a subclass of BaseReq but not follow the naming convention."
         
     | 
| 
       1491 
1516 
     | 
    
         
             
                        )
         
     | 
| 
       1492 
     | 
    
         
            -
             
     | 
| 
       1493 
     | 
    
         
            -
             
     | 
| 
       1494 
     | 
    
         
            -
            _check_all_req_types()
         
     | 
| 
         @@ -334,6 +334,11 @@ def _handle_output_by_index(output, i): 
     | 
|
| 
       334 
334 
     | 
    
         
             
                        ),
         
     | 
| 
       335 
335 
     | 
    
         
             
                        placeholder_tokens_idx=None,
         
     | 
| 
       336 
336 
     | 
    
         
             
                        placeholder_tokens_val=None,
         
     | 
| 
      
 337 
     | 
    
         
            +
                        retraction_counts=(
         
     | 
| 
      
 338 
     | 
    
         
            +
                            [output.retraction_counts[i]]
         
     | 
| 
      
 339 
     | 
    
         
            +
                            if len(output.retraction_counts) > i
         
     | 
| 
      
 340 
     | 
    
         
            +
                            else None
         
     | 
| 
      
 341 
     | 
    
         
            +
                        ),
         
     | 
| 
       337 
342 
     | 
    
         
             
                        token_steps=([output.token_steps[i]] if output.token_steps else None),
         
     | 
| 
       338 
343 
     | 
    
         
             
                    )
         
     | 
| 
       339 
344 
     | 
    
         
             
                elif isinstance(output, BatchMultimodalOutput):
         
     | 
| 
         @@ -2,6 +2,8 @@ from __future__ import annotations 
     | 
|
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            import enum
         
     | 
| 
       4 
4 
     | 
    
         | 
| 
      
 5 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
       5 
7 
     | 
    
         
             
            # Copyright 2023-2024 SGLang Team
         
     | 
| 
       6 
8 
     | 
    
         
             
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       7 
9 
     | 
    
         
             
            # you may not use this file except in compliance with the License.
         
     | 
| 
         @@ -70,11 +72,18 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool 
     | 
|
| 
       70 
72 
     | 
    
         
             
            from sglang.srt.mem_cache.radix_cache import RadixKey
         
     | 
| 
       71 
73 
     | 
    
         
             
            from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
         
     | 
| 
       72 
74 
     | 
    
         
             
            from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
         
     | 
| 
       73 
     | 
    
         
            -
            from sglang.srt.model_executor.forward_batch_info import  
     | 
| 
      
 75 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import (
         
     | 
| 
      
 76 
     | 
    
         
            +
                CaptureHiddenMode,
         
     | 
| 
      
 77 
     | 
    
         
            +
                ForwardBatch,
         
     | 
| 
      
 78 
     | 
    
         
            +
                ForwardMode,
         
     | 
| 
      
 79 
     | 
    
         
            +
            )
         
     | 
| 
       74 
80 
     | 
    
         
             
            from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
         
     | 
| 
       75 
81 
     | 
    
         
             
            from sglang.srt.sampling.sampling_params import SamplingParams
         
     | 
| 
       76 
82 
     | 
    
         
             
            from sglang.srt.server_args import ServerArgs, get_global_server_args
         
     | 
| 
       77 
83 
     | 
    
         
             
            from sglang.srt.utils import flatten_nested_list
         
     | 
| 
      
 84 
     | 
    
         
            +
            from sglang.srt.utils.common import is_npu
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
            _is_npu = is_npu()
         
     | 
| 
       78 
87 
     | 
    
         | 
| 
       79 
88 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
       80 
89 
     | 
    
         
             
                from sglang.srt.configs.model_config import ModelConfig
         
     | 
| 
         @@ -392,13 +401,23 @@ class MultimodalInputs: 
     | 
|
| 
       392 
401 
     | 
    
         | 
| 
       393 
402 
     | 
    
         | 
| 
       394 
403 
     | 
    
         
             
            class RequestStage(str, enum.Enum):
         
     | 
| 
       395 
     | 
    
         
            -
                #  
     | 
| 
      
 404 
     | 
    
         
            +
                # Tokenizer
         
     | 
| 
      
 405 
     | 
    
         
            +
                TOKENIZE = "tokenize"
         
     | 
| 
      
 406 
     | 
    
         
            +
                TOKENIZER_DISPATCH = "dispatch"
         
     | 
| 
      
 407 
     | 
    
         
            +
             
     | 
| 
      
 408 
     | 
    
         
            +
                # DP controller
         
     | 
| 
      
 409 
     | 
    
         
            +
                DC_DISPATCH = "dc_dispatch"
         
     | 
| 
      
 410 
     | 
    
         
            +
             
     | 
| 
      
 411 
     | 
    
         
            +
                # common/non-disaggregation
         
     | 
| 
       396 
412 
     | 
    
         
             
                PREFILL_WAITING = "prefill_waiting"
         
     | 
| 
      
 413 
     | 
    
         
            +
                REQUEST_PROCESS = "request_process"
         
     | 
| 
      
 414 
     | 
    
         
            +
                DECODE_LOOP = "decode_loop"
         
     | 
| 
      
 415 
     | 
    
         
            +
                PREFILL_FORWARD = "prefill_forward"
         
     | 
| 
      
 416 
     | 
    
         
            +
                PREFILL_CHUNKED_FORWARD = "chunked_prefill"
         
     | 
| 
       397 
417 
     | 
    
         | 
| 
       398 
418 
     | 
    
         
             
                # disaggregation prefill
         
     | 
| 
       399 
419 
     | 
    
         
             
                PREFILL_PREPARE = "prefill_prepare"
         
     | 
| 
       400 
420 
     | 
    
         
             
                PREFILL_BOOTSTRAP = "prefill_bootstrap"
         
     | 
| 
       401 
     | 
    
         
            -
                PREFILL_FORWARD = "prefill_forward"
         
     | 
| 
       402 
421 
     | 
    
         
             
                PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
         
     | 
| 
       403 
422 
     | 
    
         | 
| 
       404 
423 
     | 
    
         
             
                # disaggregation decode
         
     | 
| 
         @@ -406,6 +425,8 @@ class RequestStage(str, enum.Enum): 
     | 
|
| 
       406 
425 
     | 
    
         
             
                DECODE_BOOTSTRAP = "decode_bootstrap"
         
     | 
| 
       407 
426 
     | 
    
         
             
                DECODE_WAITING = "decode_waiting"
         
     | 
| 
       408 
427 
     | 
    
         
             
                DECODE_TRANSFERRED = "decode_transferred"
         
     | 
| 
      
 428 
     | 
    
         
            +
                DECODE_FAKE_OUTPUT = "fake_output"
         
     | 
| 
      
 429 
     | 
    
         
            +
                DECODE_QUICK_FINISH = "quick_finish"
         
     | 
| 
       409 
430 
     | 
    
         | 
| 
       410 
431 
     | 
    
         | 
| 
       411 
432 
     | 
    
         
             
            class Req:
         
     | 
| 
         @@ -438,6 +459,7 @@ class Req: 
     | 
|
| 
       438 
459 
     | 
    
         
             
                    priority: Optional[int] = None,
         
     | 
| 
       439 
460 
     | 
    
         
             
                    metrics_collector: Optional[SchedulerMetricsCollector] = None,
         
     | 
| 
       440 
461 
     | 
    
         
             
                    extra_key: Optional[str] = None,
         
     | 
| 
      
 462 
     | 
    
         
            +
                    dimensions: Optional[int] = None,
         
     | 
| 
       441 
463 
     | 
    
         
             
                    http_worker_ipc: Optional[str] = None,
         
     | 
| 
       442 
464 
     | 
    
         
             
                ):
         
     | 
| 
       443 
465 
     | 
    
         
             
                    # Input and output info
         
     | 
| 
         @@ -490,16 +512,15 @@ class Req: 
     | 
|
| 
       490 
512 
     | 
    
         | 
| 
       491 
513 
     | 
    
         
             
                    # Check finish
         
     | 
| 
       492 
514 
     | 
    
         
             
                    self.tokenizer = None
         
     | 
| 
       493 
     | 
    
         
            -
                    self.finished_reason = None
         
     | 
| 
      
 515 
     | 
    
         
            +
                    self.finished_reason: Optional[BaseFinishReason] = None
         
     | 
| 
       494 
516 
     | 
    
         
             
                    # finished position (in output_ids), used when checking stop conditions with speculative decoding
         
     | 
| 
       495 
517 
     | 
    
         
             
                    self.finished_len = None
         
     | 
| 
       496 
518 
     | 
    
         
             
                    # Whether this request has finished output
         
     | 
| 
       497 
519 
     | 
    
         
             
                    self.finished_output = None
         
     | 
| 
       498 
     | 
    
         
            -
                    # If we want to abort the request in the middle of the event loop, 
     | 
| 
      
 520 
     | 
    
         
            +
                    # If we want to abort the request in the middle of the event loop,
         
     | 
| 
      
 521 
     | 
    
         
            +
                    # set to_finish instead of directly setting finished_reason.
         
     | 
| 
       499 
522 
     | 
    
         
             
                    # Note: We should never set finished_reason in the middle, the req will get filtered and never respond
         
     | 
| 
       500 
     | 
    
         
            -
                    self. 
     | 
| 
       501 
     | 
    
         
            -
                    # This carries the error message for `.to_abort` and will be attached to the finished_reason at the end of the event loop
         
     | 
| 
       502 
     | 
    
         
            -
                    self.to_abort_message: str = None
         
     | 
| 
      
 523 
     | 
    
         
            +
                    self.to_finish: Optional[BaseFinishReason] = None
         
     | 
| 
       503 
524 
     | 
    
         
             
                    self.stream = stream
         
     | 
| 
       504 
525 
     | 
    
         
             
                    self.eos_token_ids = eos_token_ids
         
     | 
| 
       505 
526 
     | 
    
         
             
                    self.vocab_size = vocab_size
         
     | 
| 
         @@ -618,6 +639,9 @@ class Req: 
     | 
|
| 
       618 
639 
     | 
    
         
             
                    # This is used to compute the acceptance rate and average acceptance length per request.
         
     | 
| 
       619 
640 
     | 
    
         
             
                    self.spec_accepted_tokens = 0
         
     | 
| 
       620 
641 
     | 
    
         | 
| 
      
 642 
     | 
    
         
            +
                    # The number of times this request has been retracted / preempted.
         
     | 
| 
      
 643 
     | 
    
         
            +
                    self.retraction_count = 0
         
     | 
| 
      
 644 
     | 
    
         
            +
             
     | 
| 
       621 
645 
     | 
    
         
             
                    # For metrics
         
     | 
| 
       622 
646 
     | 
    
         
             
                    self.metrics_collector = metrics_collector
         
     | 
| 
       623 
647 
     | 
    
         
             
                    self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
         
     | 
| 
         @@ -646,6 +670,9 @@ class Req: 
     | 
|
| 
       646 
670 
     | 
    
         
             
                    self.tmp_end_idx: int = -1
         
     | 
| 
       647 
671 
     | 
    
         
             
                    self.metadata_buffer_index: int = -1
         
     | 
| 
       648 
672 
     | 
    
         | 
| 
      
 673 
     | 
    
         
            +
                    # For Matryoshka embeddings
         
     | 
| 
      
 674 
     | 
    
         
            +
                    self.dimensions = dimensions
         
     | 
| 
      
 675 
     | 
    
         
            +
             
     | 
| 
       649 
676 
     | 
    
         
             
                @property
         
     | 
| 
       650 
677 
     | 
    
         
             
                def seqlen(self):
         
     | 
| 
       651 
678 
     | 
    
         
             
                    return len(self.origin_input_ids) + len(self.output_ids)
         
     | 
| 
         @@ -845,10 +872,9 @@ class Req: 
     | 
|
| 
       845 
872 
     | 
    
         
             
                    if self.finished():
         
     | 
| 
       846 
873 
     | 
    
         
             
                        return
         
     | 
| 
       847 
874 
     | 
    
         | 
| 
       848 
     | 
    
         
            -
                    if self. 
     | 
| 
       849 
     | 
    
         
            -
                        self.finished_reason =  
     | 
| 
       850 
     | 
    
         
            -
             
     | 
| 
       851 
     | 
    
         
            -
                        )
         
     | 
| 
      
 875 
     | 
    
         
            +
                    if self.to_finish:
         
     | 
| 
      
 876 
     | 
    
         
            +
                        self.finished_reason = self.to_finish
         
     | 
| 
      
 877 
     | 
    
         
            +
                        self.to_finish = None
         
     | 
| 
       852 
878 
     | 
    
         
             
                        return
         
     | 
| 
       853 
879 
     | 
    
         | 
| 
       854 
880 
     | 
    
         
             
                    if len(self.output_ids) >= self.sampling_params.max_new_tokens:
         
     | 
| 
         @@ -875,6 +901,10 @@ class Req: 
     | 
|
| 
       875 
901 
     | 
    
         
             
                        return
         
     | 
| 
       876 
902 
     | 
    
         | 
| 
       877 
903 
     | 
    
         
             
                def reset_for_retract(self):
         
     | 
| 
      
 904 
     | 
    
         
            +
                    # Increment retraction count before resetting other state. We should not reset this
         
     | 
| 
      
 905 
     | 
    
         
            +
                    # since we are tracking the total number of retractions for each request.
         
     | 
| 
      
 906 
     | 
    
         
            +
                    self.retraction_count += 1
         
     | 
| 
      
 907 
     | 
    
         
            +
             
     | 
| 
       878 
908 
     | 
    
         
             
                    self.prefix_indices = torch.empty((0,), dtype=torch.int64)
         
     | 
| 
       879 
909 
     | 
    
         
             
                    self.last_node = None
         
     | 
| 
       880 
910 
     | 
    
         
             
                    self.swa_uuid_for_lock = None
         
     | 
| 
         @@ -920,7 +950,7 @@ class Req: 
     | 
|
| 
       920 
950 
     | 
    
         
             
                    self.grammar = None
         
     | 
| 
       921 
951 
     | 
    
         
             
                    self.origin_input_ids = [0]  # set it to one token to skip the long prefill
         
     | 
| 
       922 
952 
     | 
    
         
             
                    self.return_logprob = False
         
     | 
| 
       923 
     | 
    
         
            -
                    self. 
     | 
| 
      
 953 
     | 
    
         
            +
                    self.to_finish = FINISH_ABORT(
         
     | 
| 
       924 
954 
     | 
    
         
             
                        error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
         
     | 
| 
       925 
955 
     | 
    
         
             
                    )
         
     | 
| 
       926 
956 
     | 
    
         | 
| 
         @@ -1010,6 +1040,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): 
     | 
|
| 
       1010 
1040 
     | 
    
         
             
                encoder_lens_cpu: Optional[List[int]] = None
         
     | 
| 
       1011 
1041 
     | 
    
         
             
                encoder_out_cache_loc: Optional[torch.Tensor] = None
         
     | 
| 
       1012 
1042 
     | 
    
         | 
| 
      
 1043 
     | 
    
         
            +
                # For matryoshka embeddings
         
     | 
| 
      
 1044 
     | 
    
         
            +
                dimensions: Optional[list[int]] = None
         
     | 
| 
      
 1045 
     | 
    
         
            +
             
     | 
| 
      
 1046 
     | 
    
         
            +
                # For split prefill
         
     | 
| 
      
 1047 
     | 
    
         
            +
                split_index: int = 0
         
     | 
| 
      
 1048 
     | 
    
         
            +
                split_prefill_finished: bool = False
         
     | 
| 
      
 1049 
     | 
    
         
            +
                split_forward_count: int = 1
         
     | 
| 
      
 1050 
     | 
    
         
            +
                split_forward_batch: ForwardBatch = None
         
     | 
| 
      
 1051 
     | 
    
         
            +
                seq_lens_cpu_cache: torch.Tensor = None
         
     | 
| 
      
 1052 
     | 
    
         
            +
             
     | 
| 
       1013 
1053 
     | 
    
         
             
                # Stream
         
     | 
| 
       1014 
1054 
     | 
    
         
             
                has_stream: bool = False
         
     | 
| 
       1015 
1055 
     | 
    
         | 
| 
         @@ -1017,7 +1057,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): 
     | 
|
| 
       1017 
1057 
     | 
    
         
             
                has_grammar: bool = False
         
     | 
| 
       1018 
1058 
     | 
    
         | 
| 
       1019 
1059 
     | 
    
         
             
                # Device
         
     | 
| 
       1020 
     | 
    
         
            -
                 
     | 
| 
      
 1060 
     | 
    
         
            +
                if not _is_npu:
         
     | 
| 
      
 1061 
     | 
    
         
            +
                    device: str = "cuda"
         
     | 
| 
      
 1062 
     | 
    
         
            +
                else:
         
     | 
| 
      
 1063 
     | 
    
         
            +
                    device: str = "npu"
         
     | 
| 
       1021 
1064 
     | 
    
         | 
| 
       1022 
1065 
     | 
    
         
             
                # Speculative decoding
         
     | 
| 
       1023 
1066 
     | 
    
         
             
                spec_algorithm: SpeculativeAlgorithm = None
         
     | 
| 
         @@ -1166,6 +1209,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): 
     | 
|
| 
       1166 
1209 
     | 
    
         
             
                    prefix_lens = [len(r.prefix_indices) for r in reqs]
         
     | 
| 
       1167 
1210 
     | 
    
         
             
                    extend_lens = [r.extend_input_len for r in reqs]
         
     | 
| 
       1168 
1211 
     | 
    
         | 
| 
      
 1212 
     | 
    
         
            +
                    # For matryoshka embeddings
         
     | 
| 
      
 1213 
     | 
    
         
            +
                    if self.model_config.is_matryoshka and any(
         
     | 
| 
      
 1214 
     | 
    
         
            +
                        r.dimensions is not None for r in reqs
         
     | 
| 
      
 1215 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 1216 
     | 
    
         
            +
                        self.dimensions = [
         
     | 
| 
      
 1217 
     | 
    
         
            +
                            r.dimensions if r.dimensions else self.model_config.hidden_size
         
     | 
| 
      
 1218 
     | 
    
         
            +
                            for r in reqs
         
     | 
| 
      
 1219 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 1220 
     | 
    
         
            +
             
     | 
| 
       1169 
1221 
     | 
    
         
             
                    token_type_ids = [
         
     | 
| 
       1170 
1222 
     | 
    
         
             
                        r.token_type_ids for r in reqs if r.token_type_ids is not None
         
     | 
| 
       1171 
1223 
     | 
    
         
             
                    ]
         
     | 
| 
         @@ -1367,6 +1419,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): 
     | 
|
| 
       1367 
1419 
     | 
    
         
             
                    self.extend_num_tokens += running_bs
         
     | 
| 
       1368 
1420 
     | 
    
         
             
                    # TODO (lianmin): Revisit this. It should be seq_len - 1
         
     | 
| 
       1369 
1421 
     | 
    
         
             
                    self.extend_logprob_start_lens.extend([0] * running_bs)
         
     | 
| 
      
 1422 
     | 
    
         
            +
                    self.is_prefill_only = False
         
     | 
| 
       1370 
1423 
     | 
    
         | 
| 
       1371 
1424 
     | 
    
         
             
                def new_page_count_next_decode(self, selected_indices: Optional[List[int]] = None):
         
     | 
| 
       1372 
1425 
     | 
    
         
             
                    page_size = self.token_to_kv_pool_allocator.page_size
         
     | 
| 
         @@ -1397,7 +1450,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): 
     | 
|
| 
       1397 
1450 
     | 
    
         
             
                    evict_from_tree_cache(self.tree_cache, num_tokens)
         
     | 
| 
       1398 
1451 
     | 
    
         
             
                    return self._is_available_size_sufficient(num_tokens)
         
     | 
| 
       1399 
1452 
     | 
    
         | 
| 
       1400 
     | 
    
         
            -
                def retract_decode( 
     | 
| 
      
 1453 
     | 
    
         
            +
                def retract_decode(
         
     | 
| 
      
 1454 
     | 
    
         
            +
                    self, server_args: ServerArgs
         
     | 
| 
      
 1455 
     | 
    
         
            +
                ) -> Tuple[List[Req], float, List[Req]]:
         
     | 
| 
       1401 
1456 
     | 
    
         
             
                    """Retract the decoding requests when there is not enough memory."""
         
     | 
| 
       1402 
1457 
     | 
    
         
             
                    sorted_indices = list(range(len(self.reqs)))
         
     | 
| 
       1403 
1458 
     | 
    
         | 
| 
         @@ -1754,6 +1809,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): 
     | 
|
| 
       1754 
1809 
     | 
    
         
             
                        ),
         
     | 
| 
       1755 
1810 
     | 
    
         
             
                        extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
         
     | 
| 
       1756 
1811 
     | 
    
         
             
                        is_prefill_only=self.is_prefill_only,
         
     | 
| 
      
 1812 
     | 
    
         
            +
                        dimensions=self.dimensions,
         
     | 
| 
       1757 
1813 
     | 
    
         
             
                    )
         
     | 
| 
       1758 
1814 
     | 
    
         | 
| 
       1759 
1815 
     | 
    
         
             
                def copy(self):
         
     | 
| 
         @@ -1862,5 +1918,8 @@ class ModelWorkerBatch: 
     | 
|
| 
       1862 
1918 
     | 
    
         
             
                capture_hidden_mode: CaptureHiddenMode = None
         
     | 
| 
       1863 
1919 
     | 
    
         
             
                hicache_consumer_index: int = -1
         
     | 
| 
       1864 
1920 
     | 
    
         | 
| 
      
 1921 
     | 
    
         
            +
                # For matryoshka embeddings
         
     | 
| 
      
 1922 
     | 
    
         
            +
                dimensions: Optional[list[int]] = None
         
     | 
| 
      
 1923 
     | 
    
         
            +
             
     | 
| 
       1865 
1924 
     | 
    
         
             
                # Whether this batch is prefill-only (no token generation needed)
         
     | 
| 
       1866 
1925 
     | 
    
         
             
                is_prefill_only: bool = False
         
     |