sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -23,7 +23,6 @@ import time
|
|
23
23
|
from collections import defaultdict, deque
|
24
24
|
from concurrent import futures
|
25
25
|
from dataclasses import dataclass
|
26
|
-
from http import HTTPStatus
|
27
26
|
from pathlib import Path
|
28
27
|
from types import SimpleNamespace
|
29
28
|
from typing import Dict, List, Optional, Tuple, Union
|
@@ -36,6 +35,7 @@ from torch.distributed import barrier
|
|
36
35
|
|
37
36
|
from sglang.global_config import global_config
|
38
37
|
from sglang.srt.configs.model_config import ModelConfig
|
38
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
39
39
|
from sglang.srt.constrained.base_grammar_backend import (
|
40
40
|
INVALID_GRAMMAR_OBJ,
|
41
41
|
create_grammar_backend,
|
@@ -140,6 +140,7 @@ from sglang.srt.utils import (
|
|
140
140
|
DeepEPMode,
|
141
141
|
DynamicGradMode,
|
142
142
|
broadcast_pyobj,
|
143
|
+
configure_gc_logger,
|
143
144
|
configure_logger,
|
144
145
|
disable_request_logging,
|
145
146
|
get_available_gpu_memory,
|
@@ -148,6 +149,8 @@ from sglang.srt.utils import (
|
|
148
149
|
kill_itself_when_parent_died,
|
149
150
|
point_to_point_pyobj,
|
150
151
|
pyspy_dump_schedulers,
|
152
|
+
require_mlp_sync,
|
153
|
+
require_mlp_tp_gather,
|
151
154
|
set_gpu_proc_affinity,
|
152
155
|
set_random_seed,
|
153
156
|
suppress_other_loggers,
|
@@ -179,6 +182,18 @@ class EmbeddingBatchResult:
|
|
179
182
|
bid: int
|
180
183
|
|
181
184
|
|
185
|
+
class KvMetrics:
|
186
|
+
def __init__(self):
|
187
|
+
self.request_active_slots = None
|
188
|
+
self.request_total_slots = None
|
189
|
+
self.kv_active_blocks = None
|
190
|
+
self.kv_total_blocks = None
|
191
|
+
self.num_requests_waiting = None
|
192
|
+
self.gpu_cache_usage_perc = None
|
193
|
+
self.gpu_prefix_cache_hit_rate = None
|
194
|
+
self.data_parallel_rank = None
|
195
|
+
|
196
|
+
|
182
197
|
class IdleSleeper:
|
183
198
|
"""
|
184
199
|
In setups which have long inactivity periods it is desirable to reduce
|
@@ -220,6 +235,7 @@ class Scheduler(
|
|
220
235
|
self.server_args = server_args
|
221
236
|
self.tp_rank = tp_rank
|
222
237
|
self.pp_rank = pp_rank
|
238
|
+
self.dp_rank = dp_rank
|
223
239
|
self.tp_size = server_args.tp_size
|
224
240
|
self.pp_size = server_args.pp_size
|
225
241
|
self.dp_size = server_args.dp_size
|
@@ -258,6 +274,9 @@ class Scheduler(
|
|
258
274
|
self.send_to_tokenizer = get_zmq_socket(
|
259
275
|
context, zmq.PUSH, port_args.tokenizer_ipc_name, False
|
260
276
|
)
|
277
|
+
self.send_metrics_from_scheduler = get_zmq_socket(
|
278
|
+
context, zmq.PUSH, port_args.metrics_ipc_name, False
|
279
|
+
)
|
261
280
|
|
262
281
|
if server_args.skip_tokenizer_init:
|
263
282
|
# Directly send to the TokenizerManager
|
@@ -283,6 +302,7 @@ class Scheduler(
|
|
283
302
|
else:
|
284
303
|
self.recv_from_tokenizer = None
|
285
304
|
self.recv_from_rpc = None
|
305
|
+
self.send_metrics_from_scheduler = None
|
286
306
|
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
287
307
|
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
288
308
|
|
@@ -450,8 +470,6 @@ class Scheduler(
|
|
450
470
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
451
471
|
t.start()
|
452
472
|
self.parent_process = psutil.Process().parent()
|
453
|
-
|
454
|
-
# Init memory saver
|
455
473
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
456
474
|
enable=server_args.enable_memory_saver
|
457
475
|
)
|
@@ -508,6 +526,9 @@ class Scheduler(
|
|
508
526
|
)
|
509
527
|
self.init_disaggregation()
|
510
528
|
|
529
|
+
if get_bool_env_var("SGLANG_GC_LOG"):
|
530
|
+
configure_gc_logger()
|
531
|
+
|
511
532
|
def maybe_sleep_on_idle(self):
|
512
533
|
if self.idle_sleeper is not None:
|
513
534
|
self.idle_sleeper.maybe_sleep()
|
@@ -559,12 +580,20 @@ class Scheduler(
|
|
559
580
|
self.tree_cache = HiRadixCache(
|
560
581
|
req_to_token_pool=self.req_to_token_pool,
|
561
582
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
562
|
-
tp_cache_group=
|
583
|
+
tp_cache_group=(
|
584
|
+
self.attn_tp_cpu_group
|
585
|
+
if self.server_args.enable_dp_attention
|
586
|
+
else self.tp_cpu_group
|
587
|
+
),
|
563
588
|
page_size=self.page_size,
|
564
589
|
hicache_ratio=server_args.hicache_ratio,
|
565
590
|
hicache_size=server_args.hicache_size,
|
566
591
|
hicache_write_policy=server_args.hicache_write_policy,
|
567
592
|
)
|
593
|
+
self.tp_worker.register_hicache_layer_transfer_counter(
|
594
|
+
self.tree_cache.cache_controller.layer_done_counter
|
595
|
+
)
|
596
|
+
|
568
597
|
else:
|
569
598
|
self.tree_cache = RadixCache(
|
570
599
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -622,7 +651,12 @@ class Scheduler(
|
|
622
651
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
623
652
|
buffer_size
|
624
653
|
)
|
625
|
-
self.disagg_metadata_buffers = MetadataBuffers(
|
654
|
+
self.disagg_metadata_buffers = MetadataBuffers(
|
655
|
+
buffer_size,
|
656
|
+
hidden_size=self.model_config.hf_text_config.hidden_size,
|
657
|
+
dtype=self.model_config.dtype,
|
658
|
+
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
659
|
+
)
|
626
660
|
|
627
661
|
# The decode requests polling kv cache
|
628
662
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
@@ -669,7 +703,12 @@ class Scheduler(
|
|
669
703
|
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
670
704
|
buffer_size
|
671
705
|
)
|
672
|
-
self.disagg_metadata_buffers = MetadataBuffers(
|
706
|
+
self.disagg_metadata_buffers = MetadataBuffers(
|
707
|
+
buffer_size,
|
708
|
+
hidden_size=self.model_config.hf_text_config.hidden_size,
|
709
|
+
dtype=self.model_config.dtype,
|
710
|
+
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
|
711
|
+
)
|
673
712
|
|
674
713
|
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
675
714
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
@@ -795,11 +834,28 @@ class Scheduler(
|
|
795
834
|
result.next_token_ids,
|
796
835
|
result.bid,
|
797
836
|
)
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
837
|
+
if self.cur_batch.return_logprob:
|
838
|
+
pp_outputs = PPProxyTensors(
|
839
|
+
{
|
840
|
+
"next_token_ids": next_token_ids,
|
841
|
+
"extend_input_len_per_req": result.extend_input_len_per_req,
|
842
|
+
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
843
|
+
}
|
844
|
+
| (
|
845
|
+
{
|
846
|
+
f"logits_output.{k}": v
|
847
|
+
for k, v in result.logits_output.__dict__.items()
|
848
|
+
}
|
849
|
+
if result.logits_output is not None
|
850
|
+
else {}
|
851
|
+
)
|
852
|
+
)
|
853
|
+
else:
|
854
|
+
pp_outputs = PPProxyTensors(
|
855
|
+
{
|
856
|
+
"next_token_ids": next_token_ids,
|
857
|
+
}
|
858
|
+
)
|
803
859
|
# send the output from the last round to let the next stage worker run post processing
|
804
860
|
self.pp_group.send_tensor_dict(
|
805
861
|
pp_outputs.tensors,
|
@@ -816,12 +872,25 @@ class Scheduler(
|
|
816
872
|
)
|
817
873
|
)
|
818
874
|
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
875
|
+
logits_output_args = {
|
876
|
+
k[len("logits_output.") :]: v
|
877
|
+
for k, v in next_pp_outputs.tensors.items()
|
878
|
+
if k.startswith("logits_output.")
|
879
|
+
}
|
880
|
+
if len(logits_output_args) > 0:
|
881
|
+
logits_output = LogitsProcessorOutput(**logits_output_args)
|
882
|
+
else:
|
883
|
+
logits_output = None
|
819
884
|
output_result = GenerationBatchResult(
|
820
|
-
logits_output=
|
885
|
+
logits_output=logits_output,
|
821
886
|
pp_hidden_states_proxy_tensors=None,
|
822
887
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
823
|
-
extend_input_len_per_req=
|
824
|
-
|
888
|
+
extend_input_len_per_req=next_pp_outputs.tensors.get(
|
889
|
+
"extend_input_len_per_req", None
|
890
|
+
),
|
891
|
+
extend_logprob_start_len_per_req=next_pp_outputs.tensors.get(
|
892
|
+
"extend_logprob_start_len_per_req", None
|
893
|
+
),
|
825
894
|
bid=bids[next_mb_id],
|
826
895
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
827
896
|
)
|
@@ -1187,6 +1256,22 @@ class Scheduler(
|
|
1187
1256
|
req.logprob_start_len = len(req.origin_input_ids) - 1
|
1188
1257
|
self._add_request_to_queue(req)
|
1189
1258
|
|
1259
|
+
def _emit_kv_metrics(self):
|
1260
|
+
kv_metrics = KvMetrics()
|
1261
|
+
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
1262
|
+
kv_metrics.request_total_slots = self.max_running_requests
|
1263
|
+
kv_metrics.kv_active_blocks = int(
|
1264
|
+
self.stats.token_usage * self.max_total_num_tokens
|
1265
|
+
)
|
1266
|
+
kv_metrics.kv_total_blocks = self.max_total_num_tokens
|
1267
|
+
kv_metrics.num_requests_waiting = self.stats.num_queue_reqs
|
1268
|
+
kv_metrics.gpu_cache_usage_perc = self.stats.token_usage
|
1269
|
+
kv_metrics.gpu_prefix_cache_hit_rate = self.stats.cache_hit_rate
|
1270
|
+
kv_metrics.data_parallel_rank = self.dp_rank if self.dp_rank is not None else 0
|
1271
|
+
|
1272
|
+
if not self.send_metrics_from_scheduler.closed:
|
1273
|
+
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
1274
|
+
|
1190
1275
|
def log_prefill_stats(
|
1191
1276
|
self,
|
1192
1277
|
adder: PrefillAdder,
|
@@ -1239,6 +1324,7 @@ class Scheduler(
|
|
1239
1324
|
self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
|
1240
1325
|
|
1241
1326
|
self.metrics_collector.log_stats(self.stats)
|
1327
|
+
self._emit_kv_metrics()
|
1242
1328
|
self._publish_kv_events()
|
1243
1329
|
|
1244
1330
|
def log_decode_stats(
|
@@ -1300,6 +1386,7 @@ class Scheduler(
|
|
1300
1386
|
self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
|
1301
1387
|
self.stats.spec_accept_length = spec_accept_length
|
1302
1388
|
self.metrics_collector.log_stats(self.stats)
|
1389
|
+
self._emit_kv_metrics()
|
1303
1390
|
self._publish_kv_events()
|
1304
1391
|
|
1305
1392
|
def check_memory(self):
|
@@ -1322,7 +1409,14 @@ class Scheduler(
|
|
1322
1409
|
)
|
1323
1410
|
raise ValueError(msg)
|
1324
1411
|
|
1325
|
-
if
|
1412
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1413
|
+
req_total_size = (
|
1414
|
+
self.req_to_token_pool.size + self.req_to_token_pool.pre_alloc_size
|
1415
|
+
)
|
1416
|
+
else:
|
1417
|
+
req_total_size = self.req_to_token_pool.size
|
1418
|
+
|
1419
|
+
if len(self.req_to_token_pool.free_slots) != req_total_size:
|
1326
1420
|
msg = (
|
1327
1421
|
"req_to_token_pool memory leak detected!"
|
1328
1422
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
@@ -1383,6 +1477,15 @@ class Scheduler(
|
|
1383
1477
|
self.running_batch.merge_batch(self.last_batch)
|
1384
1478
|
|
1385
1479
|
new_batch = self.get_new_batch_prefill()
|
1480
|
+
|
1481
|
+
need_dp_attn_preparation = require_mlp_sync(self.server_args)
|
1482
|
+
|
1483
|
+
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
1484
|
+
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
1485
|
+
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
1486
|
+
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
|
1487
|
+
need_dp_attn_preparation = new_batch is None
|
1488
|
+
|
1386
1489
|
if new_batch is not None:
|
1387
1490
|
# Run prefill first if possible
|
1388
1491
|
ret = new_batch
|
@@ -1395,8 +1498,8 @@ class Scheduler(
|
|
1395
1498
|
ret = None
|
1396
1499
|
|
1397
1500
|
# Handle DP attention
|
1398
|
-
if
|
1399
|
-
ret, _ = self.
|
1501
|
+
if need_dp_attn_preparation:
|
1502
|
+
ret, _ = self.prepare_mlp_sync_batch(ret)
|
1400
1503
|
|
1401
1504
|
return ret
|
1402
1505
|
|
@@ -1428,15 +1531,14 @@ class Scheduler(
|
|
1428
1531
|
return None
|
1429
1532
|
|
1430
1533
|
if self.enable_hierarchical_cache:
|
1431
|
-
|
1432
|
-
self.tree_cache.writing_check()
|
1433
|
-
self.tree_cache.loading_check()
|
1534
|
+
self.tree_cache.check_hicache_events()
|
1434
1535
|
|
1435
1536
|
# Get priority queue
|
1436
|
-
|
1537
|
+
self.policy.calc_priority(self.waiting_queue)
|
1437
1538
|
|
1438
1539
|
# Prefill policy
|
1439
1540
|
adder = PrefillAdder(
|
1541
|
+
self.page_size,
|
1440
1542
|
self.tree_cache,
|
1441
1543
|
self.token_to_kv_pool_allocator,
|
1442
1544
|
self.running_batch,
|
@@ -1478,14 +1580,8 @@ class Scheduler(
|
|
1478
1580
|
self.running_batch.batch_is_full = True
|
1479
1581
|
break
|
1480
1582
|
|
1481
|
-
req.init_next_round_input(
|
1482
|
-
|
1483
|
-
self.enable_hierarchical_cache,
|
1484
|
-
)
|
1485
|
-
|
1486
|
-
res = adder.add_one_req(
|
1487
|
-
req, self.chunked_req, self.enable_hierarchical_cache
|
1488
|
-
)
|
1583
|
+
req.init_next_round_input(self.tree_cache)
|
1584
|
+
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
1489
1585
|
|
1490
1586
|
if res != AddReqResult.CONTINUE:
|
1491
1587
|
if res == AddReqResult.NO_TOKEN:
|
@@ -1512,9 +1608,6 @@ class Scheduler(
|
|
1512
1608
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
1513
1609
|
]
|
1514
1610
|
|
1515
|
-
if self.enable_hierarchical_cache:
|
1516
|
-
self.tree_cache.ready_to_load_cache()
|
1517
|
-
|
1518
1611
|
if adder.new_chunked_req is not None:
|
1519
1612
|
assert self.chunked_req is None
|
1520
1613
|
self.chunked_req = adder.new_chunked_req
|
@@ -1538,6 +1631,12 @@ class Scheduler(
|
|
1538
1631
|
self.server_args.enable_custom_logit_processor,
|
1539
1632
|
chunked_req=self.chunked_req,
|
1540
1633
|
)
|
1634
|
+
if self.enable_hierarchical_cache:
|
1635
|
+
# todo (zhiqiang): disable cuda graph execution if hicache loading triggered
|
1636
|
+
new_batch.hicache_consumer_index = (
|
1637
|
+
self.tree_cache.ready_to_load_host_cache()
|
1638
|
+
)
|
1639
|
+
|
1541
1640
|
new_batch.prepare_for_extend()
|
1542
1641
|
|
1543
1642
|
# Mixed-style chunked prefill
|
@@ -1613,6 +1712,11 @@ class Scheduler(
|
|
1613
1712
|
if self.is_generation:
|
1614
1713
|
if self.spec_algorithm.is_none():
|
1615
1714
|
model_worker_batch = batch.get_model_worker_batch()
|
1715
|
+
|
1716
|
+
# update the consumer index of hicache to the running batch
|
1717
|
+
self.tp_worker.set_hicache_consumer(
|
1718
|
+
model_worker_batch.hicache_consumer_index
|
1719
|
+
)
|
1616
1720
|
if self.pp_group.is_last_rank:
|
1617
1721
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
1618
1722
|
self.tp_worker.forward_batch_generation(model_worker_batch)
|
@@ -1641,13 +1745,15 @@ class Scheduler(
|
|
1641
1745
|
# These 2 values are needed for processing the output, but the values can be
|
1642
1746
|
# modified by overlap schedule. So we have to copy them here so that
|
1643
1747
|
# we can use the correct values in output processing.
|
1644
|
-
if batch.return_logprob:
|
1748
|
+
if batch.return_logprob or self.spec_algorithm.is_eagle():
|
1645
1749
|
extend_input_len_per_req = [req.extend_input_len for req in batch.reqs]
|
1750
|
+
else:
|
1751
|
+
extend_input_len_per_req = None
|
1752
|
+
if batch.return_logprob:
|
1646
1753
|
extend_logprob_start_len_per_req = [
|
1647
1754
|
req.extend_logprob_start_len for req in batch.reqs
|
1648
1755
|
]
|
1649
1756
|
else:
|
1650
|
-
extend_input_len_per_req = None
|
1651
1757
|
extend_logprob_start_len_per_req = None
|
1652
1758
|
|
1653
1759
|
ret = GenerationBatchResult(
|
@@ -1695,12 +1801,11 @@ class Scheduler(
|
|
1695
1801
|
self.return_health_check_ct -= 1
|
1696
1802
|
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
|
1697
1803
|
|
1698
|
-
def
|
1699
|
-
return self.
|
1804
|
+
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
|
1805
|
+
return self.prepare_mlp_sync_batch_raw(
|
1700
1806
|
local_batch,
|
1701
1807
|
dp_size=self.server_args.dp_size,
|
1702
1808
|
attn_tp_size=self.attn_tp_size,
|
1703
|
-
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
|
1704
1809
|
tp_cpu_group=self.tp_cpu_group,
|
1705
1810
|
get_idle_batch=self.get_idle_batch,
|
1706
1811
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
@@ -1709,14 +1814,14 @@ class Scheduler(
|
|
1709
1814
|
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
1710
1815
|
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
1711
1816
|
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
1817
|
+
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1712
1818
|
)
|
1713
1819
|
|
1714
1820
|
@staticmethod
|
1715
|
-
def
|
1821
|
+
def prepare_mlp_sync_batch_raw(
|
1716
1822
|
local_batch: ScheduleBatch,
|
1717
1823
|
dp_size,
|
1718
1824
|
attn_tp_size: int,
|
1719
|
-
moe_dense_tp_size: Optional[int],
|
1720
1825
|
tp_cpu_group,
|
1721
1826
|
get_idle_batch,
|
1722
1827
|
disable_cuda_graph: bool,
|
@@ -1725,6 +1830,7 @@ class Scheduler(
|
|
1725
1830
|
enable_two_batch_overlap: bool,
|
1726
1831
|
enable_deepep_moe: bool,
|
1727
1832
|
deepep_mode: DeepEPMode,
|
1833
|
+
require_mlp_tp_gather: bool,
|
1728
1834
|
):
|
1729
1835
|
# Check if other DP workers have running batches
|
1730
1836
|
if local_batch is None:
|
@@ -1732,8 +1838,6 @@ class Scheduler(
|
|
1732
1838
|
num_tokens_for_logprob = 0
|
1733
1839
|
elif local_batch.forward_mode.is_decode():
|
1734
1840
|
num_tokens = local_batch.batch_size()
|
1735
|
-
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
1736
|
-
num_tokens = num_tokens * speculative_num_draft_tokens
|
1737
1841
|
num_tokens_for_logprob = num_tokens
|
1738
1842
|
else:
|
1739
1843
|
num_tokens = local_batch.extend_num_tokens
|
@@ -1752,11 +1856,6 @@ class Scheduler(
|
|
1752
1856
|
else:
|
1753
1857
|
can_cuda_graph = 0
|
1754
1858
|
|
1755
|
-
if not spec_algorithm.is_none():
|
1756
|
-
# TODO(sang): Support cuda graph when idle batch is there.
|
1757
|
-
if local_batch is None or local_batch.forward_mode.is_idle():
|
1758
|
-
can_cuda_graph = 0
|
1759
|
-
|
1760
1859
|
is_extend_in_batch = (
|
1761
1860
|
local_batch.forward_mode.is_extend() if local_batch else False
|
1762
1861
|
)
|
@@ -1801,7 +1900,7 @@ class Scheduler(
|
|
1801
1900
|
|
1802
1901
|
if local_batch is not None:
|
1803
1902
|
# TODO: handle the case when moe_dense_tp_size != 1
|
1804
|
-
if
|
1903
|
+
if not require_mlp_tp_gather:
|
1805
1904
|
local_batch.global_num_tokens = [num_tokens]
|
1806
1905
|
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
|
1807
1906
|
else:
|
@@ -1809,6 +1908,7 @@ class Scheduler(
|
|
1809
1908
|
local_batch.global_num_tokens_for_logprob = (
|
1810
1909
|
global_num_tokens_for_logprob
|
1811
1910
|
)
|
1911
|
+
local_batch.is_extend_in_batch = any(is_extend_in_batch)
|
1812
1912
|
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
1813
1913
|
local_batch.global_forward_mode = global_forward_mode
|
1814
1914
|
|
@@ -1816,6 +1916,7 @@ class Scheduler(
|
|
1816
1916
|
if not disable_cuda_graph:
|
1817
1917
|
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
1818
1918
|
|
1919
|
+
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
|
1819
1920
|
return local_batch, any(is_extend_in_batch)
|
1820
1921
|
|
1821
1922
|
def get_idle_batch(self):
|
@@ -2135,8 +2236,8 @@ class Scheduler(
|
|
2135
2236
|
"""In-place update of the weights from disk."""
|
2136
2237
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
2137
2238
|
if success:
|
2138
|
-
|
2139
|
-
assert
|
2239
|
+
flush_cache_success = self.flush_cache()
|
2240
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
2140
2241
|
else:
|
2141
2242
|
logger.error(message)
|
2142
2243
|
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
@@ -2153,8 +2254,8 @@ class Scheduler(
|
|
2153
2254
|
"""Update the online model parameter."""
|
2154
2255
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
2155
2256
|
if success:
|
2156
|
-
|
2157
|
-
assert
|
2257
|
+
flush_cache_success = self.flush_cache()
|
2258
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
2158
2259
|
else:
|
2159
2260
|
logger.error(message)
|
2160
2261
|
return UpdateWeightsFromDistributedReqOutput(success, message)
|
@@ -2165,10 +2266,11 @@ class Scheduler(
|
|
2165
2266
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
2166
2267
|
if success:
|
2167
2268
|
if recv_req.flush_cache:
|
2168
|
-
|
2169
|
-
assert
|
2269
|
+
flush_cache_success = self.flush_cache()
|
2270
|
+
assert flush_cache_success, "Cache flush failed after updating weights"
|
2170
2271
|
else:
|
2171
2272
|
logger.error(message)
|
2273
|
+
barrier(group=self.tp_cpu_group)
|
2172
2274
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
2173
2275
|
|
2174
2276
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
@@ -2176,23 +2278,40 @@ class Scheduler(
|
|
2176
2278
|
return GetWeightsByNameReqOutput(parameter)
|
2177
2279
|
|
2178
2280
|
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
2179
|
-
|
2180
|
-
|
2181
|
-
|
2182
|
-
|
2183
|
-
|
2184
|
-
|
2185
|
-
|
2186
|
-
|
2281
|
+
tags = recv_req.tags
|
2282
|
+
import subprocess
|
2283
|
+
|
2284
|
+
if tags is None:
|
2285
|
+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2286
|
+
|
2287
|
+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2288
|
+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
2289
|
+
self.flush_cache()
|
2290
|
+
|
2291
|
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2292
|
+
self.stashed_model_static_state = _export_static_state(
|
2293
|
+
self.tp_worker.worker.model_runner.model
|
2294
|
+
)
|
2295
|
+
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
2296
|
+
|
2187
2297
|
return ReleaseMemoryOccupationReqOutput()
|
2188
2298
|
|
2189
2299
|
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
2190
|
-
|
2191
|
-
|
2192
|
-
|
2193
|
-
|
2194
|
-
|
2195
|
-
|
2300
|
+
tags = recv_req.tags
|
2301
|
+
if tags is None or len(tags) == 0:
|
2302
|
+
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
2303
|
+
|
2304
|
+
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
2305
|
+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
2306
|
+
_import_static_state(
|
2307
|
+
self.tp_worker.worker.model_runner.model,
|
2308
|
+
self.stashed_model_static_state,
|
2309
|
+
)
|
2310
|
+
del self.stashed_model_static_state
|
2311
|
+
|
2312
|
+
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
2313
|
+
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
2314
|
+
|
2196
2315
|
return ResumeMemoryOccupationReqOutput()
|
2197
2316
|
|
2198
2317
|
def slow_down(self, recv_req: SlowDownReqInput):
|
@@ -2421,8 +2540,10 @@ class Scheduler(
|
|
2421
2540
|
if self.profiler_decode_ct > self.profiler_target_decode_ct:
|
2422
2541
|
if self.profile_in_progress:
|
2423
2542
|
self.stop_profile(stage=ForwardMode.DECODE)
|
2543
|
+
elif batch.forward_mode.is_idle():
|
2544
|
+
pass
|
2424
2545
|
else:
|
2425
|
-
raise RuntimeError("unsupported profile stage")
|
2546
|
+
raise RuntimeError(f"unsupported profile stage: {batch.forward_mode}")
|
2426
2547
|
else:
|
2427
2548
|
# Check profiler
|
2428
2549
|
if (
|