sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -79,13 +79,17 @@ from sglang.srt.managers.io_struct import (
|
|
79
79
|
FreezeGCReq,
|
80
80
|
GetInternalStateReq,
|
81
81
|
GetInternalStateReqOutput,
|
82
|
+
GetLoadReqInput,
|
83
|
+
GetLoadReqOutput,
|
82
84
|
GetWeightsByNameReqInput,
|
83
85
|
HealthCheckOutput,
|
86
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
87
|
+
InitWeightsSendGroupForRemoteInstanceReqOutput,
|
84
88
|
InitWeightsUpdateGroupReqInput,
|
85
89
|
LoadLoRAAdapterReqInput,
|
86
90
|
LoadLoRAAdapterReqOutput,
|
87
91
|
MultiTokenizerRegisterReq,
|
88
|
-
|
92
|
+
MultiTokenizerWrapper,
|
89
93
|
OpenSessionReqInput,
|
90
94
|
OpenSessionReqOutput,
|
91
95
|
ProfileReq,
|
@@ -93,6 +97,8 @@ from sglang.srt.managers.io_struct import (
|
|
93
97
|
ResumeMemoryOccupationReqInput,
|
94
98
|
RpcReqInput,
|
95
99
|
RpcReqOutput,
|
100
|
+
SendWeightsToRemoteInstanceReqInput,
|
101
|
+
SendWeightsToRemoteInstanceReqOutput,
|
96
102
|
SetInternalStateReq,
|
97
103
|
SetInternalStateReqOutput,
|
98
104
|
SlowDownReqInput,
|
@@ -145,6 +151,15 @@ from sglang.srt.parser.reasoning_parser import ReasoningParser
|
|
145
151
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
146
152
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
147
153
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
154
|
+
from sglang.srt.tracing.trace import (
|
155
|
+
process_tracing_init,
|
156
|
+
trace_event,
|
157
|
+
trace_set_proc_propagate_context,
|
158
|
+
trace_set_thread_info,
|
159
|
+
trace_slice,
|
160
|
+
trace_slice_end,
|
161
|
+
trace_slice_start,
|
162
|
+
)
|
148
163
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
149
164
|
from sglang.srt.utils import (
|
150
165
|
DynamicGradMode,
|
@@ -158,6 +173,7 @@ from sglang.srt.utils import (
|
|
158
173
|
get_zmq_socket,
|
159
174
|
is_cpu,
|
160
175
|
kill_itself_when_parent_died,
|
176
|
+
numa_bind_to_node,
|
161
177
|
point_to_point_pyobj,
|
162
178
|
pyspy_dump_schedulers,
|
163
179
|
require_mlp_sync,
|
@@ -348,6 +364,18 @@ class Scheduler(
|
|
348
364
|
target_worker=self.tp_worker,
|
349
365
|
dp_rank=dp_rank,
|
350
366
|
)
|
367
|
+
elif self.spec_algorithm.is_standalone():
|
368
|
+
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
369
|
+
|
370
|
+
self.draft_worker = StandaloneWorker(
|
371
|
+
gpu_id=gpu_id,
|
372
|
+
tp_rank=tp_rank,
|
373
|
+
moe_ep_rank=moe_ep_rank,
|
374
|
+
server_args=server_args,
|
375
|
+
nccl_port=port_args.nccl_port,
|
376
|
+
target_worker=self.tp_worker,
|
377
|
+
dp_rank=dp_rank,
|
378
|
+
)
|
351
379
|
else:
|
352
380
|
self.draft_worker = None
|
353
381
|
|
@@ -401,7 +429,7 @@ class Scheduler(
|
|
401
429
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
402
430
|
f"max_running_requests={self.max_running_requests}, "
|
403
431
|
f"context_len={self.model_config.context_len}, "
|
404
|
-
f"available_gpu_mem={avail_mem:.2f} GB"
|
432
|
+
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
|
405
433
|
)
|
406
434
|
|
407
435
|
# Init memory pool and cache
|
@@ -488,7 +516,7 @@ class Scheduler(
|
|
488
516
|
enable=server_args.enable_memory_saver
|
489
517
|
)
|
490
518
|
self.offload_tags = set()
|
491
|
-
self.
|
519
|
+
self.init_profiler()
|
492
520
|
|
493
521
|
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
494
522
|
self.input_blocker = (
|
@@ -525,6 +553,14 @@ class Scheduler(
|
|
525
553
|
(CloseSessionReqInput, self.close_session),
|
526
554
|
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
|
527
555
|
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
|
556
|
+
(
|
557
|
+
InitWeightsSendGroupForRemoteInstanceReqInput,
|
558
|
+
self.init_weights_send_group_for_remote_instance,
|
559
|
+
),
|
560
|
+
(
|
561
|
+
SendWeightsToRemoteInstanceReqInput,
|
562
|
+
self.send_weights_to_remote_instance,
|
563
|
+
),
|
528
564
|
(
|
529
565
|
UpdateWeightsFromDistributedReqInput,
|
530
566
|
self.update_weights_from_distributed,
|
@@ -543,6 +579,7 @@ class Scheduler(
|
|
543
579
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
544
580
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
545
581
|
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
582
|
+
(GetLoadReqInput, self.get_load),
|
546
583
|
]
|
547
584
|
)
|
548
585
|
|
@@ -622,6 +659,7 @@ class Scheduler(
|
|
622
659
|
hicache_write_policy=server_args.hicache_write_policy,
|
623
660
|
hicache_io_backend=server_args.hicache_io_backend,
|
624
661
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
662
|
+
enable_metrics=self.enable_metrics,
|
625
663
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
626
664
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
627
665
|
model_name=server_args.served_model_name,
|
@@ -654,6 +692,21 @@ class Scheduler(
|
|
654
692
|
page_size=self.page_size,
|
655
693
|
disable=server_args.disable_radix_cache,
|
656
694
|
)
|
695
|
+
elif server_args.enable_lmcache:
|
696
|
+
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
|
697
|
+
LMCRadixCache,
|
698
|
+
)
|
699
|
+
|
700
|
+
self.tree_cache = LMCRadixCache(
|
701
|
+
req_to_token_pool=self.req_to_token_pool,
|
702
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
703
|
+
page_size=self.page_size,
|
704
|
+
disable=server_args.disable_radix_cache,
|
705
|
+
model_config=self.model_config,
|
706
|
+
tp_size=self.tp_size,
|
707
|
+
rank=self.tp_rank,
|
708
|
+
tp_group=self.tp_group,
|
709
|
+
)
|
657
710
|
else:
|
658
711
|
self.tree_cache = RadixCache(
|
659
712
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -785,6 +838,10 @@ class Scheduler(
|
|
785
838
|
batch = self.get_next_batch_to_run()
|
786
839
|
self.cur_batch = batch
|
787
840
|
|
841
|
+
if batch:
|
842
|
+
for req in batch.reqs:
|
843
|
+
trace_event("schedule", req.rid)
|
844
|
+
|
788
845
|
if batch:
|
789
846
|
result = self.run_batch(batch)
|
790
847
|
self.process_batch_result(batch, result)
|
@@ -806,6 +863,10 @@ class Scheduler(
|
|
806
863
|
batch = self.get_next_batch_to_run()
|
807
864
|
self.cur_batch = batch
|
808
865
|
|
866
|
+
if batch:
|
867
|
+
for req in batch.reqs:
|
868
|
+
trace_event("schedule", req.rid)
|
869
|
+
|
809
870
|
if batch:
|
810
871
|
batch.launch_done = threading.Event()
|
811
872
|
result = self.run_batch(batch)
|
@@ -1069,6 +1130,12 @@ class Scheduler(
|
|
1069
1130
|
self.tp_cpu_group,
|
1070
1131
|
src=self.tp_group.ranks[0],
|
1071
1132
|
)
|
1133
|
+
|
1134
|
+
for req in recv_reqs:
|
1135
|
+
if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
|
1136
|
+
trace_set_proc_propagate_context(req.rid, req.trace_context)
|
1137
|
+
trace_slice_start("", req.rid, anonymous=True)
|
1138
|
+
|
1072
1139
|
return recv_reqs
|
1073
1140
|
|
1074
1141
|
def process_input_requests(self, recv_reqs: List):
|
@@ -1096,13 +1163,13 @@ class Scheduler(
|
|
1096
1163
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
1097
1164
|
continue
|
1098
1165
|
|
1099
|
-
# If it is a
|
1100
|
-
if isinstance(recv_req,
|
1166
|
+
# If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
|
1167
|
+
if isinstance(recv_req, MultiTokenizerWrapper):
|
1101
1168
|
worker_id = recv_req.worker_id
|
1102
1169
|
recv_req = recv_req.obj
|
1103
1170
|
output = self._request_dispatcher(recv_req)
|
1104
1171
|
if output is not None:
|
1105
|
-
output =
|
1172
|
+
output = MultiTokenizerWrapper(worker_id, output)
|
1106
1173
|
self.send_to_tokenizer.send_pyobj(output)
|
1107
1174
|
continue
|
1108
1175
|
|
@@ -1114,6 +1181,16 @@ class Scheduler(
|
|
1114
1181
|
else:
|
1115
1182
|
self.send_to_tokenizer.send_pyobj(output)
|
1116
1183
|
|
1184
|
+
def init_req_max_new_tokens(self, req):
|
1185
|
+
req.sampling_params.max_new_tokens = min(
|
1186
|
+
(
|
1187
|
+
req.sampling_params.max_new_tokens
|
1188
|
+
if req.sampling_params.max_new_tokens is not None
|
1189
|
+
else 1 << 30
|
1190
|
+
),
|
1191
|
+
self.max_req_len - len(req.origin_input_ids) - 1,
|
1192
|
+
)
|
1193
|
+
|
1117
1194
|
def handle_generate_request(
|
1118
1195
|
self,
|
1119
1196
|
recv_req: TokenizedGenerateReqInput,
|
@@ -1177,6 +1254,7 @@ class Scheduler(
|
|
1177
1254
|
req.set_finish_with_abort(
|
1178
1255
|
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
1179
1256
|
)
|
1257
|
+
self.init_req_max_new_tokens(req)
|
1180
1258
|
self._add_request_to_queue(req)
|
1181
1259
|
return
|
1182
1260
|
else:
|
@@ -1184,6 +1262,7 @@ class Scheduler(
|
|
1184
1262
|
session = self.sessions[recv_req.session_params.id]
|
1185
1263
|
req = session.create_req(recv_req, self.tokenizer)
|
1186
1264
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
1265
|
+
self.init_req_max_new_tokens(req)
|
1187
1266
|
self._add_request_to_queue(req)
|
1188
1267
|
return
|
1189
1268
|
|
@@ -1203,9 +1282,13 @@ class Scheduler(
|
|
1203
1282
|
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
|
1204
1283
|
)
|
1205
1284
|
)
|
1285
|
+
self.init_req_max_new_tokens(req)
|
1206
1286
|
self._add_request_to_queue(req)
|
1207
1287
|
return
|
1208
1288
|
|
1289
|
+
# initialize before returning
|
1290
|
+
self.init_req_max_new_tokens(req)
|
1291
|
+
|
1209
1292
|
# Validate prompt length
|
1210
1293
|
error_msg = validate_input_length(
|
1211
1294
|
req,
|
@@ -1220,26 +1303,25 @@ class Scheduler(
|
|
1220
1303
|
# Copy more attributes
|
1221
1304
|
if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
|
1222
1305
|
# By default, only return the logprobs for output tokens
|
1223
|
-
|
1306
|
+
# For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
|
1307
|
+
# to skip input logprob computation entirely
|
1308
|
+
if req.is_prefill_only:
|
1309
|
+
req.logprob_start_len = len(req.origin_input_ids)
|
1310
|
+
else:
|
1311
|
+
# TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
|
1312
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1224
1313
|
else:
|
1225
1314
|
req.logprob_start_len = recv_req.logprob_start_len
|
1226
1315
|
|
1227
|
-
if req.logprob_start_len >= len(
|
1316
|
+
if not req.is_prefill_only and req.logprob_start_len >= len(
|
1317
|
+
req.origin_input_ids
|
1318
|
+
):
|
1228
1319
|
error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
|
1229
1320
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1230
1321
|
req.set_finish_with_abort(error_msg)
|
1231
1322
|
self._add_request_to_queue(req)
|
1232
1323
|
return
|
1233
1324
|
|
1234
|
-
req.sampling_params.max_new_tokens = min(
|
1235
|
-
(
|
1236
|
-
req.sampling_params.max_new_tokens
|
1237
|
-
if req.sampling_params.max_new_tokens is not None
|
1238
|
-
else 1 << 30
|
1239
|
-
),
|
1240
|
-
self.max_req_len - len(req.origin_input_ids) - 1,
|
1241
|
-
)
|
1242
|
-
|
1243
1325
|
# Init grammar cache for this request
|
1244
1326
|
add_to_grammar_queue = False
|
1245
1327
|
if (
|
@@ -1298,6 +1380,7 @@ class Scheduler(
|
|
1298
1380
|
else:
|
1299
1381
|
self._prefetch_kvcache(req)
|
1300
1382
|
self.waiting_queue.append(req)
|
1383
|
+
trace_slice_end("process req", req.rid, auto_next_anon=True)
|
1301
1384
|
|
1302
1385
|
def _prefetch_kvcache(self, req: Req):
|
1303
1386
|
if self.enable_hicache_storage:
|
@@ -1409,9 +1492,11 @@ class Scheduler(
|
|
1409
1492
|
_, _, available_size, evictable_size = self._get_token_info()
|
1410
1493
|
protected_size = self.tree_cache.protected_size()
|
1411
1494
|
memory_leak = (available_size + evictable_size) != (
|
1495
|
+
# self.max_total_num_tokens
|
1496
|
+
# if not self.enable_hierarchical_cache
|
1497
|
+
# else self.max_total_num_tokens - protected_size
|
1412
1498
|
self.max_total_num_tokens
|
1413
|
-
|
1414
|
-
else self.max_total_num_tokens - protected_size
|
1499
|
+
- protected_size
|
1415
1500
|
)
|
1416
1501
|
token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
|
1417
1502
|
|
@@ -1462,6 +1547,20 @@ class Scheduler(
|
|
1462
1547
|
self.stats.gen_throughput = 0
|
1463
1548
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
1464
1549
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1550
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1551
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
1552
|
+
self.disagg_prefill_bootstrap_queue.queue
|
1553
|
+
)
|
1554
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
1555
|
+
self.disagg_prefill_inflight_queue
|
1556
|
+
)
|
1557
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1558
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
1559
|
+
self.disagg_decode_prealloc_queue.queue
|
1560
|
+
)
|
1561
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
1562
|
+
self.disagg_decode_transfer_queue.queue
|
1563
|
+
)
|
1465
1564
|
self.metrics_collector.log_stats(self.stats)
|
1466
1565
|
self._publish_kv_events()
|
1467
1566
|
|
@@ -1509,7 +1608,12 @@ class Scheduler(
|
|
1509
1608
|
chunked_req_to_exclude.add(self.chunked_req)
|
1510
1609
|
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1511
1610
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1512
|
-
self.
|
1611
|
+
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
|
1612
|
+
self.req_to_token_pool.free(
|
1613
|
+
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
1614
|
+
)
|
1615
|
+
else:
|
1616
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1513
1617
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
1514
1618
|
if self.last_batch.chunked_req is not None:
|
1515
1619
|
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
@@ -1776,10 +1880,6 @@ class Scheduler(
|
|
1776
1880
|
if self.spec_algorithm.is_none():
|
1777
1881
|
model_worker_batch = batch.get_model_worker_batch()
|
1778
1882
|
|
1779
|
-
# update the consumer index of hicache to the running batch
|
1780
|
-
self.tp_worker.set_hicache_consumer(
|
1781
|
-
model_worker_batch.hicache_consumer_index
|
1782
|
-
)
|
1783
1883
|
if self.pp_group.is_last_rank:
|
1784
1884
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
1785
1885
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
@@ -1848,8 +1948,23 @@ class Scheduler(
|
|
1848
1948
|
):
|
1849
1949
|
if batch.forward_mode.is_decode():
|
1850
1950
|
self.process_batch_result_decode(batch, result, launch_done)
|
1951
|
+
for req in batch.reqs:
|
1952
|
+
trace_slice(
|
1953
|
+
"decode loop",
|
1954
|
+
req.rid,
|
1955
|
+
auto_next_anon=not req.finished(),
|
1956
|
+
thread_finish_flag=req.finished(),
|
1957
|
+
)
|
1958
|
+
|
1851
1959
|
elif batch.forward_mode.is_extend():
|
1852
1960
|
self.process_batch_result_prefill(batch, result, launch_done)
|
1961
|
+
for req in batch.reqs:
|
1962
|
+
trace_slice(
|
1963
|
+
"prefill",
|
1964
|
+
req.rid,
|
1965
|
+
auto_next_anon=not req.finished(),
|
1966
|
+
thread_finish_flag=req.finished(),
|
1967
|
+
)
|
1853
1968
|
elif batch.forward_mode.is_idle():
|
1854
1969
|
if self.enable_overlap:
|
1855
1970
|
self.tp_worker.resolve_last_batch_result(launch_done)
|
@@ -2174,39 +2289,50 @@ class Scheduler(
|
|
2174
2289
|
if_success = False
|
2175
2290
|
return if_success
|
2176
2291
|
|
2177
|
-
def get_load(self):
|
2292
|
+
def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
|
2178
2293
|
# TODO(lsyin): use dynamically maintained num_waiting_tokens
|
2294
|
+
|
2179
2295
|
if self.is_hybrid:
|
2180
|
-
|
2296
|
+
num_tokens_full = (
|
2181
2297
|
self.full_tokens_per_layer
|
2182
2298
|
- self.token_to_kv_pool_allocator.full_available_size()
|
2183
2299
|
- self.tree_cache.full_evictable_size()
|
2184
2300
|
)
|
2185
|
-
|
2301
|
+
num_tokens_swa = (
|
2186
2302
|
self.swa_tokens_per_layer
|
2187
2303
|
- self.token_to_kv_pool_allocator.swa_available_size()
|
2188
2304
|
- self.tree_cache.swa_evictable_size()
|
2189
2305
|
)
|
2190
|
-
|
2306
|
+
num_tokens = max(num_tokens_full, num_tokens_swa)
|
2191
2307
|
else:
|
2192
|
-
|
2308
|
+
num_tokens = (
|
2193
2309
|
self.max_total_num_tokens
|
2194
2310
|
- self.token_to_kv_pool_allocator.available_size()
|
2195
2311
|
- self.tree_cache.evictable_size()
|
2196
2312
|
)
|
2197
|
-
|
2313
|
+
|
2314
|
+
# Tokens in waiting queue, bootstrap queue, prealloc queue
|
2315
|
+
num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
|
2316
|
+
num_waiting_reqs = len(self.waiting_queue)
|
2198
2317
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
2199
|
-
|
2318
|
+
num_tokens += sum(
|
2200
2319
|
len(req.origin_input_ids)
|
2201
2320
|
for req in self.disagg_prefill_bootstrap_queue.queue
|
2202
2321
|
)
|
2322
|
+
num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
|
2203
2323
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
2204
|
-
|
2324
|
+
num_tokens += sum(
|
2205
2325
|
len(req.req.origin_input_ids)
|
2206
2326
|
for req in self.disagg_decode_prealloc_queue.queue
|
2207
2327
|
)
|
2328
|
+
num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
|
2208
2329
|
|
2209
|
-
return
|
2330
|
+
return GetLoadReqOutput(
|
2331
|
+
dp_rank=self.dp_rank,
|
2332
|
+
num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
|
2333
|
+
num_waiting_reqs=num_waiting_reqs,
|
2334
|
+
num_tokens=num_tokens,
|
2335
|
+
)
|
2210
2336
|
|
2211
2337
|
def get_internal_state(self, recv_req: GetInternalStateReq):
|
2212
2338
|
ret = dict(global_server_args_dict)
|
@@ -2221,10 +2347,9 @@ class Scheduler(
|
|
2221
2347
|
"token_capacity": int(self.max_total_num_tokens),
|
2222
2348
|
}
|
2223
2349
|
|
2224
|
-
|
2225
|
-
|
2226
|
-
|
2227
|
-
)
|
2350
|
+
ret["memory_usage"]["graph"] = round(
|
2351
|
+
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
2352
|
+
)
|
2228
2353
|
|
2229
2354
|
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
2230
2355
|
ret["avg_spec_accept_length"] = (
|
@@ -2233,8 +2358,6 @@ class Scheduler(
|
|
2233
2358
|
if RECORD_STEP_TIME:
|
2234
2359
|
ret["step_time_dict"] = self.step_time_dict
|
2235
2360
|
|
2236
|
-
ret["load"] = self.get_load()
|
2237
|
-
|
2238
2361
|
return GetInternalStateReqOutput(internal_state=ret)
|
2239
2362
|
|
2240
2363
|
def set_internal_state(self, recv_req: SetInternalStateReq):
|
@@ -2398,6 +2521,22 @@ class Scheduler(
|
|
2398
2521
|
self.send_to_detokenizer.send_pyobj(recv_req)
|
2399
2522
|
return recv_req
|
2400
2523
|
|
2524
|
+
def init_weights_send_group_for_remote_instance(
|
2525
|
+
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
|
2526
|
+
):
|
2527
|
+
"""Init the seed and client instance communication group."""
|
2528
|
+
success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
|
2529
|
+
recv_req
|
2530
|
+
)
|
2531
|
+
return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
|
2532
|
+
|
2533
|
+
def send_weights_to_remote_instance(
|
2534
|
+
self, recv_req: SendWeightsToRemoteInstanceReqInput
|
2535
|
+
):
|
2536
|
+
"""Send the seed instance weights to the destination instance."""
|
2537
|
+
success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
|
2538
|
+
return SendWeightsToRemoteInstanceReqOutput(success, message)
|
2539
|
+
|
2401
2540
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2402
2541
|
t = recv_req.forward_sleep_time
|
2403
2542
|
if t is not None and t <= 0:
|
@@ -2519,6 +2658,15 @@ def run_scheduler_process(
|
|
2519
2658
|
pipe_writer,
|
2520
2659
|
balance_meta: Optional[DPBalanceMeta] = None,
|
2521
2660
|
):
|
2661
|
+
if server_args.enable_trace:
|
2662
|
+
process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
|
2663
|
+
if server_args.disaggregation_mode == "null":
|
2664
|
+
thread_label = "Scheduler"
|
2665
|
+
trace_set_thread_info(thread_label, tp_rank, dp_rank)
|
2666
|
+
|
2667
|
+
if (numa_node := server_args.numa_node) is not None:
|
2668
|
+
numa_bind_to_node(numa_node[gpu_id])
|
2669
|
+
|
2522
2670
|
# Generate the prefix
|
2523
2671
|
prefix = ""
|
2524
2672
|
if dp_rank is not None:
|
@@ -214,7 +214,7 @@ class SchedulerMetricsMixin:
|
|
214
214
|
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
215
215
|
|
216
216
|
msg += (
|
217
|
-
f"cuda graph: {can_run_cuda_graph}, "
|
217
|
+
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
|
218
218
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
219
219
|
f"#queue-req: {len(self.waiting_queue)}, "
|
220
220
|
)
|
@@ -230,7 +230,7 @@ class SchedulerMetricsMixin:
|
|
230
230
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
231
231
|
self.stats.spec_accept_length = spec_accept_length
|
232
232
|
self.stats.total_retracted_reqs = self.total_retracted_reqs
|
233
|
-
self.
|
233
|
+
self.stats.avg_request_queue_latency = 0.0
|
234
234
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
235
235
|
self.stats.num_decode_prealloc_queue_reqs = len(
|
236
236
|
self.disagg_decode_prealloc_queue.queue
|
@@ -238,6 +238,7 @@ class SchedulerMetricsMixin:
|
|
238
238
|
self.stats.num_decode_transfer_queue_reqs = len(
|
239
239
|
self.disagg_decode_transfer_queue.queue
|
240
240
|
)
|
241
|
+
self.metrics_collector.log_stats(self.stats)
|
241
242
|
self._emit_kv_metrics()
|
242
243
|
self._publish_kv_events()
|
243
244
|
|
@@ -278,7 +279,7 @@ class SchedulerMetricsMixin:
|
|
278
279
|
self.server_args.load_balance_method == "minimum_tokens"
|
279
280
|
and self.forward_ct % 40 == 0
|
280
281
|
):
|
281
|
-
holding_tokens = self.get_load()
|
282
|
+
holding_tokens = self.get_load().num_tokens
|
282
283
|
|
283
284
|
new_recv_dp_balance_id_list, holding_token_list = (
|
284
285
|
self.gather_dp_balance_info(holding_tokens)
|
@@ -5,6 +5,8 @@ import threading
|
|
5
5
|
import time
|
6
6
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
7
7
|
|
8
|
+
import torch
|
9
|
+
|
8
10
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
9
11
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
10
12
|
from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut
|
@@ -71,6 +73,7 @@ class SchedulerOutputProcessorMixin:
|
|
71
73
|
|
72
74
|
# Check finish conditions
|
73
75
|
logprob_pt = 0
|
76
|
+
|
74
77
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
75
78
|
if req.is_retracted:
|
76
79
|
continue
|
@@ -99,6 +102,7 @@ class SchedulerOutputProcessorMixin:
|
|
99
102
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
100
103
|
extend_input_len = extend_input_len_per_req[i]
|
101
104
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
105
|
+
|
102
106
|
if req.return_logprob:
|
103
107
|
self.add_logprob_return_values(
|
104
108
|
i,
|
@@ -441,27 +445,59 @@ class SchedulerOutputProcessorMixin:
|
|
441
445
|
output: LogitsProcessorOutput,
|
442
446
|
):
|
443
447
|
"""Attach logprobs to the return values."""
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
448
|
+
if output.next_token_logprobs is not None:
|
449
|
+
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
450
|
+
req.output_token_logprobs_idx.append(next_token_ids[i])
|
451
|
+
|
452
|
+
# Only add input logprobs if there are input tokens to process
|
453
|
+
# Note: For prefill-only requests with default logprob_start_len, this will be 0,
|
454
|
+
# meaning we only compute output logprobs (which is the intended behavior)
|
455
|
+
if num_input_logprobs > 0:
|
456
|
+
self.add_input_logprob_return_values(
|
457
|
+
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
|
458
|
+
)
|
459
|
+
else:
|
460
|
+
self._initialize_empty_logprob_containers(req)
|
450
461
|
|
451
462
|
if req.top_logprobs_num > 0:
|
452
463
|
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
453
464
|
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
454
465
|
|
455
|
-
if
|
456
|
-
req.
|
457
|
-
|
458
|
-
|
466
|
+
if (
|
467
|
+
req.token_ids_logprob is not None
|
468
|
+
and output.next_token_token_ids_logprobs_val is not None
|
469
|
+
):
|
470
|
+
# Convert GPU tensor to list if needed
|
471
|
+
logprobs_val = output.next_token_token_ids_logprobs_val[i]
|
472
|
+
if isinstance(logprobs_val, torch.Tensor):
|
473
|
+
logprobs_val = logprobs_val.tolist()
|
474
|
+
req.output_token_ids_logprobs_val.append(logprobs_val)
|
459
475
|
req.output_token_ids_logprobs_idx.append(
|
460
476
|
output.next_token_token_ids_logprobs_idx[i]
|
461
477
|
)
|
462
478
|
|
463
479
|
return num_input_logprobs
|
464
480
|
|
481
|
+
def _initialize_empty_logprob_containers(self, req: Req) -> None:
|
482
|
+
"""
|
483
|
+
Initialize logprob fields to empty lists if unset.
|
484
|
+
|
485
|
+
This is needed for prefill-only requests where the normal initialization
|
486
|
+
flow might be bypassed, but downstream code expects these fields to be lists.
|
487
|
+
"""
|
488
|
+
if req.input_token_logprobs_val is None:
|
489
|
+
req.input_token_logprobs_val = []
|
490
|
+
if req.input_token_logprobs_idx is None:
|
491
|
+
req.input_token_logprobs_idx = []
|
492
|
+
if req.input_top_logprobs_val is None:
|
493
|
+
req.input_top_logprobs_val = []
|
494
|
+
if req.input_top_logprobs_idx is None:
|
495
|
+
req.input_top_logprobs_idx = []
|
496
|
+
if req.input_token_ids_logprobs_val is None:
|
497
|
+
req.input_token_ids_logprobs_val = []
|
498
|
+
if req.input_token_ids_logprobs_idx is None:
|
499
|
+
req.input_token_ids_logprobs_idx = []
|
500
|
+
|
465
501
|
def stream_output(
|
466
502
|
self: Scheduler,
|
467
503
|
reqs: List[Req],
|
@@ -700,6 +736,8 @@ class SchedulerOutputProcessorMixin:
|
|
700
736
|
output_token_ids_logprobs_val,
|
701
737
|
output_token_ids_logprobs_idx,
|
702
738
|
output_hidden_states,
|
739
|
+
placeholder_tokens_idx=None,
|
740
|
+
placeholder_tokens_val=None,
|
703
741
|
)
|
704
742
|
)
|
705
743
|
|
@@ -719,6 +757,12 @@ class SchedulerOutputProcessorMixin:
|
|
719
757
|
cached_tokens.append(req.cached_tokens)
|
720
758
|
self.send_to_detokenizer.send_pyobj(
|
721
759
|
BatchEmbeddingOut(
|
722
|
-
rids,
|
760
|
+
rids,
|
761
|
+
finished_reasons,
|
762
|
+
embeddings,
|
763
|
+
prompt_tokens,
|
764
|
+
cached_tokens,
|
765
|
+
placeholder_tokens_idx=None,
|
766
|
+
placeholder_tokens_val=None,
|
723
767
|
)
|
724
768
|
)
|
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|
26
26
|
|
27
27
|
class SchedulerProfilerMixin:
|
28
28
|
|
29
|
-
def
|
29
|
+
def init_profiler(self):
|
30
30
|
self.torch_profiler = None
|
31
31
|
self.torch_profiler_output_dir: Optional[str] = None
|
32
32
|
self.profiler_activities: Optional[List[str]] = None
|