sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from enum import Enum, auto
|
|
17
17
|
from typing import Any, List, Optional
|
18
18
|
|
19
19
|
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
20
|
-
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
20
|
+
from sglang.srt.utils.poll_based_barrier import PollBasedBarrier
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
12
12
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
13
13
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
14
14
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
15
|
-
from sglang.srt.managers.utils import DPBalanceMeta
|
16
15
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
17
16
|
from sglang.srt.utils import get_bool_env_var
|
18
17
|
|
@@ -47,8 +46,11 @@ class SchedulerMetricsMixin:
|
|
47
46
|
self.spec_num_total_forward_ct = 0
|
48
47
|
self.cum_spec_accept_length = 0
|
49
48
|
self.cum_spec_accept_count = 0
|
50
|
-
self.
|
49
|
+
self.kv_transfer_speed_gb_s: float = 0.0
|
50
|
+
self.kv_transfer_latency_ms: float = 0.0
|
51
|
+
|
51
52
|
self.stats = SchedulerStats()
|
53
|
+
|
52
54
|
if self.enable_metrics:
|
53
55
|
engine_type = "unified"
|
54
56
|
labels = {
|
@@ -61,33 +63,30 @@ class SchedulerMetricsMixin:
|
|
61
63
|
labels["dp_rank"] = dp_rank
|
62
64
|
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
63
65
|
|
64
|
-
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
|
65
|
-
self.balance_meta = dp_balance_meta
|
66
|
-
if (
|
67
|
-
self.server_args.enable_dp_attention
|
68
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
69
|
-
):
|
70
|
-
assert dp_balance_meta is not None
|
71
|
-
|
72
|
-
self.recv_dp_balance_id_this_term = []
|
73
|
-
|
74
66
|
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
|
75
67
|
if self.enable_kv_cache_events:
|
76
68
|
self.kv_event_publisher = EventPublisherFactory.create(
|
77
69
|
kv_events_config, self.attn_dp_rank
|
78
70
|
)
|
79
71
|
|
72
|
+
def udpate_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
73
|
+
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
74
|
+
self.spec_num_total_forward_ct += bs
|
75
|
+
self.num_generated_tokens += num_accepted_tokens
|
76
|
+
|
80
77
|
def log_prefill_stats(
|
81
78
|
self: Scheduler,
|
82
79
|
adder: PrefillAdder,
|
83
80
|
can_run_list: List[Req],
|
84
81
|
running_bs: int,
|
82
|
+
running_bs_offline_batch: int,
|
85
83
|
):
|
86
84
|
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
87
85
|
self.last_prefill_stats_tic = time.perf_counter()
|
88
86
|
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
89
87
|
self.last_prefill_tokens = adder.log_input_tokens
|
90
88
|
|
89
|
+
# TODO: generalize this for various memory pools
|
91
90
|
if self.is_hybrid:
|
92
91
|
(
|
93
92
|
full_num_used,
|
@@ -101,51 +100,53 @@ class SchedulerMetricsMixin:
|
|
101
100
|
) = self._get_swa_token_info()
|
102
101
|
num_used = max(full_num_used, swa_num_used)
|
103
102
|
token_usage = max(full_token_usage, swa_token_usage)
|
104
|
-
|
103
|
+
token_usage_msg = (
|
105
104
|
f"full token usage: {full_token_usage:.2f}, "
|
106
105
|
f"swa token usage: {swa_token_usage:.2f}, "
|
107
106
|
)
|
108
107
|
else:
|
109
108
|
num_used, token_usage, _, _ = self._get_token_info()
|
110
|
-
|
109
|
+
token_usage_msg = f"token usage: {token_usage:.2f}, "
|
111
110
|
|
112
|
-
num_new_seq = len(can_run_list)
|
113
111
|
f = (
|
114
112
|
f"Prefill batch. "
|
115
|
-
f"#new-seq: {
|
113
|
+
f"#new-seq: {len(can_run_list)}, "
|
116
114
|
f"#new-token: {adder.log_input_tokens}, "
|
117
115
|
f"#cached-token: {adder.log_hit_tokens}, "
|
118
|
-
f"{
|
116
|
+
f"{token_usage_msg}"
|
117
|
+
f"#running-req: {running_bs}, "
|
118
|
+
f"#queue-req: {len(self.waiting_queue)}, "
|
119
119
|
)
|
120
120
|
|
121
121
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
122
|
-
f += f"#
|
123
|
-
f += f"#
|
124
|
-
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
125
|
-
f += f"input throughput (token/s): {self.last_input_throughput:.2f}, "
|
126
|
-
else:
|
127
|
-
f += f"#running-req: {running_bs}, "
|
128
|
-
f += f"#queue-req: {len(self.waiting_queue)}, "
|
122
|
+
f += f"#prealloc-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
123
|
+
f += f"#inflight-req: {len(self.disagg_prefill_inflight_queue)}, "
|
129
124
|
|
130
125
|
logger.info(f)
|
131
126
|
|
132
127
|
if self.enable_metrics:
|
128
|
+
# Basics
|
133
129
|
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
134
|
-
|
135
130
|
cache_hit_rate = (
|
136
131
|
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
137
132
|
)
|
133
|
+
|
138
134
|
self.stats.num_running_reqs = running_bs
|
135
|
+
self.stats.num_running_reqs_offline_batch = running_bs_offline_batch
|
139
136
|
self.stats.num_used_tokens = num_used
|
140
|
-
self.stats.token_usage =
|
137
|
+
self.stats.token_usage = token_usage
|
138
|
+
if self.is_hybrid:
|
139
|
+
self.stats.swa_token_usage = swa_token_usage
|
141
140
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
141
|
+
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
142
142
|
self.stats.cache_hit_rate = cache_hit_rate
|
143
143
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
self.
|
144
|
+
# Retract
|
145
|
+
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
146
|
+
self.stats.num_paused_reqs = self.num_paused_reqs
|
147
|
+
self.num_retracted_reqs = self.num_paused_reqs = 0
|
148
148
|
|
149
|
+
# PD disaggregation
|
149
150
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
150
151
|
self.stats.num_prefill_prealloc_queue_reqs = len(
|
151
152
|
self.disagg_prefill_bootstrap_queue.queue
|
@@ -153,7 +154,18 @@ class SchedulerMetricsMixin:
|
|
153
154
|
self.stats.num_prefill_inflight_queue_reqs = len(
|
154
155
|
self.disagg_prefill_inflight_queue
|
155
156
|
)
|
157
|
+
self.stats.kv_transfer_speed_gb_s = self.kv_transfer_speed_gb_s
|
158
|
+
self.stats.kv_transfer_latency_ms = self.kv_transfer_latency_ms
|
159
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
160
|
+
self.stats.num_decode_prealloc_queue_reqs = len(
|
161
|
+
self.disagg_decode_prealloc_queue.queue
|
162
|
+
)
|
163
|
+
self.stats.num_decode_transfer_queue_reqs = len(
|
164
|
+
self.disagg_decode_transfer_queue.queue
|
165
|
+
)
|
156
166
|
|
167
|
+
# Others
|
168
|
+
self.calculate_utilization()
|
157
169
|
self.metrics_collector.log_stats(self.stats)
|
158
170
|
self._emit_kv_metrics()
|
159
171
|
self._publish_kv_events()
|
@@ -166,8 +178,12 @@ class SchedulerMetricsMixin:
|
|
166
178
|
gap_latency = time.perf_counter() - self.last_decode_stats_tic
|
167
179
|
self.last_decode_stats_tic = time.perf_counter()
|
168
180
|
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
181
|
+
|
169
182
|
self.num_generated_tokens = 0
|
170
183
|
num_running_reqs = len(batch.reqs)
|
184
|
+
num_running_reqs_offline_batch = 0
|
185
|
+
|
186
|
+
# TODO: generalize this for various memory pools
|
171
187
|
if self.is_hybrid:
|
172
188
|
(
|
173
189
|
full_num_used,
|
@@ -181,7 +197,7 @@ class SchedulerMetricsMixin:
|
|
181
197
|
) = self._get_swa_token_info()
|
182
198
|
num_used = max(full_num_used, swa_num_used)
|
183
199
|
token_usage = max(full_token_usage, swa_token_usage)
|
184
|
-
|
200
|
+
token_usage_msg = (
|
185
201
|
f"#full token: {full_num_used}, "
|
186
202
|
f"full token usage: {full_token_usage:.2f}, "
|
187
203
|
f"#swa token: {swa_num_used}, "
|
@@ -189,14 +205,14 @@ class SchedulerMetricsMixin:
|
|
189
205
|
)
|
190
206
|
else:
|
191
207
|
num_used, token_usage, _, _ = self._get_token_info()
|
192
|
-
|
208
|
+
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
|
193
209
|
|
194
210
|
if RECORD_STEP_TIME:
|
195
211
|
self.step_time_dict[num_running_reqs].append(
|
196
212
|
gap_latency / self.server_args.decode_log_interval
|
197
213
|
)
|
198
214
|
|
199
|
-
msg = f"Decode batch. #running-req: {num_running_reqs}, {
|
215
|
+
msg = f"Decode batch. #running-req: {num_running_reqs}, {token_usage_msg}"
|
200
216
|
|
201
217
|
if self.spec_algorithm.is_none():
|
202
218
|
spec_accept_length = 0
|
@@ -208,41 +224,66 @@ class SchedulerMetricsMixin:
|
|
208
224
|
self.cum_spec_accept_count += self.spec_num_total_forward_ct
|
209
225
|
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
|
210
226
|
msg += f"accept len: {spec_accept_length:.2f}, "
|
227
|
+
cache_hit_rate = 0.0
|
211
228
|
|
212
229
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
213
230
|
msg += f"pre-allocated usage: {self.disagg_decode_prealloc_queue.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
231
|
+
msg += f"#prealloc-req: {len(self.disagg_decode_prealloc_queue.queue)}, "
|
232
|
+
msg += f"#transfer-req: {len(self.disagg_decode_transfer_queue.queue)}, "
|
214
233
|
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
215
234
|
|
216
235
|
msg += (
|
217
|
-
f"{'
|
236
|
+
f"{'cuda graph' if self.device == 'cuda' else 'cpu graph'}: {can_run_cuda_graph}, "
|
218
237
|
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
|
219
238
|
f"#queue-req: {len(self.waiting_queue)}, "
|
220
239
|
)
|
221
240
|
|
222
241
|
logger.info(msg)
|
223
242
|
if self.enable_metrics:
|
243
|
+
# Basics
|
224
244
|
self.stats.num_running_reqs = num_running_reqs
|
245
|
+
self.stats.num_running_reqs_offline_batch = num_running_reqs_offline_batch
|
225
246
|
self.stats.num_used_tokens = num_used
|
226
|
-
self.stats.token_usage =
|
227
|
-
self.
|
247
|
+
self.stats.token_usage = token_usage
|
248
|
+
if self.is_hybrid:
|
249
|
+
self.stats.swa_token_usage = swa_token_usage
|
228
250
|
self.stats.gen_throughput = self.last_gen_throughput
|
229
251
|
self.stats.num_queue_reqs = len(self.waiting_queue)
|
230
252
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
253
|
+
self.stats.cache_hit_rate = cache_hit_rate
|
231
254
|
self.stats.spec_accept_length = spec_accept_length
|
232
|
-
|
233
|
-
|
234
|
-
|
255
|
+
|
256
|
+
# Retract
|
257
|
+
self.stats.num_retracted_reqs = self.num_retracted_reqs
|
258
|
+
self.stats.num_paused_reqs = self.num_paused_reqs
|
259
|
+
self.num_retracted_reqs = self.num_paused_reqs = 0
|
260
|
+
|
261
|
+
# PD disaggregation
|
262
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
263
|
+
self.stats.num_prefill_prealloc_queue_reqs = len(
|
264
|
+
self.disagg_prefill_bootstrap_queue.queue
|
265
|
+
)
|
266
|
+
self.stats.num_prefill_inflight_queue_reqs = len(
|
267
|
+
self.disagg_prefill_inflight_queue
|
268
|
+
)
|
269
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
235
270
|
self.stats.num_decode_prealloc_queue_reqs = len(
|
236
271
|
self.disagg_decode_prealloc_queue.queue
|
237
272
|
)
|
238
273
|
self.stats.num_decode_transfer_queue_reqs = len(
|
239
274
|
self.disagg_decode_transfer_queue.queue
|
240
275
|
)
|
276
|
+
|
277
|
+
# Others
|
278
|
+
self.calculate_utilization()
|
241
279
|
self.metrics_collector.log_stats(self.stats)
|
242
280
|
self._emit_kv_metrics()
|
243
281
|
self._publish_kv_events()
|
244
282
|
|
245
283
|
def _emit_kv_metrics(self: Scheduler):
|
284
|
+
if not self.enable_kv_cache_events:
|
285
|
+
return
|
286
|
+
|
246
287
|
kv_metrics = KvMetrics()
|
247
288
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
248
289
|
kv_metrics.request_total_slots = self.max_running_requests
|
@@ -259,93 +300,24 @@ class SchedulerMetricsMixin:
|
|
259
300
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
260
301
|
|
261
302
|
def _publish_kv_events(self: Scheduler):
|
262
|
-
if self.enable_kv_cache_events:
|
263
|
-
|
264
|
-
if events:
|
265
|
-
batch = KVEventBatch(ts=time.time(), events=events)
|
266
|
-
self.kv_event_publisher.publish(batch)
|
303
|
+
if not self.enable_kv_cache_events:
|
304
|
+
return
|
267
305
|
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
self.server_args.enable_dp_attention
|
273
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
274
|
-
):
|
275
|
-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
276
|
-
|
277
|
-
def maybe_handle_dp_balance_data(self: Scheduler):
|
278
|
-
if (
|
279
|
-
self.server_args.load_balance_method == "minimum_tokens"
|
280
|
-
and self.forward_ct % 40 == 0
|
281
|
-
):
|
282
|
-
holding_tokens = self.get_load().num_tokens
|
283
|
-
|
284
|
-
new_recv_dp_balance_id_list, holding_token_list = (
|
285
|
-
self.gather_dp_balance_info(holding_tokens)
|
286
|
-
)
|
306
|
+
events = self.tree_cache.take_events()
|
307
|
+
if events:
|
308
|
+
batch = KVEventBatch(ts=time.time(), events=events)
|
309
|
+
self.kv_event_publisher.publish(batch)
|
287
310
|
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
new_recv_dp_balance_id_list, holding_token_list
|
292
|
-
)
|
293
|
-
|
294
|
-
def gather_dp_balance_info(
|
295
|
-
self: Scheduler, holding_tokens_list
|
296
|
-
) -> Union[None, List[List[int]]]:
|
297
|
-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
298
|
-
recv_list = self.recv_dp_balance_id_this_term
|
299
|
-
assert len(recv_list) <= 511, (
|
300
|
-
"The number of requests received this round is too large. "
|
301
|
-
"Please increase gather_tensor_size and onfly_info_size."
|
302
|
-
)
|
303
|
-
# The maximum size of the tensor used for gathering data from all workers.
|
304
|
-
gather_tensor_size = 512
|
305
|
-
|
306
|
-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
307
|
-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
308
|
-
recv_tensor[0] = holding_tokens_list
|
309
|
-
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
|
310
|
-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
|
311
|
-
|
312
|
-
if self.tp_rank == 0:
|
313
|
-
gathered_list = [
|
314
|
-
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
315
|
-
for _ in range(self.balance_meta.num_workers)
|
316
|
-
]
|
311
|
+
def calculate_utilization(self):
|
312
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
313
|
+
self.stats.utilization = -1
|
317
314
|
else:
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
holding_tokens_list.append(tensor[0].item())
|
328
|
-
list_length = tensor[1].item()
|
329
|
-
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
|
330
|
-
|
331
|
-
return gathered_id_list_per_worker, holding_tokens_list
|
332
|
-
|
333
|
-
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
|
334
|
-
meta = self.balance_meta
|
335
|
-
|
336
|
-
with meta.mutex:
|
337
|
-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
338
|
-
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
|
339
|
-
# 1.Check if the rid received by each worker this round is present in onfly.
|
340
|
-
# If it is, remove the corresponding onfly item.
|
341
|
-
worker_id = 0
|
342
|
-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
343
|
-
for new_recv_rid in new_recv_rids:
|
344
|
-
assert (
|
345
|
-
new_recv_rid in on_fly_reqs
|
346
|
-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
347
|
-
del on_fly_reqs[new_recv_rid]
|
348
|
-
worker_id += 1
|
349
|
-
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
350
|
-
meta.set_shared_onfly_info(onfly_list)
|
351
|
-
meta.set_shared_local_tokens(local_tokens)
|
315
|
+
if (
|
316
|
+
self.stats.max_running_requests_under_SLO is not None
|
317
|
+
and self.stats.max_running_requests_under_SLO > 0
|
318
|
+
):
|
319
|
+
self.stats.utilization = max(
|
320
|
+
self.stats.num_running_reqs
|
321
|
+
/ self.stats.max_running_requests_under_SLO,
|
322
|
+
self.stats.token_usage / 0.9,
|
323
|
+
)
|
@@ -9,7 +9,11 @@ import torch
|
|
9
9
|
|
10
10
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
11
11
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
12
|
-
from sglang.srt.managers.io_struct import
|
12
|
+
from sglang.srt.managers.io_struct import (
|
13
|
+
AbortReq,
|
14
|
+
BatchEmbeddingOutput,
|
15
|
+
BatchTokenIDOutput,
|
16
|
+
)
|
13
17
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
14
18
|
|
15
19
|
if TYPE_CHECKING:
|
@@ -91,7 +95,7 @@ class SchedulerOutputProcessorMixin:
|
|
91
95
|
|
92
96
|
if req.finished():
|
93
97
|
self.tree_cache.cache_finished_req(req)
|
94
|
-
req.time_stats.completion_time = time.
|
98
|
+
req.time_stats.completion_time = time.perf_counter()
|
95
99
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
96
100
|
# This updates radix so others can match
|
97
101
|
self.tree_cache.cache_unfinished_req(req)
|
@@ -140,7 +144,7 @@ class SchedulerOutputProcessorMixin:
|
|
140
144
|
logger.error(
|
141
145
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
142
146
|
)
|
143
|
-
self.abort_request(AbortReq(req.rid))
|
147
|
+
self.abort_request(AbortReq(rid=req.rid))
|
144
148
|
req.grammar.finished = req.finished()
|
145
149
|
else:
|
146
150
|
# being chunked reqs' prefill is not finished
|
@@ -173,8 +177,7 @@ class SchedulerOutputProcessorMixin:
|
|
173
177
|
self.set_next_batch_sampling_info_done(batch)
|
174
178
|
|
175
179
|
else: # embedding or reward model
|
176
|
-
embeddings
|
177
|
-
embeddings = embeddings.tolist()
|
180
|
+
embeddings = result.embeddings.tolist()
|
178
181
|
|
179
182
|
# Check finish conditions
|
180
183
|
for i, req in enumerate(batch.reqs):
|
@@ -250,8 +253,14 @@ class SchedulerOutputProcessorMixin:
|
|
250
253
|
|
251
254
|
req.check_finished()
|
252
255
|
if req.finished():
|
253
|
-
self.
|
254
|
-
|
256
|
+
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
257
|
+
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
258
|
+
if not self.decode_offload_manager.offload_kv_cache(req):
|
259
|
+
self.tree_cache.cache_finished_req(req)
|
260
|
+
else:
|
261
|
+
self.tree_cache.cache_finished_req(req)
|
262
|
+
|
263
|
+
req.time_stats.completion_time = time.perf_counter()
|
255
264
|
|
256
265
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
257
266
|
# speculative worker handles logprob in speculative decoding
|
@@ -287,7 +296,7 @@ class SchedulerOutputProcessorMixin:
|
|
287
296
|
logger.error(
|
288
297
|
f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
|
289
298
|
)
|
290
|
-
self.abort_request(AbortReq(req.rid))
|
299
|
+
self.abort_request(AbortReq(rid=req.rid))
|
291
300
|
req.grammar.finished = req.finished()
|
292
301
|
|
293
302
|
self.set_next_batch_sampling_info_done(batch)
|
@@ -709,8 +718,7 @@ class SchedulerOutputProcessorMixin:
|
|
709
718
|
return
|
710
719
|
|
711
720
|
self.send_to_detokenizer.send_pyobj(
|
712
|
-
|
713
|
-
rids,
|
721
|
+
BatchTokenIDOutput(
|
714
722
|
finished_reasons,
|
715
723
|
decoded_texts,
|
716
724
|
decode_ids_list,
|
@@ -736,6 +744,7 @@ class SchedulerOutputProcessorMixin:
|
|
736
744
|
output_token_ids_logprobs_val,
|
737
745
|
output_token_ids_logprobs_idx,
|
738
746
|
output_hidden_states,
|
747
|
+
rids=rids,
|
739
748
|
placeholder_tokens_idx=None,
|
740
749
|
placeholder_tokens_val=None,
|
741
750
|
)
|
@@ -756,12 +765,12 @@ class SchedulerOutputProcessorMixin:
|
|
756
765
|
prompt_tokens.append(len(req.origin_input_ids))
|
757
766
|
cached_tokens.append(req.cached_tokens)
|
758
767
|
self.send_to_detokenizer.send_pyobj(
|
759
|
-
|
760
|
-
rids,
|
768
|
+
BatchEmbeddingOutput(
|
761
769
|
finished_reasons,
|
762
770
|
embeddings,
|
763
771
|
prompt_tokens,
|
764
772
|
cached_tokens,
|
773
|
+
rids=rids,
|
765
774
|
placeholder_tokens_idx=None,
|
766
775
|
placeholder_tokens_val=None,
|
767
776
|
)
|
@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
|
|
97
97
|
def start_profile(
|
98
98
|
self, stage: Optional[ForwardMode] = None
|
99
99
|
) -> ProfileReqOutput | None:
|
100
|
-
stage_str = f" for {stage.
|
100
|
+
stage_str = f" for {stage.name}" if stage else ""
|
101
101
|
logger.info(
|
102
102
|
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
|
103
103
|
)
|
@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
|
|
181
181
|
if not Path(self.torch_profiler_output_dir).exists():
|
182
182
|
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
|
183
183
|
|
184
|
-
stage_suffix = f"-{stage.
|
184
|
+
stage_suffix = f"-{stage.name}" if stage else ""
|
185
185
|
logger.info("Stop profiling" + stage_suffix + "...")
|
186
186
|
if self.torch_profiler is not None:
|
187
187
|
self.torch_profiler.stop()
|
@@ -204,7 +204,7 @@ class SchedulerProfilerMixin:
|
|
204
204
|
|
205
205
|
torch.distributed.barrier(self.tp_cpu_group)
|
206
206
|
if self.tp_rank == 0:
|
207
|
-
from sglang.srt.utils import rpd_to_chrome_trace
|
207
|
+
from sglang.srt.utils.rpd_utils import rpd_to_chrome_trace
|
208
208
|
|
209
209
|
rpd_to_chrome_trace("trace.rpd", self.rpd_profile_path)
|
210
210
|
self.rpd_profiler = None
|
@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
|
|
247
247
|
if self.profiler_decode_ct == 0:
|
248
248
|
if self.profile_in_progress:
|
249
249
|
# force trace flush
|
250
|
-
self.stop_profile(ForwardMode.EXTEND)
|
250
|
+
self.stop_profile(stage=ForwardMode.EXTEND)
|
251
251
|
self.start_profile(batch.forward_mode)
|
252
252
|
self.profiler_decode_ct += 1
|
253
253
|
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
|
|
294
294
|
recv_req.profile_by_stage,
|
295
295
|
recv_req.profile_id,
|
296
296
|
)
|
297
|
-
return self.start_profile(
|
297
|
+
return self.start_profile()
|
298
298
|
else:
|
299
299
|
return self.stop_profile()
|
@@ -5,6 +5,8 @@ import torch
|
|
5
5
|
|
6
6
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
7
7
|
from sglang.srt.managers.io_struct import (
|
8
|
+
DestroyWeightsUpdateGroupReqInput,
|
9
|
+
DestroyWeightsUpdateGroupReqOutput,
|
8
10
|
GetWeightsByNameReqInput,
|
9
11
|
GetWeightsByNameReqOutput,
|
10
12
|
InitWeightsUpdateGroupReqInput,
|
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
|
|
41
43
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
42
44
|
return InitWeightsUpdateGroupReqOutput(success, message)
|
43
45
|
|
46
|
+
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
47
|
+
"""Destroy the online model parameter update group."""
|
48
|
+
success, message = self.tp_worker.destroy_weights_update_group(recv_req)
|
49
|
+
return DestroyWeightsUpdateGroupReqOutput(success, message)
|
50
|
+
|
44
51
|
def update_weights_from_distributed(
|
45
52
|
self,
|
46
53
|
recv_req: UpdateWeightsFromDistributedReqInput,
|