sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -38,13 +38,19 @@ from sglang.srt.managers.io_struct import (
38
38
  BatchTokenIDOut,
39
39
  CloseSessionReqInput,
40
40
  FlushCacheReq,
41
+ GetWeightsByNameReqInput,
42
+ GetWeightsByNameReqOutput,
43
+ InitWeightsUpdateGroupReqInput,
44
+ InitWeightsUpdateGroupReqOutput,
41
45
  OpenSessionReqInput,
42
46
  OpenSessionReqOutput,
43
47
  ProfileReq,
44
48
  TokenizedEmbeddingReqInput,
45
49
  TokenizedGenerateReqInput,
46
- UpdateWeightReqInput,
47
- UpdateWeightReqOutput,
50
+ UpdateWeightFromDiskReqInput,
51
+ UpdateWeightFromDiskReqOutput,
52
+ UpdateWeightsFromDistributedReqInput,
53
+ UpdateWeightsFromDistributedReqOutput,
48
54
  )
49
55
  from sglang.srt.managers.schedule_batch import (
50
56
  FINISH_ABORT,
@@ -108,9 +114,6 @@ class Scheduler:
108
114
  self.skip_tokenizer_init = server_args.skip_tokenizer_init
109
115
  self.enable_metrics = server_args.enable_metrics
110
116
 
111
- # Session info
112
- self.sessions = {}
113
-
114
117
  # Init inter-process communication
115
118
  context = zmq.Context(2)
116
119
 
@@ -141,9 +144,12 @@ class Scheduler:
141
144
  self.model_config = ModelConfig(
142
145
  server_args.model_path,
143
146
  trust_remote_code=server_args.trust_remote_code,
147
+ revision=server_args.revision,
144
148
  context_length=server_args.context_length,
145
149
  model_override_args=server_args.json_model_override_args,
146
150
  is_embedding=server_args.is_embedding,
151
+ dtype=server_args.dtype,
152
+ quantization=server_args.quantization,
147
153
  )
148
154
  self.is_generation = self.model_config.is_generation
149
155
 
@@ -250,9 +256,15 @@ class Scheduler:
250
256
  self.num_generated_tokens = 0
251
257
  self.last_decode_stats_tic = time.time()
252
258
  self.stream_interval = server_args.stream_interval
259
+ self.current_stream = torch.get_device_module(self.device).current_stream()
260
+
261
+ # Session info
262
+ self.sessions = {}
253
263
 
254
264
  # Init chunked prefill
255
265
  self.chunked_prefill_size = server_args.chunked_prefill_size
266
+ if self.chunked_prefill_size <= 0: # -1 means disable
267
+ self.chunked_prefill_size = None
256
268
  self.being_chunked_req = None
257
269
  self.is_mixed_chunk = (
258
270
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
@@ -345,6 +357,7 @@ class Scheduler:
345
357
  )
346
358
 
347
359
  def watchdog_thread(self):
360
+ """A watch dog thread that will try to kill the server itself if one batch takes too long."""
348
361
  self.watchdog_last_forward_ct = 0
349
362
  self.watchdog_last_time = time.time()
350
363
 
@@ -422,61 +435,6 @@ class Scheduler:
422
435
 
423
436
  self.last_batch = batch
424
437
 
425
- def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
426
- # Check if other DP workers have running batches
427
- if local_batch is None:
428
- num_tokens = 0
429
- elif local_batch.forward_mode.is_decode():
430
- num_tokens = local_batch.batch_size()
431
- else:
432
- num_tokens = local_batch.extend_num_tokens
433
-
434
- local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
435
- global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
436
- torch.distributed.all_gather_into_tensor(
437
- global_num_tokens,
438
- local_num_tokens,
439
- group=self.tp_cpu_group,
440
- )
441
-
442
- if local_batch is None and global_num_tokens.max().item() > 0:
443
- local_batch = self.get_idle_batch()
444
-
445
- if local_batch is not None:
446
- local_batch.global_num_tokens = global_num_tokens.tolist()
447
-
448
- # Check forward mode for cuda graph
449
- if not self.server_args.disable_cuda_graph:
450
- forward_mode_state = torch.tensor(
451
- (
452
- 1
453
- if local_batch.forward_mode.is_decode()
454
- or local_batch.forward_mode.is_idle()
455
- else 0
456
- ),
457
- dtype=torch.int32,
458
- )
459
- torch.distributed.all_reduce(
460
- forward_mode_state,
461
- op=torch.distributed.ReduceOp.MIN,
462
- group=self.tp_cpu_group,
463
- )
464
- local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
465
-
466
- return local_batch
467
-
468
- def get_idle_batch(self):
469
- idle_batch = ScheduleBatch.init_new(
470
- [],
471
- self.req_to_token_pool,
472
- self.token_to_kv_pool,
473
- self.tree_cache,
474
- self.model_config,
475
- self.enable_overlap,
476
- )
477
- idle_batch.prepare_for_idle()
478
- return idle_batch
479
-
480
438
  def recv_requests(self):
481
439
  if self.tp_rank == 0 or self.server_args.enable_dp_attention:
482
440
  recv_reqs = []
@@ -504,11 +462,27 @@ class Scheduler:
504
462
  self.flush_cache()
505
463
  elif isinstance(recv_req, AbortReq):
506
464
  self.abort_request(recv_req)
507
- elif isinstance(recv_req, UpdateWeightReqInput):
508
- success, message = self.update_weights(recv_req)
465
+ elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
466
+ success, message = self.update_weights_from_disk(recv_req)
467
+ self.send_to_tokenizer.send_pyobj(
468
+ UpdateWeightFromDiskReqOutput(success, message)
469
+ )
470
+ elif isinstance(recv_req, GetWeightsByNameReqInput):
471
+ parameter = self.get_weights_by_name(recv_req)
472
+ self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
473
+ elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
474
+ success, message = self.init_weights_update_group(recv_req)
509
475
  self.send_to_tokenizer.send_pyobj(
510
- UpdateWeightReqOutput(success, message)
476
+ InitWeightsUpdateGroupReqOutput(success, message)
511
477
  )
478
+ elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
479
+ success, message = self.update_weights_from_distributed(recv_req)
480
+ self.send_to_tokenizer.send_pyobj(
481
+ UpdateWeightsFromDistributedReqOutput(success, message)
482
+ )
483
+ elif isinstance(recv_req, GetWeightsByNameReqInput):
484
+ parameter = self.get_weights_by_name(recv_req)
485
+ self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
512
486
  elif isinstance(recv_req, ProfileReq):
513
487
  if recv_req == ProfileReq.START_PROFILE:
514
488
  self.start_profile()
@@ -653,7 +627,7 @@ class Scheduler:
653
627
 
654
628
  self.waiting_queue.append(req)
655
629
 
656
- def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
630
+ def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
657
631
  if isinstance(self.tree_cache, RadixCache):
658
632
  self.tree_cache_metrics["total"] += (
659
633
  adder.log_input_tokens + adder.log_hit_tokens
@@ -677,14 +651,14 @@ class Scheduler:
677
651
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
678
652
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
679
653
  f"#running-req: {running_bs}, "
680
- f"#queue-req: {len(self.waiting_queue) + has_inflight}"
654
+ f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
681
655
  )
682
656
 
683
657
  if self.enable_metrics:
684
658
  self.stats.num_running_reqs = running_bs
685
659
  self.stats.num_used_tokens = num_used
686
660
  self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
687
- self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
661
+ self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
688
662
  self.stats.cache_hit_rate = tree_cache_hit_rate
689
663
  self.metrics_collector.log_stats(self.stats)
690
664
 
@@ -745,7 +719,7 @@ class Scheduler:
745
719
  # Move the chunked request out of the batch
746
720
  self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
747
721
  self.tree_cache.cache_unfinished_req(self.being_chunked_req)
748
- # Inflight request keeps its rid but will get a new req_pool_idx
722
+ # being chunked request keeps its rid but will get a new req_pool_idx
749
723
  self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
750
724
  self.batch_is_full = False
751
725
 
@@ -796,10 +770,10 @@ class Scheduler:
796
770
  running_bs if self.is_mixed_chunk else 0,
797
771
  )
798
772
 
799
- has_inflight = self.being_chunked_req is not None
800
- if has_inflight:
773
+ has_being_chunked = self.being_chunked_req is not None
774
+ if has_being_chunked:
801
775
  self.being_chunked_req.init_next_round_input()
802
- self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
776
+ self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
803
777
 
804
778
  if self.lora_paths:
805
779
  lora_set = (
@@ -841,16 +815,16 @@ class Scheduler:
841
815
  x for x in self.waiting_queue if x not in set(can_run_list)
842
816
  ]
843
817
 
844
- if adder.new_inflight_req is not None:
818
+ if adder.new_being_chunked_req is not None:
845
819
  assert self.being_chunked_req is None
846
- self.being_chunked_req = adder.new_inflight_req
820
+ self.being_chunked_req = adder.new_being_chunked_req
847
821
 
848
822
  if self.being_chunked_req:
849
823
  self.being_chunked_req.is_being_chunked += 1
850
824
 
851
825
  # Print stats
852
826
  if self.tp_rank == 0:
853
- self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
827
+ self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
854
828
 
855
829
  # Create a new batch
856
830
  new_batch = ScheduleBatch.init_new(
@@ -966,7 +940,7 @@ class Scheduler:
966
940
  self.process_batch_result_prefill(batch, result)
967
941
  elif batch.forward_mode.is_dummy_first():
968
942
  batch.next_batch_sampling_info.update_regex_vocab_mask()
969
- torch.cuda.current_stream().synchronize()
943
+ self.current_stream.synchronize()
970
944
  batch.next_batch_sampling_info.sampling_info_done.set()
971
945
 
972
946
  def process_batch_result_prefill(self, batch: ScheduleBatch, result):
@@ -1022,13 +996,14 @@ class Scheduler:
1022
996
 
1023
997
  if req.grammar is not None:
1024
998
  req.grammar.accept_token(next_token_id)
999
+ req.grammar.finished = req.finished()
1025
1000
  else:
1026
- # Inflight reqs' prefill is not finished
1001
+ # being chunked reqs' prefill is not finished
1027
1002
  req.is_being_chunked -= 1
1028
1003
 
1029
1004
  if batch.next_batch_sampling_info:
1030
1005
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1031
- torch.cuda.current_stream().synchronize()
1006
+ self.current_stream.synchronize()
1032
1007
  batch.next_batch_sampling_info.sampling_info_done.set()
1033
1008
 
1034
1009
  else: # embedding or reward model
@@ -1051,7 +1026,7 @@ class Scheduler:
1051
1026
  else:
1052
1027
  self.tree_cache.cache_unfinished_req(req)
1053
1028
  else:
1054
- # Inflight reqs' prefill is not finished
1029
+ # being chunked reqs' prefill is not finished
1055
1030
  req.is_being_chunked -= 1
1056
1031
 
1057
1032
  self.stream_output(batch.reqs)
@@ -1100,10 +1075,11 @@ class Scheduler:
1100
1075
 
1101
1076
  if req.grammar is not None:
1102
1077
  req.grammar.accept_token(next_token_id)
1078
+ req.grammar.finished = req.finished()
1103
1079
 
1104
1080
  if batch.next_batch_sampling_info:
1105
1081
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1106
- torch.cuda.current_stream().synchronize()
1082
+ self.current_stream.synchronize()
1107
1083
  batch.next_batch_sampling_info.sampling_info_done.set()
1108
1084
 
1109
1085
  self.stream_output(batch.reqs)
@@ -1146,6 +1122,14 @@ class Scheduler:
1146
1122
  + 1 : len(req.fill_ids)
1147
1123
  - req.last_update_decode_tokens
1148
1124
  ]
1125
+
1126
+ # Clip the padded hash values from image tokens.
1127
+ # Otherwise, it will lead to detokenization errors.
1128
+ input_token_ids = [
1129
+ x if x < self.model_config.vocab_size - 1 else 0
1130
+ for x in input_token_ids
1131
+ ]
1132
+
1149
1133
  req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
1150
1134
 
1151
1135
  if (
@@ -1293,6 +1277,61 @@ class Scheduler:
1293
1277
  )
1294
1278
  )
1295
1279
 
1280
+ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
1281
+ # Check if other DP workers have running batches
1282
+ if local_batch is None:
1283
+ num_tokens = 0
1284
+ elif local_batch.forward_mode.is_decode():
1285
+ num_tokens = local_batch.batch_size()
1286
+ else:
1287
+ num_tokens = local_batch.extend_num_tokens
1288
+
1289
+ local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
1290
+ global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
1291
+ torch.distributed.all_gather_into_tensor(
1292
+ global_num_tokens,
1293
+ local_num_tokens,
1294
+ group=self.tp_cpu_group,
1295
+ )
1296
+
1297
+ if local_batch is None and global_num_tokens.max().item() > 0:
1298
+ local_batch = self.get_idle_batch()
1299
+
1300
+ if local_batch is not None:
1301
+ local_batch.global_num_tokens = global_num_tokens.tolist()
1302
+
1303
+ # Check forward mode for cuda graph
1304
+ if not self.server_args.disable_cuda_graph:
1305
+ forward_mode_state = torch.tensor(
1306
+ (
1307
+ 1
1308
+ if local_batch.forward_mode.is_decode()
1309
+ or local_batch.forward_mode.is_idle()
1310
+ else 0
1311
+ ),
1312
+ dtype=torch.int32,
1313
+ )
1314
+ torch.distributed.all_reduce(
1315
+ forward_mode_state,
1316
+ op=torch.distributed.ReduceOp.MIN,
1317
+ group=self.tp_cpu_group,
1318
+ )
1319
+ local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
1320
+
1321
+ return local_batch
1322
+
1323
+ def get_idle_batch(self):
1324
+ idle_batch = ScheduleBatch.init_new(
1325
+ [],
1326
+ self.req_to_token_pool,
1327
+ self.token_to_kv_pool,
1328
+ self.tree_cache,
1329
+ self.model_config,
1330
+ self.enable_overlap,
1331
+ )
1332
+ idle_batch.prepare_for_idle()
1333
+ return idle_batch
1334
+
1296
1335
  def move_ready_grammar_requests(self):
1297
1336
  """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
1298
1337
  num_ready_reqs = 0
@@ -1361,9 +1400,9 @@ class Scheduler:
1361
1400
  req.to_abort = True
1362
1401
  break
1363
1402
 
1364
- def update_weights(self, recv_req: UpdateWeightReqInput):
1365
- """In-place update of the weights."""
1366
- success, message = self.tp_worker.update_weights(recv_req)
1403
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
1404
+ """In-place update of the weights from disk."""
1405
+ success, message = self.tp_worker.update_weights_from_disk(recv_req)
1367
1406
  if success:
1368
1407
  flash_cache_success = self.flush_cache()
1369
1408
  assert flash_cache_success, "Cache flush failed after updating weights"
@@ -1371,6 +1410,27 @@ class Scheduler:
1371
1410
  logger.error(message)
1372
1411
  return success, message
1373
1412
 
1413
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
1414
+ """Initialize the online model parameter update group."""
1415
+ success, message = self.tp_worker.init_weights_update_group(recv_req)
1416
+ return success, message
1417
+
1418
+ def update_weights_from_distributed(
1419
+ self, recv_req: UpdateWeightsFromDistributedReqInput
1420
+ ):
1421
+ """Update the online model parameter."""
1422
+ success, message = self.tp_worker.update_weights_from_distributed(recv_req)
1423
+ if success:
1424
+ flash_cache_success = self.flush_cache()
1425
+ assert flash_cache_success, "Cache flush failed after updating weights"
1426
+ else:
1427
+ logger.error(message)
1428
+ return success, message
1429
+
1430
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1431
+ parameter = self.tp_worker.get_weights_by_name(recv_req)
1432
+ return parameter
1433
+
1374
1434
  def start_profile(self) -> None:
1375
1435
  if self.profiler is None:
1376
1436
  raise RuntimeError("Profiler is not enabled.")
@@ -1413,10 +1473,6 @@ def run_scheduler_process(
1413
1473
  dp_rank: Optional[int],
1414
1474
  pipe_writer,
1415
1475
  ):
1416
- # set cpu affinity to this gpu process
1417
- if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1418
- set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1419
-
1420
1476
  # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
1421
1477
  if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
1422
1478
  dp_rank = int(os.environ["SGLANG_DP_RANK"])
@@ -1426,6 +1482,10 @@ def run_scheduler_process(
1426
1482
  else:
1427
1483
  configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
1428
1484
 
1485
+ # set cpu affinity to this gpu process
1486
+ if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
1487
+ set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
1488
+
1429
1489
  suppress_other_loggers()
1430
1490
  parent_process = psutil.Process().parent()
1431
1491