sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -45,20 +45,24 @@ from sglang.srt.managers.io_struct import (
45
45
  EmbeddingReqInput,
46
46
  FlushCacheReq,
47
47
  GenerateReqInput,
48
- GetMemPoolSizeReq,
49
- GetMemPoolSizeReqOutput,
48
+ GetWeightsByNameReqInput,
49
+ GetWeightsByNameReqOutput,
50
+ InitWeightsUpdateGroupReqInput,
51
+ InitWeightsUpdateGroupReqOutput,
50
52
  OpenSessionReqInput,
51
53
  OpenSessionReqOutput,
52
54
  ProfileReq,
53
55
  TokenizedEmbeddingReqInput,
54
56
  TokenizedGenerateReqInput,
55
- UpdateWeightReqInput,
56
- UpdateWeightReqOutput,
57
+ UpdateWeightFromDiskReqInput,
58
+ UpdateWeightFromDiskReqOutput,
59
+ UpdateWeightsFromDistributedReqInput,
60
+ UpdateWeightsFromDistributedReqOutput,
57
61
  )
58
62
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
59
63
  from sglang.srt.sampling.sampling_params import SamplingParams
60
64
  from sglang.srt.server_args import PortArgs, ServerArgs
61
- from sglang.srt.utils import get_zmq_socket, kill_child_process
65
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
62
66
 
63
67
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
64
68
 
@@ -105,9 +109,12 @@ class TokenizerManager:
105
109
  self.model_config = ModelConfig(
106
110
  server_args.model_path,
107
111
  trust_remote_code=server_args.trust_remote_code,
112
+ revision=server_args.revision,
108
113
  context_length=server_args.context_length,
109
114
  model_override_args=server_args.json_model_override_args,
110
115
  is_embedding=server_args.is_embedding,
116
+ dtype=server_args.dtype,
117
+ quantization=server_args.quantization,
111
118
  )
112
119
 
113
120
  self.is_generation = self.model_config.is_generation
@@ -218,7 +225,8 @@ class TokenizerManager:
218
225
  input_ids = obj.input_ids
219
226
 
220
227
  if self.is_generation:
221
- image_inputs = await self.image_processor.process_images_async(
228
+ # TODO: also support getting embeddings for multimodal models
229
+ image_inputs: Dict = await self.image_processor.process_images_async(
222
230
  obj.image_data, input_text or input_ids, obj
223
231
  )
224
232
  if image_inputs and "input_ids" in image_inputs:
@@ -331,6 +339,12 @@ class TokenizerManager:
331
339
  rids.append(tmp_obj.rid)
332
340
  else:
333
341
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
342
+ if batch_size > 128:
343
+ logger.warning(
344
+ "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
345
+ "The performance might be better if you just duplicate the requests n times or use "
346
+ "many threads to send them one by one with parallel sampling (n > 1)."
347
+ )
334
348
 
335
349
  # Tokenize all requests
336
350
  objs = [obj[i] for i in range(batch_size)]
@@ -406,27 +420,10 @@ class TokenizerManager:
406
420
  req = ProfileReq.STOP_PROFILE
407
421
  self.send_to_scheduler.send_pyobj(req)
408
422
 
409
- async def get_memory_pool_size(self):
410
- if self.to_create_loop:
411
- self.create_handle_loop()
412
-
413
- req = GetMemPoolSizeReq()
414
-
415
- self.send_to_scheduler.send_pyobj(req)
416
- self.mem_pool_size = asyncio.Future()
417
-
418
- # FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
419
- if self.server_args.dp_size == 1:
420
- res = await self.mem_pool_size
421
- return res.size
422
- else: # self.server_args.dp_size > 1
423
- self.mem_pool_size_tmp = []
424
- res = await self.mem_pool_size
425
- ret = [r.size for r in res]
426
- return ret
427
-
428
- async def update_weights(
429
- self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
423
+ async def update_weights_from_disk(
424
+ self,
425
+ obj: UpdateWeightFromDiskReqInput,
426
+ request: Optional[fastapi.Request] = None,
430
427
  ):
431
428
  if self.to_create_loop:
432
429
  self.create_handle_loop()
@@ -471,6 +468,63 @@ class TokenizerManager:
471
468
  else:
472
469
  return False, "Another update is in progress. Please try again later."
473
470
 
471
+ async def init_weights_update_group(
472
+ self,
473
+ obj: InitWeightsUpdateGroupReqInput,
474
+ request: Optional[fastapi.Request] = None,
475
+ ) -> bool:
476
+ if self.to_create_loop:
477
+ self.create_handle_loop()
478
+ self.send_to_scheduler.send_pyobj(obj)
479
+
480
+ self.init_weights_update_group_result = asyncio.Future()
481
+ assert (
482
+ self.server_args.dp_size == 1
483
+ ), "dp_size must be 1 for init parameter update group"
484
+ result = await self.init_weights_update_group_result
485
+ return result.success, result.message
486
+
487
+ async def update_weights_from_distributed(
488
+ self,
489
+ obj: UpdateWeightsFromDistributedReqInput,
490
+ request: Optional[fastapi.Request] = None,
491
+ ):
492
+ if self.to_create_loop:
493
+ self.create_handle_loop()
494
+
495
+ if not self.model_update_lock.locked():
496
+ async with self.model_update_lock:
497
+ self.send_to_scheduler.send_pyobj(obj)
498
+ self.parameter_update_result = asyncio.Future()
499
+ assert (
500
+ self.server_args.dp_size == 1
501
+ ), "dp_size must be for update weights from distributed"
502
+ result = await self.parameter_update_result
503
+ return result.success, result.message
504
+ else:
505
+ logger.error("Another parameter update is in progress in tokenizer manager")
506
+ return (
507
+ False,
508
+ "Another parameter update is in progress. Please try again later.",
509
+ )
510
+
511
+ async def get_weights_by_name(
512
+ self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
513
+ ):
514
+ if self.to_create_loop:
515
+ self.create_handle_loop()
516
+
517
+ self.send_to_scheduler.send_pyobj(obj)
518
+ self.get_weights_by_name_result = asyncio.Future()
519
+ if self.server_args.dp_size == 1:
520
+ result = await self.get_weights_by_name_result
521
+ return result.parameter
522
+ else:
523
+ self.get_weights_by_name_tmp = []
524
+ result = await self.get_weights_by_name_result
525
+ all_parameters = [r.parameter for r in result]
526
+ return all_parameters
527
+
474
528
  async def open_session(
475
529
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
476
530
  ):
@@ -532,7 +586,7 @@ class TokenizerManager:
532
586
  else:
533
587
  break
534
588
 
535
- kill_child_process(include_self=True)
589
+ kill_process_tree(os.getpid(), include_parent=True)
536
590
  sys.exit(0)
537
591
 
538
592
  async def handle_loop(self):
@@ -540,10 +594,77 @@ class TokenizerManager:
540
594
 
541
595
  while True:
542
596
  recv_obj: Union[
543
- BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
597
+ BatchStrOut,
598
+ BatchEmbeddingOut,
599
+ BatchTokenIDOut,
600
+ UpdateWeightFromDiskReqOutput,
601
+ UpdateWeightsFromDistributedReqOutput,
602
+ GetWeightsByNameReqOutput,
603
+ InitWeightsUpdateGroupReqOutput,
544
604
  ] = await self.recv_from_detokenizer.recv_pyobj()
545
605
 
546
- if isinstance(recv_obj, UpdateWeightReqOutput):
606
+ if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
607
+ for i, rid in enumerate(recv_obj.rids):
608
+ state = self.rid_to_state.get(rid, None)
609
+ if state is None:
610
+ continue
611
+
612
+ recv_obj.meta_info[i]["id"] = rid
613
+ if isinstance(recv_obj, BatchStrOut):
614
+ out_dict = {
615
+ "text": recv_obj.output_strs[i],
616
+ "meta_info": recv_obj.meta_info[i],
617
+ }
618
+ elif isinstance(recv_obj, BatchTokenIDOut):
619
+ out_dict = {
620
+ "token_ids": recv_obj.output_ids[i],
621
+ "meta_info": recv_obj.meta_info[i],
622
+ }
623
+ else:
624
+ assert isinstance(recv_obj, BatchEmbeddingOut)
625
+ out_dict = {
626
+ "embedding": recv_obj.embeddings[i],
627
+ "meta_info": recv_obj.meta_info[i],
628
+ }
629
+ state.out_list.append(out_dict)
630
+ state.finished = recv_obj.finished_reason[i] is not None
631
+ state.event.set()
632
+
633
+ if self.enable_metrics:
634
+ completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
635
+
636
+ if state.first_token_time is None:
637
+ state.first_token_time = time.time()
638
+ self.metrics_collector.observe_time_to_first_token(
639
+ state.first_token_time - state.created_time
640
+ )
641
+ else:
642
+ if completion_tokens >= 2:
643
+ self.metrics_collector.observe_time_per_output_token(
644
+ (time.time() - state.first_token_time)
645
+ / (completion_tokens - 1)
646
+ )
647
+
648
+ if state.finished:
649
+ self.metrics_collector.inc_prompt_tokens(
650
+ recv_obj.meta_info[i]["prompt_tokens"]
651
+ )
652
+ self.metrics_collector.inc_generation_tokens(
653
+ completion_tokens
654
+ )
655
+ self.metrics_collector.observe_e2e_request_latency(
656
+ time.time() - state.created_time
657
+ )
658
+ if completion_tokens >= 1:
659
+ self.metrics_collector.observe_time_per_output_token(
660
+ (time.time() - state.created_time)
661
+ / completion_tokens
662
+ )
663
+ elif isinstance(recv_obj, OpenSessionReqOutput):
664
+ self.session_futures[recv_obj.session_id].set_result(
665
+ recv_obj.session_id
666
+ )
667
+ elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
547
668
  if self.server_args.dp_size == 1:
548
669
  self.model_update_result.set_result(recv_obj)
549
670
  else: # self.server_args.dp_size > 1
@@ -551,79 +672,27 @@ class TokenizerManager:
551
672
  # set future if the all results are recevied
552
673
  if len(self.model_update_tmp) == self.server_args.dp_size:
553
674
  self.model_update_result.set_result(self.model_update_tmp)
554
- continue
555
- elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
675
+ elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
676
+ assert (
677
+ self.server_args.dp_size == 1
678
+ ), "dp_size must be 1 for init parameter update group"
679
+ self.init_weights_update_group_result.set_result(recv_obj)
680
+ elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
681
+ assert (
682
+ self.server_args.dp_size == 1
683
+ ), "dp_size must be 1 for update weights from distributed"
684
+ self.parameter_update_result.set_result(recv_obj)
685
+ elif isinstance(recv_obj, GetWeightsByNameReqOutput):
556
686
  if self.server_args.dp_size == 1:
557
- self.mem_pool_size.set_result(recv_obj)
558
- else: # self.sever_args.dp_size > 1
559
- self.mem_pool_size_tmp.append(recv_obj)
560
- # set future if the all results are received
561
- if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
562
- self.mem_pool_size.set_result(self.mem_pool_size_tmp)
563
- continue
564
- elif isinstance(recv_obj, OpenSessionReqOutput):
565
- self.session_futures[recv_obj.session_id].set_result(
566
- recv_obj.session_id
567
- )
568
- continue
569
-
570
- assert isinstance(
571
- recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
572
- ), f"Unexpected obj received: {type(recv_obj)}"
573
-
574
- for i, rid in enumerate(recv_obj.rids):
575
- state = self.rid_to_state.get(rid, None)
576
- if state is None:
577
- continue
578
-
579
- recv_obj.meta_info[i]["id"] = rid
580
- if isinstance(recv_obj, BatchStrOut):
581
- out_dict = {
582
- "text": recv_obj.output_strs[i],
583
- "meta_info": recv_obj.meta_info[i],
584
- }
585
- elif isinstance(recv_obj, BatchTokenIDOut):
586
- out_dict = {
587
- "token_ids": recv_obj.output_ids[i],
588
- "meta_info": recv_obj.meta_info[i],
589
- }
687
+ self.get_weights_by_name_result.set_result(recv_obj)
590
688
  else:
591
- assert isinstance(recv_obj, BatchEmbeddingOut)
592
- out_dict = {
593
- "embedding": recv_obj.embeddings[i],
594
- "meta_info": recv_obj.meta_info[i],
595
- }
596
- state.out_list.append(out_dict)
597
- state.finished = recv_obj.finished_reason[i] is not None
598
- state.event.set()
599
-
600
- if self.enable_metrics:
601
- completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
602
-
603
- if state.first_token_time is None:
604
- state.first_token_time = time.time()
605
- self.metrics_collector.observe_time_to_first_token(
606
- state.first_token_time - state.created_time
689
+ self.get_weights_by_name_tmp.append(recv_obj)
690
+ if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
691
+ self.get_weights_by_name_result.set_result(
692
+ self.get_weights_by_name_tmp
607
693
  )
608
- else:
609
- if completion_tokens >= 2:
610
- self.metrics_collector.observe_time_per_output_token(
611
- (time.time() - state.first_token_time)
612
- / (completion_tokens - 1)
613
- )
614
-
615
- if state.finished:
616
- self.metrics_collector.inc_prompt_tokens(
617
- recv_obj.meta_info[i]["prompt_tokens"]
618
- )
619
- self.metrics_collector.inc_generation_tokens(completion_tokens)
620
- self.metrics_collector.observe_e2e_request_latency(
621
- time.time() - state.created_time
622
- )
623
- if completion_tokens >= 1:
624
- self.metrics_collector.observe_time_per_output_token(
625
- (time.time() - state.created_time) / completion_tokens
626
- )
694
+ else:
695
+ raise ValueError(f"Invalid object: {recv_obj=}")
627
696
 
628
697
  def convert_logprob_style(
629
698
  self,
@@ -19,7 +19,12 @@ from typing import Optional
19
19
 
20
20
  from sglang.srt.configs.model_config import ModelConfig
21
21
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
22
- from sglang.srt.managers.io_struct import UpdateWeightReqInput
22
+ from sglang.srt.managers.io_struct import (
23
+ GetWeightsByNameReqInput,
24
+ InitWeightsUpdateGroupReqInput,
25
+ UpdateWeightFromDiskReqInput,
26
+ UpdateWeightsFromDistributedReqInput,
27
+ )
23
28
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
24
29
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
30
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -47,9 +52,12 @@ class TpModelWorker:
47
52
  self.model_config = ModelConfig(
48
53
  server_args.model_path,
49
54
  trust_remote_code=server_args.trust_remote_code,
55
+ revision=server_args.revision,
50
56
  context_length=server_args.context_length,
51
57
  model_override_args=server_args.json_model_override_args,
52
58
  is_embedding=server_args.is_embedding,
59
+ dtype=server_args.dtype,
60
+ quantization=server_args.quantization,
53
61
  )
54
62
  self.model_runner = ModelRunner(
55
63
  model_config=self.model_config,
@@ -155,8 +163,33 @@ class TpModelWorker:
155
163
  embeddings = logits_output.embeddings
156
164
  return embeddings
157
165
 
158
- def update_weights(self, recv_req: UpdateWeightReqInput):
159
- success, message = self.model_runner.update_weights(
166
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
167
+ success, message = self.model_runner.update_weights_from_disk(
160
168
  recv_req.model_path, recv_req.load_format
161
169
  )
162
170
  return success, message
171
+
172
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
173
+ success, message = self.model_runner.init_weights_update_group(
174
+ recv_req.master_address,
175
+ recv_req.master_port,
176
+ recv_req.rank_offset,
177
+ recv_req.world_size,
178
+ recv_req.group_name,
179
+ recv_req.backend,
180
+ )
181
+ return success, message
182
+
183
+ def update_weights_from_distributed(
184
+ self, recv_req: UpdateWeightsFromDistributedReqInput
185
+ ):
186
+ success, message = self.model_runner.update_weights_from_distributed(
187
+ recv_req.name, recv_req.dtype, recv_req.shape
188
+ )
189
+ return success, message
190
+
191
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
192
+ parameter = self.model_runner.get_weights_by_name(
193
+ recv_req.name, recv_req.truncate_size
194
+ )
195
+ return parameter
@@ -15,16 +15,24 @@
15
15
 
16
16
  import dataclasses
17
17
  import logging
18
+ import signal
18
19
  import threading
19
20
  from queue import Queue
20
21
  from typing import Optional
21
22
 
23
+ import psutil
22
24
  import torch
23
25
 
24
- from sglang.srt.managers.io_struct import UpdateWeightReqInput
26
+ from sglang.srt.managers.io_struct import (
27
+ GetWeightsByNameReqInput,
28
+ InitWeightsUpdateGroupReqInput,
29
+ UpdateWeightFromDiskReqInput,
30
+ UpdateWeightsFromDistributedReqInput,
31
+ )
25
32
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
26
33
  from sglang.srt.managers.tp_worker import TpModelWorker
27
34
  from sglang.srt.server_args import ServerArgs
35
+ from sglang.utils import get_exception_traceback
28
36
 
29
37
  logger = logging.getLogger(__name__)
30
38
 
@@ -70,6 +78,7 @@ class TpModelWorkerClient:
70
78
  target=self.forward_thread_func,
71
79
  )
72
80
  self.forward_thread.start()
81
+ self.parent_process = psutil.Process().parent()
73
82
 
74
83
  def get_worker_info(self):
75
84
  return self.worker.get_worker_info()
@@ -87,8 +96,13 @@ class TpModelWorkerClient:
87
96
  )
88
97
 
89
98
  def forward_thread_func(self):
90
- with torch.cuda.stream(self.forward_stream):
91
- self.forward_thread_func_()
99
+ try:
100
+ with torch.cuda.stream(self.forward_stream):
101
+ self.forward_thread_func_()
102
+ except Exception:
103
+ traceback = get_exception_traceback()
104
+ logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
105
+ self.parent_process.send_signal(signal.SIGQUIT)
92
106
 
93
107
  @torch.no_grad()
94
108
  def forward_thread_func_(self):
@@ -195,10 +209,23 @@ class TpModelWorkerClient:
195
209
  ) % self.future_token_ids_limit
196
210
  return None, future_next_token_ids
197
211
 
198
- def update_weights(self, recv_req: UpdateWeightReqInput):
199
- success, message = self.worker.update_weights(recv_req)
212
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
213
+ success, message = self.worker.update_weights_from_disk(recv_req)
200
214
  return success, message
201
215
 
216
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
217
+ success, message = self.worker.init_weights_update_group(recv_req)
218
+ return success, message
219
+
220
+ def update_weights_from_distributed(
221
+ self, recv_req: UpdateWeightsFromDistributedReqInput
222
+ ):
223
+ success, message = self.worker.update_weights_from_distributed(recv_req)
224
+ return success, message
225
+
226
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
227
+ return self.worker.get_weights_by_name(recv_req)
228
+
202
229
  def __delete__(self):
203
230
  self.input_queue.put((None, None))
204
231
  self.copy_queue.put((None, None, None))
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
36
36
  from sglang.srt.model_executor.model_runner import ModelRunner
37
37
 
38
38
 
39
- def _to_torch(model: torch.nn.Module, reverse: bool = False):
39
+ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
40
40
  for sub in model._modules.values():
41
41
  if isinstance(sub, CustomOp):
42
42
  if reverse:
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
45
45
  else:
46
46
  # NOTE: Temporarily workaround MoE
47
47
  if "FusedMoE" in sub.__class__.__name__:
48
- sub._forward_method = fused_moe_forward_native
48
+ if batch_size == 1:
49
+ # The performance of torch.compile on this layer is not always good when bs > 1,
50
+ # so we decide to skip it for now.
51
+ sub._forward_method = fused_moe_forward_native
49
52
  else:
50
53
  sub._forward_method = sub.forward_native
51
54
  setattr(sub, "is_torch_compile", True)
52
55
  if isinstance(sub, torch.nn.Module):
53
- _to_torch(sub, reverse)
56
+ _to_torch(sub, reverse, batch_size)
54
57
 
55
58
 
56
59
  @contextmanager
57
60
  def patch_model(
58
- model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
61
+ model: torch.nn.Module,
62
+ enable_compile: bool,
63
+ batch_size: int,
64
+ tp_group: "GroupCoordinator",
59
65
  ):
60
66
  """Patch the model to make it compatible with with torch.compile"""
61
67
  backup_ca_comm = None
62
68
 
63
69
  try:
64
70
  if enable_compile:
65
- _to_torch(model)
71
+ _to_torch(model, reverse=False, batch_size=batch_size)
66
72
  monkey_patch_vllm_all_gather()
67
73
  backup_ca_comm = tp_group.ca_comm
68
74
  # Use custom-allreduce here.
@@ -70,13 +76,15 @@ def patch_model(
70
76
  # even with ENABLE_INTRA_NODE_COMM=1.
71
77
  # tp_group.ca_comm = None
72
78
  yield torch.compile(
73
- torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
79
+ torch.no_grad()(model.forward),
80
+ mode="max-autotune-no-cudagraphs",
81
+ dynamic=False,
74
82
  )
75
83
  else:
76
84
  yield model.forward
77
85
  finally:
78
86
  if enable_compile:
79
- _to_torch(model, reverse=True)
87
+ _to_torch(model, reverse=True, batch_size=batch_size)
80
88
  monkey_patch_vllm_all_gather(reverse=True)
81
89
  tp_group.ca_comm = backup_ca_comm
82
90
 
@@ -237,6 +245,7 @@ class CudaGraphRunner:
237
245
  with patch_model(
238
246
  self.model_runner.model,
239
247
  bs in self.compile_bs,
248
+ bs,
240
249
  self.model_runner.tp_group,
241
250
  ) as forward:
242
251
  (
@@ -256,10 +256,15 @@ class ForwardBatch:
256
256
  ret.extend_prefix_lens = torch.tensor(
257
257
  batch.extend_prefix_lens, dtype=torch.int32
258
258
  ).to(device, non_blocking=True)
259
- ret.extend_num_tokens = batch.extend_num_tokens
260
- ret.positions, ret.extend_start_loc = compute_position_triton(
261
- ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
262
- )
259
+ if model_runner.server_args.attention_backend != "torch_native":
260
+ ret.extend_num_tokens = batch.extend_num_tokens
261
+ ret.positions, ret.extend_start_loc = compute_position_triton(
262
+ ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
263
+ )
264
+ else:
265
+ ret.positions, ret.extend_start_loc = compute_position_torch(
266
+ ret.extend_prefix_lens, ret.extend_seq_lens
267
+ )
263
268
  ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
264
269
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
265
270
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens