sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -43,10 +43,10 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
|
43
43
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
44
44
|
from sglang.srt.lora.lora_registry import LoRARegistry
|
|
45
45
|
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
|
|
46
|
+
from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
|
|
46
47
|
from sglang.srt.managers.disagg_service import start_disagg_service
|
|
47
48
|
from sglang.srt.managers.io_struct import (
|
|
48
49
|
AbortReq,
|
|
49
|
-
BaseReq,
|
|
50
50
|
BatchEmbeddingOutput,
|
|
51
51
|
BatchMultimodalOutput,
|
|
52
52
|
BatchStrOutput,
|
|
@@ -69,6 +69,7 @@ from sglang.srt.managers.io_struct import (
|
|
|
69
69
|
)
|
|
70
70
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
|
71
71
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
|
72
|
+
from sglang.srt.managers.schedule_batch import RequestStage
|
|
72
73
|
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
|
73
74
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
|
74
75
|
from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
|
|
@@ -80,6 +81,7 @@ from sglang.srt.tracing.trace import (
|
|
|
80
81
|
trace_get_proc_propagate_context,
|
|
81
82
|
trace_req_finish,
|
|
82
83
|
trace_req_start,
|
|
84
|
+
trace_set_remote_propagate_context,
|
|
83
85
|
trace_slice_end,
|
|
84
86
|
trace_slice_start,
|
|
85
87
|
)
|
|
@@ -171,7 +173,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
171
173
|
self.context_len = self.model_config.context_len
|
|
172
174
|
self.image_token_id = self.model_config.image_token_id
|
|
173
175
|
self.max_req_input_len = None # Will be set later in engine.py
|
|
174
|
-
|
|
175
176
|
speculative_algorithm = SpeculativeAlgorithm.from_string(
|
|
176
177
|
server_args.speculative_algorithm
|
|
177
178
|
)
|
|
@@ -180,9 +181,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
180
181
|
if speculative_algorithm.is_none()
|
|
181
182
|
else server_args.speculative_num_draft_tokens
|
|
182
183
|
)
|
|
183
|
-
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
|
|
184
|
-
self.multi_item_delimiter_text = None
|
|
185
184
|
|
|
185
|
+
# Initialize tokenizer and processor
|
|
186
186
|
if self.model_config.is_multimodal:
|
|
187
187
|
import_processors("sglang.srt.multimodal.processors")
|
|
188
188
|
try:
|
|
@@ -216,6 +216,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
216
216
|
self.mm_processor = get_mm_processor(
|
|
217
217
|
self.model_config.hf_config, server_args, _processor, transport_mode
|
|
218
218
|
)
|
|
219
|
+
self.mm_data_processor = AsyncMMDataProcessor(
|
|
220
|
+
self.mm_processor,
|
|
221
|
+
max_concurrent_calls=self.server_args.mm_max_concurrent_calls,
|
|
222
|
+
timeout_s=self.server_args.mm_per_request_timeout,
|
|
223
|
+
)
|
|
219
224
|
|
|
220
225
|
if server_args.skip_tokenizer_init:
|
|
221
226
|
self.tokenizer = self.processor = None
|
|
@@ -237,6 +242,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
237
242
|
revision=server_args.revision,
|
|
238
243
|
)
|
|
239
244
|
self._initialize_multi_item_delimiter_text()
|
|
245
|
+
|
|
240
246
|
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
|
241
247
|
if (
|
|
242
248
|
server_args.enable_dynamic_batch_tokenizer
|
|
@@ -255,24 +261,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
255
261
|
self.recv_from_detokenizer = get_zmq_socket(
|
|
256
262
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
|
257
263
|
)
|
|
258
|
-
if self.server_args.tokenizer_worker_num
|
|
264
|
+
if self.server_args.tokenizer_worker_num == 1:
|
|
265
|
+
self.send_to_scheduler = get_zmq_socket(
|
|
266
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
267
|
+
)
|
|
268
|
+
else:
|
|
269
|
+
from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper
|
|
270
|
+
|
|
259
271
|
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
|
260
272
|
send_to_scheduler = get_zmq_socket(
|
|
261
273
|
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
|
262
274
|
)
|
|
263
275
|
|
|
264
|
-
class SenderWrapper:
|
|
265
|
-
def send_pyobj(self, obj):
|
|
266
|
-
if isinstance(obj, BaseReq):
|
|
267
|
-
obj.http_worker_ipc = port_args.tokenizer_ipc_name
|
|
268
|
-
send_to_scheduler.send_pyobj(obj)
|
|
269
|
-
|
|
270
276
|
# Make sure that each request carries the tokenizer_ipc_name for response routing
|
|
271
|
-
self.send_to_scheduler = SenderWrapper()
|
|
272
|
-
else:
|
|
273
|
-
self.send_to_scheduler = get_zmq_socket(
|
|
274
|
-
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
|
275
|
-
)
|
|
277
|
+
self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler)
|
|
276
278
|
|
|
277
279
|
# Request states
|
|
278
280
|
self._chosen_loop = None
|
|
@@ -320,6 +322,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
320
322
|
# LoRA updates and inference to overlap.
|
|
321
323
|
self.lora_update_lock = asyncio.Lock()
|
|
322
324
|
|
|
325
|
+
# Disaggregation
|
|
323
326
|
self.disaggregation_mode = DisaggregationMode(
|
|
324
327
|
self.server_args.disaggregation_mode
|
|
325
328
|
)
|
|
@@ -388,10 +391,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
388
391
|
self.auto_create_handle_loop()
|
|
389
392
|
obj.normalize_batch_and_arguments()
|
|
390
393
|
|
|
391
|
-
if
|
|
392
|
-
|
|
394
|
+
if request:
|
|
395
|
+
if "trace_context" in request.headers:
|
|
396
|
+
trace_set_remote_propagate_context(request.headers["trace_context"])
|
|
393
397
|
|
|
394
|
-
|
|
398
|
+
if self.server_args.tokenizer_worker_num > 1:
|
|
395
399
|
self._attach_multi_http_worker_info(obj)
|
|
396
400
|
|
|
397
401
|
if self.enable_trace:
|
|
@@ -600,10 +604,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
600
604
|
obj.image_data = [obj.image_data]
|
|
601
605
|
if obj.audio_data is not None and not isinstance(obj.audio_data, list):
|
|
602
606
|
obj.audio_data = [obj.audio_data]
|
|
603
|
-
mm_inputs: Dict = await self.
|
|
607
|
+
mm_inputs: Dict = await self.mm_data_processor.process(
|
|
604
608
|
image_data=obj.image_data,
|
|
605
609
|
audio_data=obj.audio_data,
|
|
606
|
-
|
|
610
|
+
input_text_or_ids=(input_text or input_ids),
|
|
607
611
|
request_obj=obj,
|
|
608
612
|
max_req_input_len=self.max_req_input_len,
|
|
609
613
|
)
|
|
@@ -613,7 +617,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
613
617
|
mm_inputs = None
|
|
614
618
|
|
|
615
619
|
self._validate_one_request(obj, input_ids)
|
|
616
|
-
trace_slice_end(
|
|
620
|
+
trace_slice_end(RequestStage.TOKENIZE, obj.rid)
|
|
617
621
|
return self._create_tokenized_object(
|
|
618
622
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
|
619
623
|
)
|
|
@@ -674,6 +678,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
674
678
|
)
|
|
675
679
|
raise ValueError(error_msg)
|
|
676
680
|
|
|
681
|
+
# Matryoshka embeddings validations
|
|
682
|
+
if isinstance(obj, EmbeddingReqInput):
|
|
683
|
+
self._validate_for_matryoshka_dim(obj)
|
|
684
|
+
|
|
677
685
|
if isinstance(obj, GenerateReqInput):
|
|
678
686
|
if (
|
|
679
687
|
obj.return_hidden_states
|
|
@@ -692,6 +700,34 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
692
700
|
"Please set `--enable-custom-logit-processor` to enable this feature."
|
|
693
701
|
)
|
|
694
702
|
|
|
703
|
+
def _validate_for_matryoshka_dim(self, obj: EmbeddingReqInput) -> None:
|
|
704
|
+
"""Validate the request for Matryoshka dim if it has the field set."""
|
|
705
|
+
if obj.dimensions is None:
|
|
706
|
+
return
|
|
707
|
+
|
|
708
|
+
if not self.model_config.is_matryoshka:
|
|
709
|
+
raise ValueError(
|
|
710
|
+
f"Model '{self.model_config.model_path}' does not support matryoshka representation, "
|
|
711
|
+
f"changing output dimensions will lead to poor results."
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
if obj.dimensions < 1:
|
|
715
|
+
raise ValueError("Requested dimensions must be greater than 0")
|
|
716
|
+
|
|
717
|
+
if (
|
|
718
|
+
self.model_config.matryoshka_dimensions
|
|
719
|
+
and obj.dimensions not in self.model_config.matryoshka_dimensions
|
|
720
|
+
):
|
|
721
|
+
raise ValueError(
|
|
722
|
+
f"Model '{self.model_config.model_path}' only supports {self.model_config.matryoshka_dimensions} matryoshka dimensions, "
|
|
723
|
+
f"using other output dimensions will lead to poor results."
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
if obj.dimensions > self.model_config.hidden_size:
|
|
727
|
+
raise ValueError(
|
|
728
|
+
f"Provided dimensions are greater than max embedding dimension: {self.model_config.hidden_size}"
|
|
729
|
+
)
|
|
730
|
+
|
|
695
731
|
def _validate_input_ids_in_vocab(
|
|
696
732
|
self, input_ids: List[int], vocab_size: int
|
|
697
733
|
) -> None:
|
|
@@ -760,6 +796,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
760
796
|
sampling_params,
|
|
761
797
|
rid=obj.rid,
|
|
762
798
|
priority=obj.priority,
|
|
799
|
+
dimensions=obj.dimensions,
|
|
763
800
|
http_worker_ipc=obj.http_worker_ipc,
|
|
764
801
|
)
|
|
765
802
|
|
|
@@ -806,7 +843,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
806
843
|
req, req.text, input_ids_list[i], None, None, token_type_ids
|
|
807
844
|
)
|
|
808
845
|
)
|
|
809
|
-
trace_slice_end(
|
|
846
|
+
trace_slice_end(RequestStage.TOKENIZE, req.rid)
|
|
810
847
|
logger.debug(f"Completed batch processing for {batch_size} requests")
|
|
811
848
|
return tokenized_objs
|
|
812
849
|
|
|
@@ -858,12 +895,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
858
895
|
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
|
859
896
|
created_time: Optional[float] = None,
|
|
860
897
|
):
|
|
861
|
-
trace_slice_start(
|
|
898
|
+
trace_slice_start(RequestStage.TOKENIZER_DISPATCH, obj.rid)
|
|
862
899
|
tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
|
|
863
900
|
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
|
864
901
|
state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
|
|
865
902
|
self.rid_to_state[obj.rid] = state
|
|
866
|
-
trace_slice_end(
|
|
903
|
+
trace_slice_end(
|
|
904
|
+
RequestStage.TOKENIZER_DISPATCH, obj.rid, thread_finish_flag=True
|
|
905
|
+
)
|
|
867
906
|
return state
|
|
868
907
|
|
|
869
908
|
def _send_batch_request(
|
|
@@ -1365,6 +1404,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1365
1404
|
"finish_reason": recv_obj.finished_reasons[i],
|
|
1366
1405
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
|
1367
1406
|
"weight_version": self.server_args.weight_version,
|
|
1407
|
+
"total_retractions": recv_obj.retraction_counts[i],
|
|
1368
1408
|
}
|
|
1369
1409
|
|
|
1370
1410
|
if getattr(state.obj, "return_logprob", False):
|
|
@@ -1453,6 +1493,51 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1453
1493
|
if self.crash_dump_folder and state.finished and state.obj.log_metrics:
|
|
1454
1494
|
self.record_request_for_crash_dump(state, out_dict)
|
|
1455
1495
|
|
|
1496
|
+
def add_logprob_to_meta_info(
|
|
1497
|
+
self,
|
|
1498
|
+
meta_info: dict,
|
|
1499
|
+
state: ReqState,
|
|
1500
|
+
top_logprobs_num: int,
|
|
1501
|
+
token_ids_logprob: List[int],
|
|
1502
|
+
return_text_in_logprobs: bool,
|
|
1503
|
+
):
|
|
1504
|
+
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
1505
|
+
state.input_token_logprobs_val,
|
|
1506
|
+
state.input_token_logprobs_idx,
|
|
1507
|
+
return_text_in_logprobs,
|
|
1508
|
+
)
|
|
1509
|
+
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
1510
|
+
state.output_token_logprobs_val,
|
|
1511
|
+
state.output_token_logprobs_idx,
|
|
1512
|
+
return_text_in_logprobs,
|
|
1513
|
+
)
|
|
1514
|
+
|
|
1515
|
+
if top_logprobs_num > 0:
|
|
1516
|
+
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
1517
|
+
state.input_top_logprobs_val,
|
|
1518
|
+
state.input_top_logprobs_idx,
|
|
1519
|
+
return_text_in_logprobs,
|
|
1520
|
+
)
|
|
1521
|
+
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
1522
|
+
state.output_top_logprobs_val,
|
|
1523
|
+
state.output_top_logprobs_idx,
|
|
1524
|
+
return_text_in_logprobs,
|
|
1525
|
+
)
|
|
1526
|
+
|
|
1527
|
+
if token_ids_logprob is not None:
|
|
1528
|
+
meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
1529
|
+
state.input_token_ids_logprobs_val,
|
|
1530
|
+
state.input_token_ids_logprobs_idx,
|
|
1531
|
+
return_text_in_logprobs,
|
|
1532
|
+
)
|
|
1533
|
+
meta_info["output_token_ids_logprobs"] = (
|
|
1534
|
+
self.detokenize_top_logprobs_tokens(
|
|
1535
|
+
state.output_token_ids_logprobs_val,
|
|
1536
|
+
state.output_token_ids_logprobs_idx,
|
|
1537
|
+
return_text_in_logprobs,
|
|
1538
|
+
)
|
|
1539
|
+
)
|
|
1540
|
+
|
|
1456
1541
|
def convert_logprob_style(
|
|
1457
1542
|
self,
|
|
1458
1543
|
meta_info: dict,
|
|
@@ -1479,16 +1564,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1479
1564
|
state.output_token_logprobs_idx.extend(
|
|
1480
1565
|
recv_obj.output_token_logprobs_idx[recv_obj_index]
|
|
1481
1566
|
)
|
|
1482
|
-
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
1483
|
-
state.input_token_logprobs_val,
|
|
1484
|
-
state.input_token_logprobs_idx,
|
|
1485
|
-
return_text_in_logprobs,
|
|
1486
|
-
)
|
|
1487
|
-
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
|
1488
|
-
state.output_token_logprobs_val,
|
|
1489
|
-
state.output_token_logprobs_idx,
|
|
1490
|
-
return_text_in_logprobs,
|
|
1491
|
-
)
|
|
1492
1567
|
|
|
1493
1568
|
if top_logprobs_num > 0:
|
|
1494
1569
|
if len(recv_obj.input_top_logprobs_val) > 0:
|
|
@@ -1504,16 +1579,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1504
1579
|
state.output_top_logprobs_idx.extend(
|
|
1505
1580
|
recv_obj.output_top_logprobs_idx[recv_obj_index]
|
|
1506
1581
|
)
|
|
1507
|
-
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
1508
|
-
state.input_top_logprobs_val,
|
|
1509
|
-
state.input_top_logprobs_idx,
|
|
1510
|
-
return_text_in_logprobs,
|
|
1511
|
-
)
|
|
1512
|
-
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
|
1513
|
-
state.output_top_logprobs_val,
|
|
1514
|
-
state.output_top_logprobs_idx,
|
|
1515
|
-
return_text_in_logprobs,
|
|
1516
|
-
)
|
|
1517
1582
|
|
|
1518
1583
|
if token_ids_logprob is not None:
|
|
1519
1584
|
if len(recv_obj.input_token_ids_logprobs_val) > 0:
|
|
@@ -1529,18 +1594,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1529
1594
|
state.output_token_ids_logprobs_idx.extend(
|
|
1530
1595
|
recv_obj.output_token_ids_logprobs_idx[recv_obj_index]
|
|
1531
1596
|
)
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
state.output_token_ids_logprobs_idx,
|
|
1541
|
-
return_text_in_logprobs,
|
|
1542
|
-
)
|
|
1543
|
-
)
|
|
1597
|
+
|
|
1598
|
+
self.add_logprob_to_meta_info(
|
|
1599
|
+
meta_info,
|
|
1600
|
+
state,
|
|
1601
|
+
state.obj.top_logprobs_num,
|
|
1602
|
+
state.obj.token_ids_logprob,
|
|
1603
|
+
return_text_in_logprobs,
|
|
1604
|
+
)
|
|
1544
1605
|
|
|
1545
1606
|
def detokenize_logprob_tokens(
|
|
1546
1607
|
self,
|
|
@@ -1657,6 +1718,14 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1657
1718
|
or state.obj.sampling_params.get("ebnf", None)
|
|
1658
1719
|
or state.obj.sampling_params.get("structural_tag", None)
|
|
1659
1720
|
)
|
|
1721
|
+
|
|
1722
|
+
retraction_count = (
|
|
1723
|
+
recv_obj.retraction_counts[i]
|
|
1724
|
+
if getattr(recv_obj, "retraction_counts", None)
|
|
1725
|
+
and i < len(recv_obj.retraction_counts)
|
|
1726
|
+
else 0
|
|
1727
|
+
)
|
|
1728
|
+
|
|
1660
1729
|
self.metrics_collector.observe_one_finished_request(
|
|
1661
1730
|
labels,
|
|
1662
1731
|
recv_obj.prompt_tokens[i],
|
|
@@ -1664,6 +1733,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1664
1733
|
recv_obj.cached_tokens[i],
|
|
1665
1734
|
state.finished_time - state.created_time,
|
|
1666
1735
|
has_grammar,
|
|
1736
|
+
retraction_count,
|
|
1667
1737
|
)
|
|
1668
1738
|
|
|
1669
1739
|
def dump_requests(self, state: ReqState, out_dict: dict):
|
|
@@ -1716,26 +1786,33 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
1716
1786
|
return
|
|
1717
1787
|
state = self.rid_to_state[recv_obj.rid]
|
|
1718
1788
|
state.finished = True
|
|
1789
|
+
|
|
1790
|
+
abort_message = recv_obj.abort_message or "Abort in waiting queue"
|
|
1791
|
+
finish_reason = {
|
|
1792
|
+
"type": "abort",
|
|
1793
|
+
"message": abort_message,
|
|
1794
|
+
}
|
|
1719
1795
|
if recv_obj.finished_reason:
|
|
1720
|
-
|
|
1721
|
-
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1796
|
+
finish_reason = recv_obj.finished_reason
|
|
1797
|
+
meta_info = {"id": recv_obj.rid, "finish_reason": finish_reason}
|
|
1798
|
+
is_stream = getattr(state.obj, "stream", False)
|
|
1799
|
+
if getattr(state.obj, "return_logprob", False):
|
|
1800
|
+
self.add_logprob_to_meta_info(
|
|
1801
|
+
meta_info,
|
|
1802
|
+
state,
|
|
1803
|
+
state.obj.top_logprobs_num,
|
|
1804
|
+
state.obj.token_ids_logprob,
|
|
1805
|
+
state.obj.return_text_in_logprobs
|
|
1806
|
+
and not self.server_args.skip_tokenizer_init,
|
|
1807
|
+
)
|
|
1808
|
+
|
|
1809
|
+
output_ids = state.output_ids
|
|
1810
|
+
meta_info["completion_tokens"] = len(output_ids)
|
|
1811
|
+
out = {
|
|
1812
|
+
"text": state.text,
|
|
1813
|
+
"output_ids": [output_ids[-1]] if is_stream else output_ids,
|
|
1814
|
+
"meta_info": meta_info,
|
|
1815
|
+
}
|
|
1739
1816
|
state.out_list.append(out)
|
|
1740
1817
|
state.event.set()
|
|
1741
1818
|
|
|
@@ -2096,7 +2173,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
2096
2173
|
bootstrap_room = (
|
|
2097
2174
|
obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
|
|
2098
2175
|
)
|
|
2099
|
-
trace_req_start(
|
|
2176
|
+
trace_req_start(
|
|
2177
|
+
obj.rid,
|
|
2178
|
+
bootstrap_room,
|
|
2179
|
+
ts=int(created_time * 1e9),
|
|
2180
|
+
role=self.server_args.disaggregation_mode,
|
|
2181
|
+
)
|
|
2100
2182
|
trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
|
|
2101
2183
|
else:
|
|
2102
2184
|
for i in range(len(obj.rid)):
|
|
@@ -2105,7 +2187,12 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|
|
2105
2187
|
if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
|
|
2106
2188
|
else None
|
|
2107
2189
|
)
|
|
2108
|
-
trace_req_start(
|
|
2190
|
+
trace_req_start(
|
|
2191
|
+
obj.rid[i],
|
|
2192
|
+
bootstrap_room,
|
|
2193
|
+
ts=int(created_time * 1e9),
|
|
2194
|
+
role=self.server_args.disaggregation_mode,
|
|
2195
|
+
)
|
|
2109
2196
|
trace_slice_start(
|
|
2110
2197
|
"", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
|
|
2111
2198
|
)
|
sglang/srt/managers/tp_worker.py
CHANGED
|
@@ -35,7 +35,7 @@ from sglang.srt.managers.io_struct import (
|
|
|
35
35
|
UpdateWeightsFromIPCReqInput,
|
|
36
36
|
UpdateWeightsFromTensorReqInput,
|
|
37
37
|
)
|
|
38
|
-
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
|
38
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
|
|
39
39
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
|
40
40
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
41
41
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
@@ -425,3 +425,26 @@ class TpModelWorker(BaseTpWorker):
|
|
|
425
425
|
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
|
426
426
|
can_run_cuda_graph=can_run_cuda_graph,
|
|
427
427
|
)
|
|
428
|
+
|
|
429
|
+
def forward_batch_split_prefill(self, batch: ScheduleBatch):
|
|
430
|
+
if batch.split_index == 0:
|
|
431
|
+
model_worker_batch = batch.get_model_worker_batch()
|
|
432
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
|
433
|
+
batch.split_forward_batch = forward_batch
|
|
434
|
+
batch.seq_lens_cpu_cache = model_worker_batch.seq_lens_cpu
|
|
435
|
+
else:
|
|
436
|
+
model_worker_batch = batch.get_model_worker_batch(batch.seq_lens_cpu_cache)
|
|
437
|
+
|
|
438
|
+
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
|
439
|
+
batch.split_forward_batch, split_forward_count=batch.split_forward_count
|
|
440
|
+
)
|
|
441
|
+
if logits_output:
|
|
442
|
+
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
|
443
|
+
else:
|
|
444
|
+
next_token_ids = None
|
|
445
|
+
batch_result = GenerationBatchResult(
|
|
446
|
+
logits_output=logits_output,
|
|
447
|
+
can_run_cuda_graph=can_run_cuda_graph,
|
|
448
|
+
)
|
|
449
|
+
batch_result.next_token_ids = next_token_ids
|
|
450
|
+
return batch_result
|
|
@@ -1,12 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from abc import ABC, abstractmethod
|
|
2
|
-
from typing import
|
|
4
|
+
from typing import (
|
|
5
|
+
TYPE_CHECKING,
|
|
6
|
+
Any,
|
|
7
|
+
NamedTuple,
|
|
8
|
+
Optional,
|
|
9
|
+
Protocol,
|
|
10
|
+
Tuple,
|
|
11
|
+
runtime_checkable,
|
|
12
|
+
)
|
|
3
13
|
|
|
4
14
|
import torch
|
|
5
15
|
|
|
16
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
|
17
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
|
18
|
+
|
|
6
19
|
if TYPE_CHECKING:
|
|
7
20
|
from sglang.srt.managers.schedule_batch import Req
|
|
8
|
-
|
|
9
|
-
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@runtime_checkable
|
|
24
|
+
class PrefixCacheTrait(Protocol):
|
|
25
|
+
req_to_token_pool: ReqToTokenPool
|
|
26
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator
|
|
27
|
+
page_size: int
|
|
28
|
+
disable: bool
|
|
10
29
|
|
|
11
30
|
|
|
12
31
|
class MatchResult(NamedTuple):
|
|
@@ -28,7 +47,7 @@ class MatchResult(NamedTuple):
|
|
|
28
47
|
host_hit_length: int = 0
|
|
29
48
|
|
|
30
49
|
|
|
31
|
-
class BasePrefixCache(ABC):
|
|
50
|
+
class BasePrefixCache(ABC, PrefixCacheTrait):
|
|
32
51
|
"""Cache can be indexed by either rid or key."""
|
|
33
52
|
|
|
34
53
|
@abstractmethod
|
sglang/srt/mem_cache/common.py
CHANGED
|
@@ -89,6 +89,7 @@ def write_cache_indices(
|
|
|
89
89
|
prefix_pointers = torch.tensor(
|
|
90
90
|
[t.data_ptr() for t in prefix_tensors],
|
|
91
91
|
device=req_to_token_pool.device,
|
|
92
|
+
dtype=torch.uint64,
|
|
92
93
|
)
|
|
93
94
|
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
|
94
95
|
write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](
|
|
@@ -19,7 +19,13 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
|
|
|
19
19
|
hasher.update(bytes.fromhex(prior_hash))
|
|
20
20
|
|
|
21
21
|
for t in token_ids:
|
|
22
|
-
|
|
22
|
+
if isinstance(t, tuple):
|
|
23
|
+
# EAGLE bigram mode: hash both elements to uniquely identify the bigram
|
|
24
|
+
for elem in t:
|
|
25
|
+
hasher.update(elem.to_bytes(4, byteorder="little", signed=False))
|
|
26
|
+
else:
|
|
27
|
+
# Regular mode: single integer token
|
|
28
|
+
hasher.update(t.to_bytes(4, byteorder="little", signed=False))
|
|
23
29
|
|
|
24
30
|
return hasher.hexdigest()
|
|
25
31
|
|