sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
4
|
import time
|
3
5
|
from collections import defaultdict
|
4
|
-
from typing import List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
7
|
+
|
8
|
+
import torch
|
5
9
|
|
6
10
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
7
11
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
12
|
+
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
8
13
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
9
14
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
15
|
+
from sglang.srt.managers.utils import DPBalanceMeta
|
10
16
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
11
17
|
from sglang.srt.utils import get_bool_env_var
|
12
18
|
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from sglang.srt.managers.scheduler import Scheduler
|
21
|
+
|
13
22
|
logger = logging.getLogger(__name__)
|
14
23
|
|
15
24
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
@@ -28,7 +37,9 @@ class KvMetrics:
|
|
28
37
|
|
29
38
|
|
30
39
|
class SchedulerMetricsMixin:
|
31
|
-
def init_metrics(
|
40
|
+
def init_metrics(
|
41
|
+
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
|
42
|
+
):
|
32
43
|
self.last_gen_throughput: float = 0.0
|
33
44
|
self.last_input_throughput: float = 0.0
|
34
45
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
|
|
50
61
|
labels["dp_rank"] = dp_rank
|
51
62
|
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
52
63
|
|
53
|
-
def
|
64
|
+
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
|
65
|
+
self.balance_meta = dp_balance_meta
|
66
|
+
if (
|
67
|
+
self.server_args.enable_dp_attention
|
68
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
69
|
+
):
|
70
|
+
assert dp_balance_meta is not None
|
71
|
+
|
72
|
+
self.recv_dp_balance_id_this_term = []
|
73
|
+
|
74
|
+
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
|
54
75
|
if self.enable_kv_cache_events:
|
55
76
|
self.kv_event_publisher = EventPublisherFactory.create(
|
56
77
|
kv_events_config, self.attn_dp_rank
|
57
78
|
)
|
58
79
|
|
59
80
|
def log_prefill_stats(
|
60
|
-
self,
|
81
|
+
self: Scheduler,
|
61
82
|
adder: PrefillAdder,
|
62
83
|
can_run_list: List[Req],
|
63
84
|
running_bs: int,
|
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
|
|
138
159
|
self._publish_kv_events()
|
139
160
|
|
140
161
|
def log_decode_stats(
|
141
|
-
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
162
|
+
self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
142
163
|
):
|
143
164
|
batch = running_batch or self.running_batch
|
144
165
|
|
@@ -193,7 +214,7 @@ class SchedulerMetricsMixin:
|
|
193
214
|
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
194
215
|
|
195
216
|
msg += (
|
196
|
-
f"cuda graph: {can_run_cuda_graph}, "
|
217
|
+
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
|
197
218
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
198
219
|
f"#queue-req: {len(self.waiting_queue)}, "
|
199
220
|
)
|
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
|
|
220
241
|
self._emit_kv_metrics()
|
221
242
|
self._publish_kv_events()
|
222
243
|
|
223
|
-
def _emit_kv_metrics(self):
|
244
|
+
def _emit_kv_metrics(self: Scheduler):
|
224
245
|
kv_metrics = KvMetrics()
|
225
246
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
226
247
|
kv_metrics.request_total_slots = self.max_running_requests
|
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
|
|
236
257
|
if not self.send_metrics_from_scheduler.closed:
|
237
258
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
238
259
|
|
239
|
-
def _publish_kv_events(self):
|
260
|
+
def _publish_kv_events(self: Scheduler):
|
240
261
|
if self.enable_kv_cache_events:
|
241
262
|
events = self.tree_cache.take_events()
|
242
263
|
if events:
|
243
264
|
batch = KVEventBatch(ts=time.time(), events=events)
|
244
265
|
self.kv_event_publisher.publish(batch)
|
266
|
+
|
267
|
+
def maybe_update_dp_balance_data(
|
268
|
+
self: Scheduler, recv_req: TokenizedGenerateReqInput
|
269
|
+
):
|
270
|
+
if (
|
271
|
+
self.server_args.enable_dp_attention
|
272
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
273
|
+
):
|
274
|
+
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
275
|
+
|
276
|
+
def maybe_handle_dp_balance_data(self: Scheduler):
|
277
|
+
if (
|
278
|
+
self.server_args.load_balance_method == "minimum_tokens"
|
279
|
+
and self.forward_ct % 40 == 0
|
280
|
+
):
|
281
|
+
holding_tokens = self.get_load()
|
282
|
+
|
283
|
+
new_recv_dp_balance_id_list, holding_token_list = (
|
284
|
+
self.gather_dp_balance_info(holding_tokens)
|
285
|
+
)
|
286
|
+
|
287
|
+
self.recv_dp_balance_id_this_term.clear()
|
288
|
+
if self.tp_rank == 0: # only first worker write info
|
289
|
+
self.write_shared_dp_balance_info(
|
290
|
+
new_recv_dp_balance_id_list, holding_token_list
|
291
|
+
)
|
292
|
+
|
293
|
+
def gather_dp_balance_info(
|
294
|
+
self: Scheduler, holding_tokens_list
|
295
|
+
) -> Union[None, List[List[int]]]:
|
296
|
+
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
297
|
+
recv_list = self.recv_dp_balance_id_this_term
|
298
|
+
assert len(recv_list) <= 511, (
|
299
|
+
"The number of requests received this round is too large. "
|
300
|
+
"Please increase gather_tensor_size and onfly_info_size."
|
301
|
+
)
|
302
|
+
# The maximum size of the tensor used for gathering data from all workers.
|
303
|
+
gather_tensor_size = 512
|
304
|
+
|
305
|
+
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
306
|
+
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
307
|
+
recv_tensor[0] = holding_tokens_list
|
308
|
+
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
|
309
|
+
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
|
310
|
+
|
311
|
+
if self.tp_rank == 0:
|
312
|
+
gathered_list = [
|
313
|
+
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
314
|
+
for _ in range(self.balance_meta.num_workers)
|
315
|
+
]
|
316
|
+
else:
|
317
|
+
gathered_list = None
|
318
|
+
|
319
|
+
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
|
320
|
+
|
321
|
+
gathered_id_list_per_worker = None
|
322
|
+
if self.tp_rank == 0:
|
323
|
+
gathered_id_list_per_worker = []
|
324
|
+
holding_tokens_list = []
|
325
|
+
for tensor in gathered_list:
|
326
|
+
holding_tokens_list.append(tensor[0].item())
|
327
|
+
list_length = tensor[1].item()
|
328
|
+
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
|
329
|
+
|
330
|
+
return gathered_id_list_per_worker, holding_tokens_list
|
331
|
+
|
332
|
+
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
|
333
|
+
meta = self.balance_meta
|
334
|
+
|
335
|
+
with meta.mutex:
|
336
|
+
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
337
|
+
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
|
338
|
+
# 1.Check if the rid received by each worker this round is present in onfly.
|
339
|
+
# If it is, remove the corresponding onfly item.
|
340
|
+
worker_id = 0
|
341
|
+
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
342
|
+
for new_recv_rid in new_recv_rids:
|
343
|
+
assert (
|
344
|
+
new_recv_rid in on_fly_reqs
|
345
|
+
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
346
|
+
del on_fly_reqs[new_recv_rid]
|
347
|
+
worker_id += 1
|
348
|
+
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
349
|
+
meta.set_shared_onfly_info(onfly_list)
|
350
|
+
meta.set_shared_local_tokens(local_tokens)
|
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
|
|
93
93
|
# This updates radix so others can match
|
94
94
|
self.tree_cache.cache_unfinished_req(req)
|
95
95
|
|
96
|
-
if
|
96
|
+
if batch.return_logprob:
|
97
97
|
assert extend_logprob_start_len_per_req is not None
|
98
98
|
assert extend_input_len_per_req is not None
|
99
99
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
100
100
|
extend_input_len = extend_input_len_per_req[i]
|
101
101
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
102
|
+
if req.return_logprob:
|
103
|
+
self.add_logprob_return_values(
|
104
|
+
i,
|
105
|
+
req,
|
106
|
+
logprob_pt,
|
107
|
+
next_token_ids,
|
108
|
+
num_input_logprobs,
|
109
|
+
logits_output,
|
110
|
+
)
|
110
111
|
logprob_pt += num_input_logprobs
|
111
112
|
|
112
113
|
if (
|
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
|
|
146
147
|
skip_stream_req = req
|
147
148
|
|
148
149
|
# Incrementally update input logprobs.
|
149
|
-
if
|
150
|
+
if batch.return_logprob:
|
150
151
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
151
152
|
extend_input_len = extend_input_len_per_req[i]
|
152
153
|
if extend_logprob_start_len < extend_input_len:
|
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
|
|
154
155
|
num_input_logprobs = (
|
155
156
|
extend_input_len - extend_logprob_start_len
|
156
157
|
)
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
158
|
+
if req.return_logprob:
|
159
|
+
self.add_input_logprob_return_values(
|
160
|
+
i,
|
161
|
+
req,
|
162
|
+
logits_output,
|
163
|
+
logprob_pt,
|
164
|
+
num_input_logprobs,
|
165
|
+
last_prefill_chunk=False,
|
166
|
+
)
|
165
167
|
logprob_pt += num_input_logprobs
|
166
168
|
|
167
169
|
self.set_next_batch_sampling_info_done(batch)
|
@@ -698,6 +700,8 @@ class SchedulerOutputProcessorMixin:
|
|
698
700
|
output_token_ids_logprobs_val,
|
699
701
|
output_token_ids_logprobs_idx,
|
700
702
|
output_hidden_states,
|
703
|
+
placeholder_tokens_idx=None,
|
704
|
+
placeholder_tokens_val=None,
|
701
705
|
)
|
702
706
|
)
|
703
707
|
|
@@ -717,6 +721,12 @@ class SchedulerOutputProcessorMixin:
|
|
717
721
|
cached_tokens.append(req.cached_tokens)
|
718
722
|
self.send_to_detokenizer.send_pyobj(
|
719
723
|
BatchEmbeddingOut(
|
720
|
-
rids,
|
724
|
+
rids,
|
725
|
+
finished_reasons,
|
726
|
+
embeddings,
|
727
|
+
prompt_tokens,
|
728
|
+
cached_tokens,
|
729
|
+
placeholder_tokens_idx=None,
|
730
|
+
placeholder_tokens_val=None,
|
721
731
|
)
|
722
732
|
)
|
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
|
|
26
26
|
|
27
27
|
class SchedulerProfilerMixin:
|
28
28
|
|
29
|
-
def
|
29
|
+
def init_profiler(self):
|
30
30
|
self.torch_profiler = None
|
31
31
|
self.torch_profiler_output_dir: Optional[str] = None
|
32
32
|
self.profiler_activities: Optional[List[str]] = None
|
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
|
|
121
121
|
url = params["url"]
|
122
122
|
|
123
123
|
worker = self.tp_worker.worker
|
124
|
-
|
125
124
|
worker.model_runner.save_remote_model(url)
|
126
125
|
|
126
|
+
if self.draft_worker is not None:
|
127
|
+
draft_url = params.get("draft_url", None)
|
128
|
+
assert (
|
129
|
+
draft_url is not None
|
130
|
+
), "draft_url must be provided when draft model is enabled"
|
131
|
+
draft_worker = self.draft_worker.worker
|
132
|
+
draft_worker.model_runner.save_remote_model(draft_url)
|
133
|
+
|
127
134
|
def save_sharded_model(self, params):
|
128
135
|
worker = self.tp_worker.worker
|
129
136
|
|
@@ -24,20 +24,20 @@ import os
|
|
24
24
|
import re
|
25
25
|
from typing import Optional
|
26
26
|
|
27
|
-
from sglang.srt.code_completion_parser import (
|
27
|
+
from sglang.srt.parser.code_completion_parser import (
|
28
28
|
CompletionTemplate,
|
29
29
|
FimPosition,
|
30
30
|
completion_template_exists,
|
31
31
|
register_completion_template,
|
32
32
|
)
|
33
|
-
from sglang.srt.conversation import (
|
33
|
+
from sglang.srt.parser.conversation import (
|
34
34
|
Conversation,
|
35
35
|
SeparatorStyle,
|
36
36
|
chat_template_exists,
|
37
37
|
get_conv_template_by_model_path,
|
38
38
|
register_conv_template,
|
39
39
|
)
|
40
|
-
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
|
40
|
+
from sglang.srt.parser.jinja_template_utils import detect_jinja_template_content_format
|
41
41
|
|
42
42
|
logger = logging.getLogger(__name__)
|
43
43
|
|