sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -38,13 +38,19 @@ from sglang.srt.managers.io_struct import (
|
|
38
38
|
BatchTokenIDOut,
|
39
39
|
CloseSessionReqInput,
|
40
40
|
FlushCacheReq,
|
41
|
+
GetWeightsByNameReqInput,
|
42
|
+
GetWeightsByNameReqOutput,
|
43
|
+
InitWeightsUpdateGroupReqInput,
|
44
|
+
InitWeightsUpdateGroupReqOutput,
|
41
45
|
OpenSessionReqInput,
|
42
46
|
OpenSessionReqOutput,
|
43
47
|
ProfileReq,
|
44
48
|
TokenizedEmbeddingReqInput,
|
45
49
|
TokenizedGenerateReqInput,
|
46
|
-
|
47
|
-
|
50
|
+
UpdateWeightFromDiskReqInput,
|
51
|
+
UpdateWeightFromDiskReqOutput,
|
52
|
+
UpdateWeightsFromDistributedReqInput,
|
53
|
+
UpdateWeightsFromDistributedReqOutput,
|
48
54
|
)
|
49
55
|
from sglang.srt.managers.schedule_batch import (
|
50
56
|
FINISH_ABORT,
|
@@ -108,9 +114,6 @@ class Scheduler:
|
|
108
114
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
109
115
|
self.enable_metrics = server_args.enable_metrics
|
110
116
|
|
111
|
-
# Session info
|
112
|
-
self.sessions = {}
|
113
|
-
|
114
117
|
# Init inter-process communication
|
115
118
|
context = zmq.Context(2)
|
116
119
|
|
@@ -141,9 +144,12 @@ class Scheduler:
|
|
141
144
|
self.model_config = ModelConfig(
|
142
145
|
server_args.model_path,
|
143
146
|
trust_remote_code=server_args.trust_remote_code,
|
147
|
+
revision=server_args.revision,
|
144
148
|
context_length=server_args.context_length,
|
145
149
|
model_override_args=server_args.json_model_override_args,
|
146
150
|
is_embedding=server_args.is_embedding,
|
151
|
+
dtype=server_args.dtype,
|
152
|
+
quantization=server_args.quantization,
|
147
153
|
)
|
148
154
|
self.is_generation = self.model_config.is_generation
|
149
155
|
|
@@ -250,9 +256,15 @@ class Scheduler:
|
|
250
256
|
self.num_generated_tokens = 0
|
251
257
|
self.last_decode_stats_tic = time.time()
|
252
258
|
self.stream_interval = server_args.stream_interval
|
259
|
+
self.current_stream = torch.get_device_module(self.device).current_stream()
|
260
|
+
|
261
|
+
# Session info
|
262
|
+
self.sessions = {}
|
253
263
|
|
254
264
|
# Init chunked prefill
|
255
265
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
266
|
+
if self.chunked_prefill_size <= 0: # -1 means disable
|
267
|
+
self.chunked_prefill_size = None
|
256
268
|
self.being_chunked_req = None
|
257
269
|
self.is_mixed_chunk = (
|
258
270
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
@@ -345,6 +357,7 @@ class Scheduler:
|
|
345
357
|
)
|
346
358
|
|
347
359
|
def watchdog_thread(self):
|
360
|
+
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
|
348
361
|
self.watchdog_last_forward_ct = 0
|
349
362
|
self.watchdog_last_time = time.time()
|
350
363
|
|
@@ -422,61 +435,6 @@ class Scheduler:
|
|
422
435
|
|
423
436
|
self.last_batch = batch
|
424
437
|
|
425
|
-
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
426
|
-
# Check if other DP workers have running batches
|
427
|
-
if local_batch is None:
|
428
|
-
num_tokens = 0
|
429
|
-
elif local_batch.forward_mode.is_decode():
|
430
|
-
num_tokens = local_batch.batch_size()
|
431
|
-
else:
|
432
|
-
num_tokens = local_batch.extend_num_tokens
|
433
|
-
|
434
|
-
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
435
|
-
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
436
|
-
torch.distributed.all_gather_into_tensor(
|
437
|
-
global_num_tokens,
|
438
|
-
local_num_tokens,
|
439
|
-
group=self.tp_cpu_group,
|
440
|
-
)
|
441
|
-
|
442
|
-
if local_batch is None and global_num_tokens.max().item() > 0:
|
443
|
-
local_batch = self.get_idle_batch()
|
444
|
-
|
445
|
-
if local_batch is not None:
|
446
|
-
local_batch.global_num_tokens = global_num_tokens.tolist()
|
447
|
-
|
448
|
-
# Check forward mode for cuda graph
|
449
|
-
if not self.server_args.disable_cuda_graph:
|
450
|
-
forward_mode_state = torch.tensor(
|
451
|
-
(
|
452
|
-
1
|
453
|
-
if local_batch.forward_mode.is_decode()
|
454
|
-
or local_batch.forward_mode.is_idle()
|
455
|
-
else 0
|
456
|
-
),
|
457
|
-
dtype=torch.int32,
|
458
|
-
)
|
459
|
-
torch.distributed.all_reduce(
|
460
|
-
forward_mode_state,
|
461
|
-
op=torch.distributed.ReduceOp.MIN,
|
462
|
-
group=self.tp_cpu_group,
|
463
|
-
)
|
464
|
-
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
465
|
-
|
466
|
-
return local_batch
|
467
|
-
|
468
|
-
def get_idle_batch(self):
|
469
|
-
idle_batch = ScheduleBatch.init_new(
|
470
|
-
[],
|
471
|
-
self.req_to_token_pool,
|
472
|
-
self.token_to_kv_pool,
|
473
|
-
self.tree_cache,
|
474
|
-
self.model_config,
|
475
|
-
self.enable_overlap,
|
476
|
-
)
|
477
|
-
idle_batch.prepare_for_idle()
|
478
|
-
return idle_batch
|
479
|
-
|
480
438
|
def recv_requests(self):
|
481
439
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
482
440
|
recv_reqs = []
|
@@ -504,11 +462,27 @@ class Scheduler:
|
|
504
462
|
self.flush_cache()
|
505
463
|
elif isinstance(recv_req, AbortReq):
|
506
464
|
self.abort_request(recv_req)
|
507
|
-
elif isinstance(recv_req,
|
508
|
-
success, message = self.
|
465
|
+
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
466
|
+
success, message = self.update_weights_from_disk(recv_req)
|
467
|
+
self.send_to_tokenizer.send_pyobj(
|
468
|
+
UpdateWeightFromDiskReqOutput(success, message)
|
469
|
+
)
|
470
|
+
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
471
|
+
parameter = self.get_weights_by_name(recv_req)
|
472
|
+
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
473
|
+
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
474
|
+
success, message = self.init_weights_update_group(recv_req)
|
509
475
|
self.send_to_tokenizer.send_pyobj(
|
510
|
-
|
476
|
+
InitWeightsUpdateGroupReqOutput(success, message)
|
511
477
|
)
|
478
|
+
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
479
|
+
success, message = self.update_weights_from_distributed(recv_req)
|
480
|
+
self.send_to_tokenizer.send_pyobj(
|
481
|
+
UpdateWeightsFromDistributedReqOutput(success, message)
|
482
|
+
)
|
483
|
+
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
484
|
+
parameter = self.get_weights_by_name(recv_req)
|
485
|
+
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
512
486
|
elif isinstance(recv_req, ProfileReq):
|
513
487
|
if recv_req == ProfileReq.START_PROFILE:
|
514
488
|
self.start_profile()
|
@@ -653,7 +627,7 @@ class Scheduler:
|
|
653
627
|
|
654
628
|
self.waiting_queue.append(req)
|
655
629
|
|
656
|
-
def log_prefill_stats(self, adder, can_run_list, running_bs,
|
630
|
+
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
657
631
|
if isinstance(self.tree_cache, RadixCache):
|
658
632
|
self.tree_cache_metrics["total"] += (
|
659
633
|
adder.log_input_tokens + adder.log_hit_tokens
|
@@ -677,14 +651,14 @@ class Scheduler:
|
|
677
651
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
678
652
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
679
653
|
f"#running-req: {running_bs}, "
|
680
|
-
f"#queue-req: {len(self.waiting_queue) +
|
654
|
+
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
|
681
655
|
)
|
682
656
|
|
683
657
|
if self.enable_metrics:
|
684
658
|
self.stats.num_running_reqs = running_bs
|
685
659
|
self.stats.num_used_tokens = num_used
|
686
660
|
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
687
|
-
self.stats.num_queue_reqs = len(self.waiting_queue) +
|
661
|
+
self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
|
688
662
|
self.stats.cache_hit_rate = tree_cache_hit_rate
|
689
663
|
self.metrics_collector.log_stats(self.stats)
|
690
664
|
|
@@ -745,7 +719,7 @@ class Scheduler:
|
|
745
719
|
# Move the chunked request out of the batch
|
746
720
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
747
721
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
748
|
-
#
|
722
|
+
# being chunked request keeps its rid but will get a new req_pool_idx
|
749
723
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
750
724
|
self.batch_is_full = False
|
751
725
|
|
@@ -796,10 +770,10 @@ class Scheduler:
|
|
796
770
|
running_bs if self.is_mixed_chunk else 0,
|
797
771
|
)
|
798
772
|
|
799
|
-
|
800
|
-
if
|
773
|
+
has_being_chunked = self.being_chunked_req is not None
|
774
|
+
if has_being_chunked:
|
801
775
|
self.being_chunked_req.init_next_round_input()
|
802
|
-
self.being_chunked_req = adder.
|
776
|
+
self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
|
803
777
|
|
804
778
|
if self.lora_paths:
|
805
779
|
lora_set = (
|
@@ -841,16 +815,16 @@ class Scheduler:
|
|
841
815
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
842
816
|
]
|
843
817
|
|
844
|
-
if adder.
|
818
|
+
if adder.new_being_chunked_req is not None:
|
845
819
|
assert self.being_chunked_req is None
|
846
|
-
self.being_chunked_req = adder.
|
820
|
+
self.being_chunked_req = adder.new_being_chunked_req
|
847
821
|
|
848
822
|
if self.being_chunked_req:
|
849
823
|
self.being_chunked_req.is_being_chunked += 1
|
850
824
|
|
851
825
|
# Print stats
|
852
826
|
if self.tp_rank == 0:
|
853
|
-
self.log_prefill_stats(adder, can_run_list, running_bs,
|
827
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
854
828
|
|
855
829
|
# Create a new batch
|
856
830
|
new_batch = ScheduleBatch.init_new(
|
@@ -966,7 +940,7 @@ class Scheduler:
|
|
966
940
|
self.process_batch_result_prefill(batch, result)
|
967
941
|
elif batch.forward_mode.is_dummy_first():
|
968
942
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
969
|
-
|
943
|
+
self.current_stream.synchronize()
|
970
944
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
971
945
|
|
972
946
|
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
|
@@ -1022,13 +996,14 @@ class Scheduler:
|
|
1022
996
|
|
1023
997
|
if req.grammar is not None:
|
1024
998
|
req.grammar.accept_token(next_token_id)
|
999
|
+
req.grammar.finished = req.finished()
|
1025
1000
|
else:
|
1026
|
-
#
|
1001
|
+
# being chunked reqs' prefill is not finished
|
1027
1002
|
req.is_being_chunked -= 1
|
1028
1003
|
|
1029
1004
|
if batch.next_batch_sampling_info:
|
1030
1005
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1031
|
-
|
1006
|
+
self.current_stream.synchronize()
|
1032
1007
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1033
1008
|
|
1034
1009
|
else: # embedding or reward model
|
@@ -1051,7 +1026,7 @@ class Scheduler:
|
|
1051
1026
|
else:
|
1052
1027
|
self.tree_cache.cache_unfinished_req(req)
|
1053
1028
|
else:
|
1054
|
-
#
|
1029
|
+
# being chunked reqs' prefill is not finished
|
1055
1030
|
req.is_being_chunked -= 1
|
1056
1031
|
|
1057
1032
|
self.stream_output(batch.reqs)
|
@@ -1100,10 +1075,11 @@ class Scheduler:
|
|
1100
1075
|
|
1101
1076
|
if req.grammar is not None:
|
1102
1077
|
req.grammar.accept_token(next_token_id)
|
1078
|
+
req.grammar.finished = req.finished()
|
1103
1079
|
|
1104
1080
|
if batch.next_batch_sampling_info:
|
1105
1081
|
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
1106
|
-
|
1082
|
+
self.current_stream.synchronize()
|
1107
1083
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
1108
1084
|
|
1109
1085
|
self.stream_output(batch.reqs)
|
@@ -1146,6 +1122,14 @@ class Scheduler:
|
|
1146
1122
|
+ 1 : len(req.fill_ids)
|
1147
1123
|
- req.last_update_decode_tokens
|
1148
1124
|
]
|
1125
|
+
|
1126
|
+
# Clip the padded hash values from image tokens.
|
1127
|
+
# Otherwise, it will lead to detokenization errors.
|
1128
|
+
input_token_ids = [
|
1129
|
+
x if x < self.model_config.vocab_size - 1 else 0
|
1130
|
+
for x in input_token_ids
|
1131
|
+
]
|
1132
|
+
|
1149
1133
|
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
1150
1134
|
|
1151
1135
|
if (
|
@@ -1293,6 +1277,61 @@ class Scheduler:
|
|
1293
1277
|
)
|
1294
1278
|
)
|
1295
1279
|
|
1280
|
+
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
1281
|
+
# Check if other DP workers have running batches
|
1282
|
+
if local_batch is None:
|
1283
|
+
num_tokens = 0
|
1284
|
+
elif local_batch.forward_mode.is_decode():
|
1285
|
+
num_tokens = local_batch.batch_size()
|
1286
|
+
else:
|
1287
|
+
num_tokens = local_batch.extend_num_tokens
|
1288
|
+
|
1289
|
+
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
1290
|
+
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
1291
|
+
torch.distributed.all_gather_into_tensor(
|
1292
|
+
global_num_tokens,
|
1293
|
+
local_num_tokens,
|
1294
|
+
group=self.tp_cpu_group,
|
1295
|
+
)
|
1296
|
+
|
1297
|
+
if local_batch is None and global_num_tokens.max().item() > 0:
|
1298
|
+
local_batch = self.get_idle_batch()
|
1299
|
+
|
1300
|
+
if local_batch is not None:
|
1301
|
+
local_batch.global_num_tokens = global_num_tokens.tolist()
|
1302
|
+
|
1303
|
+
# Check forward mode for cuda graph
|
1304
|
+
if not self.server_args.disable_cuda_graph:
|
1305
|
+
forward_mode_state = torch.tensor(
|
1306
|
+
(
|
1307
|
+
1
|
1308
|
+
if local_batch.forward_mode.is_decode()
|
1309
|
+
or local_batch.forward_mode.is_idle()
|
1310
|
+
else 0
|
1311
|
+
),
|
1312
|
+
dtype=torch.int32,
|
1313
|
+
)
|
1314
|
+
torch.distributed.all_reduce(
|
1315
|
+
forward_mode_state,
|
1316
|
+
op=torch.distributed.ReduceOp.MIN,
|
1317
|
+
group=self.tp_cpu_group,
|
1318
|
+
)
|
1319
|
+
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
1320
|
+
|
1321
|
+
return local_batch
|
1322
|
+
|
1323
|
+
def get_idle_batch(self):
|
1324
|
+
idle_batch = ScheduleBatch.init_new(
|
1325
|
+
[],
|
1326
|
+
self.req_to_token_pool,
|
1327
|
+
self.token_to_kv_pool,
|
1328
|
+
self.tree_cache,
|
1329
|
+
self.model_config,
|
1330
|
+
self.enable_overlap,
|
1331
|
+
)
|
1332
|
+
idle_batch.prepare_for_idle()
|
1333
|
+
return idle_batch
|
1334
|
+
|
1296
1335
|
def move_ready_grammar_requests(self):
|
1297
1336
|
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
1298
1337
|
num_ready_reqs = 0
|
@@ -1361,9 +1400,9 @@ class Scheduler:
|
|
1361
1400
|
req.to_abort = True
|
1362
1401
|
break
|
1363
1402
|
|
1364
|
-
def
|
1365
|
-
"""In-place update of the weights."""
|
1366
|
-
success, message = self.tp_worker.
|
1403
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
1404
|
+
"""In-place update of the weights from disk."""
|
1405
|
+
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
1367
1406
|
if success:
|
1368
1407
|
flash_cache_success = self.flush_cache()
|
1369
1408
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
@@ -1371,6 +1410,27 @@ class Scheduler:
|
|
1371
1410
|
logger.error(message)
|
1372
1411
|
return success, message
|
1373
1412
|
|
1413
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
1414
|
+
"""Initialize the online model parameter update group."""
|
1415
|
+
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
1416
|
+
return success, message
|
1417
|
+
|
1418
|
+
def update_weights_from_distributed(
|
1419
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
1420
|
+
):
|
1421
|
+
"""Update the online model parameter."""
|
1422
|
+
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
1423
|
+
if success:
|
1424
|
+
flash_cache_success = self.flush_cache()
|
1425
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
1426
|
+
else:
|
1427
|
+
logger.error(message)
|
1428
|
+
return success, message
|
1429
|
+
|
1430
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1431
|
+
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1432
|
+
return parameter
|
1433
|
+
|
1374
1434
|
def start_profile(self) -> None:
|
1375
1435
|
if self.profiler is None:
|
1376
1436
|
raise RuntimeError("Profiler is not enabled.")
|
@@ -1413,10 +1473,6 @@ def run_scheduler_process(
|
|
1413
1473
|
dp_rank: Optional[int],
|
1414
1474
|
pipe_writer,
|
1415
1475
|
):
|
1416
|
-
# set cpu affinity to this gpu process
|
1417
|
-
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1418
|
-
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1419
|
-
|
1420
1476
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1421
1477
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1422
1478
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
@@ -1426,6 +1482,10 @@ def run_scheduler_process(
|
|
1426
1482
|
else:
|
1427
1483
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1428
1484
|
|
1485
|
+
# set cpu affinity to this gpu process
|
1486
|
+
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1487
|
+
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1488
|
+
|
1429
1489
|
suppress_other_loggers()
|
1430
1490
|
parent_process = psutil.Process().parent()
|
1431
1491
|
|