sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -45,13 +45,19 @@ from sglang.srt.managers.io_struct import (
45
45
  EmbeddingReqInput,
46
46
  FlushCacheReq,
47
47
  GenerateReqInput,
48
+ GetWeightsByNameReqInput,
49
+ GetWeightsByNameReqOutput,
50
+ InitWeightsUpdateGroupReqInput,
51
+ InitWeightsUpdateGroupReqOutput,
48
52
  OpenSessionReqInput,
49
53
  OpenSessionReqOutput,
50
54
  ProfileReq,
51
55
  TokenizedEmbeddingReqInput,
52
56
  TokenizedGenerateReqInput,
53
- UpdateWeightReqInput,
54
- UpdateWeightReqOutput,
57
+ UpdateWeightFromDiskReqInput,
58
+ UpdateWeightFromDiskReqOutput,
59
+ UpdateWeightsFromDistributedReqInput,
60
+ UpdateWeightsFromDistributedReqOutput,
55
61
  )
56
62
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
57
63
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -103,9 +109,12 @@ class TokenizerManager:
103
109
  self.model_config = ModelConfig(
104
110
  server_args.model_path,
105
111
  trust_remote_code=server_args.trust_remote_code,
112
+ revision=server_args.revision,
106
113
  context_length=server_args.context_length,
107
114
  model_override_args=server_args.json_model_override_args,
108
115
  is_embedding=server_args.is_embedding,
116
+ dtype=server_args.dtype,
117
+ quantization=server_args.quantization,
109
118
  )
110
119
 
111
120
  self.is_generation = self.model_config.is_generation
@@ -330,6 +339,12 @@ class TokenizerManager:
330
339
  rids.append(tmp_obj.rid)
331
340
  else:
332
341
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
342
+ if batch_size > 128:
343
+ logger.warning(
344
+ "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
345
+ "The performance might be better if you just duplicate the requests n times or use "
346
+ "many threads to send them one by one with parallel sampling (n > 1)."
347
+ )
333
348
 
334
349
  # Tokenize all requests
335
350
  objs = [obj[i] for i in range(batch_size)]
@@ -405,8 +420,10 @@ class TokenizerManager:
405
420
  req = ProfileReq.STOP_PROFILE
406
421
  self.send_to_scheduler.send_pyobj(req)
407
422
 
408
- async def update_weights(
409
- self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
423
+ async def update_weights_from_disk(
424
+ self,
425
+ obj: UpdateWeightFromDiskReqInput,
426
+ request: Optional[fastapi.Request] = None,
410
427
  ):
411
428
  if self.to_create_loop:
412
429
  self.create_handle_loop()
@@ -451,6 +468,63 @@ class TokenizerManager:
451
468
  else:
452
469
  return False, "Another update is in progress. Please try again later."
453
470
 
471
+ async def init_weights_update_group(
472
+ self,
473
+ obj: InitWeightsUpdateGroupReqInput,
474
+ request: Optional[fastapi.Request] = None,
475
+ ) -> bool:
476
+ if self.to_create_loop:
477
+ self.create_handle_loop()
478
+ self.send_to_scheduler.send_pyobj(obj)
479
+
480
+ self.init_weights_update_group_result = asyncio.Future()
481
+ assert (
482
+ self.server_args.dp_size == 1
483
+ ), "dp_size must be 1 for init parameter update group"
484
+ result = await self.init_weights_update_group_result
485
+ return result.success, result.message
486
+
487
+ async def update_weights_from_distributed(
488
+ self,
489
+ obj: UpdateWeightsFromDistributedReqInput,
490
+ request: Optional[fastapi.Request] = None,
491
+ ):
492
+ if self.to_create_loop:
493
+ self.create_handle_loop()
494
+
495
+ if not self.model_update_lock.locked():
496
+ async with self.model_update_lock:
497
+ self.send_to_scheduler.send_pyobj(obj)
498
+ self.parameter_update_result = asyncio.Future()
499
+ assert (
500
+ self.server_args.dp_size == 1
501
+ ), "dp_size must be for update weights from distributed"
502
+ result = await self.parameter_update_result
503
+ return result.success, result.message
504
+ else:
505
+ logger.error("Another parameter update is in progress in tokenizer manager")
506
+ return (
507
+ False,
508
+ "Another parameter update is in progress. Please try again later.",
509
+ )
510
+
511
+ async def get_weights_by_name(
512
+ self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
513
+ ):
514
+ if self.to_create_loop:
515
+ self.create_handle_loop()
516
+
517
+ self.send_to_scheduler.send_pyobj(obj)
518
+ self.get_weights_by_name_result = asyncio.Future()
519
+ if self.server_args.dp_size == 1:
520
+ result = await self.get_weights_by_name_result
521
+ return result.parameter
522
+ else:
523
+ self.get_weights_by_name_tmp = []
524
+ result = await self.get_weights_by_name_result
525
+ all_parameters = [r.parameter for r in result]
526
+ return all_parameters
527
+
454
528
  async def open_session(
455
529
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
456
530
  ):
@@ -520,10 +594,77 @@ class TokenizerManager:
520
594
 
521
595
  while True:
522
596
  recv_obj: Union[
523
- BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
597
+ BatchStrOut,
598
+ BatchEmbeddingOut,
599
+ BatchTokenIDOut,
600
+ UpdateWeightFromDiskReqOutput,
601
+ UpdateWeightsFromDistributedReqOutput,
602
+ GetWeightsByNameReqOutput,
603
+ InitWeightsUpdateGroupReqOutput,
524
604
  ] = await self.recv_from_detokenizer.recv_pyobj()
525
605
 
526
- if isinstance(recv_obj, UpdateWeightReqOutput):
606
+ if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
607
+ for i, rid in enumerate(recv_obj.rids):
608
+ state = self.rid_to_state.get(rid, None)
609
+ if state is None:
610
+ continue
611
+
612
+ recv_obj.meta_info[i]["id"] = rid
613
+ if isinstance(recv_obj, BatchStrOut):
614
+ out_dict = {
615
+ "text": recv_obj.output_strs[i],
616
+ "meta_info": recv_obj.meta_info[i],
617
+ }
618
+ elif isinstance(recv_obj, BatchTokenIDOut):
619
+ out_dict = {
620
+ "token_ids": recv_obj.output_ids[i],
621
+ "meta_info": recv_obj.meta_info[i],
622
+ }
623
+ else:
624
+ assert isinstance(recv_obj, BatchEmbeddingOut)
625
+ out_dict = {
626
+ "embedding": recv_obj.embeddings[i],
627
+ "meta_info": recv_obj.meta_info[i],
628
+ }
629
+ state.out_list.append(out_dict)
630
+ state.finished = recv_obj.finished_reason[i] is not None
631
+ state.event.set()
632
+
633
+ if self.enable_metrics:
634
+ completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
635
+
636
+ if state.first_token_time is None:
637
+ state.first_token_time = time.time()
638
+ self.metrics_collector.observe_time_to_first_token(
639
+ state.first_token_time - state.created_time
640
+ )
641
+ else:
642
+ if completion_tokens >= 2:
643
+ self.metrics_collector.observe_time_per_output_token(
644
+ (time.time() - state.first_token_time)
645
+ / (completion_tokens - 1)
646
+ )
647
+
648
+ if state.finished:
649
+ self.metrics_collector.inc_prompt_tokens(
650
+ recv_obj.meta_info[i]["prompt_tokens"]
651
+ )
652
+ self.metrics_collector.inc_generation_tokens(
653
+ completion_tokens
654
+ )
655
+ self.metrics_collector.observe_e2e_request_latency(
656
+ time.time() - state.created_time
657
+ )
658
+ if completion_tokens >= 1:
659
+ self.metrics_collector.observe_time_per_output_token(
660
+ (time.time() - state.created_time)
661
+ / completion_tokens
662
+ )
663
+ elif isinstance(recv_obj, OpenSessionReqOutput):
664
+ self.session_futures[recv_obj.session_id].set_result(
665
+ recv_obj.session_id
666
+ )
667
+ elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
527
668
  if self.server_args.dp_size == 1:
528
669
  self.model_update_result.set_result(recv_obj)
529
670
  else: # self.server_args.dp_size > 1
@@ -531,70 +672,27 @@ class TokenizerManager:
531
672
  # set future if the all results are recevied
532
673
  if len(self.model_update_tmp) == self.server_args.dp_size:
533
674
  self.model_update_result.set_result(self.model_update_tmp)
534
- continue
535
- elif isinstance(recv_obj, OpenSessionReqOutput):
536
- self.session_futures[recv_obj.session_id].set_result(
537
- recv_obj.session_id
538
- )
539
- continue
540
-
541
- assert isinstance(
542
- recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
543
- ), f"Unexpected obj received: {type(recv_obj)}"
544
-
545
- for i, rid in enumerate(recv_obj.rids):
546
- state = self.rid_to_state.get(rid, None)
547
- if state is None:
548
- continue
549
-
550
- recv_obj.meta_info[i]["id"] = rid
551
- if isinstance(recv_obj, BatchStrOut):
552
- out_dict = {
553
- "text": recv_obj.output_strs[i],
554
- "meta_info": recv_obj.meta_info[i],
555
- }
556
- elif isinstance(recv_obj, BatchTokenIDOut):
557
- out_dict = {
558
- "token_ids": recv_obj.output_ids[i],
559
- "meta_info": recv_obj.meta_info[i],
560
- }
675
+ elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
676
+ assert (
677
+ self.server_args.dp_size == 1
678
+ ), "dp_size must be 1 for init parameter update group"
679
+ self.init_weights_update_group_result.set_result(recv_obj)
680
+ elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
681
+ assert (
682
+ self.server_args.dp_size == 1
683
+ ), "dp_size must be 1 for update weights from distributed"
684
+ self.parameter_update_result.set_result(recv_obj)
685
+ elif isinstance(recv_obj, GetWeightsByNameReqOutput):
686
+ if self.server_args.dp_size == 1:
687
+ self.get_weights_by_name_result.set_result(recv_obj)
561
688
  else:
562
- assert isinstance(recv_obj, BatchEmbeddingOut)
563
- out_dict = {
564
- "embedding": recv_obj.embeddings[i],
565
- "meta_info": recv_obj.meta_info[i],
566
- }
567
- state.out_list.append(out_dict)
568
- state.finished = recv_obj.finished_reason[i] is not None
569
- state.event.set()
570
-
571
- if self.enable_metrics:
572
- completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
573
-
574
- if state.first_token_time is None:
575
- state.first_token_time = time.time()
576
- self.metrics_collector.observe_time_to_first_token(
577
- state.first_token_time - state.created_time
689
+ self.get_weights_by_name_tmp.append(recv_obj)
690
+ if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
691
+ self.get_weights_by_name_result.set_result(
692
+ self.get_weights_by_name_tmp
578
693
  )
579
- else:
580
- if completion_tokens >= 2:
581
- self.metrics_collector.observe_time_per_output_token(
582
- (time.time() - state.first_token_time)
583
- / (completion_tokens - 1)
584
- )
585
-
586
- if state.finished:
587
- self.metrics_collector.inc_prompt_tokens(
588
- recv_obj.meta_info[i]["prompt_tokens"]
589
- )
590
- self.metrics_collector.inc_generation_tokens(completion_tokens)
591
- self.metrics_collector.observe_e2e_request_latency(
592
- time.time() - state.created_time
593
- )
594
- if completion_tokens >= 1:
595
- self.metrics_collector.observe_time_per_output_token(
596
- (time.time() - state.created_time) / completion_tokens
597
- )
694
+ else:
695
+ raise ValueError(f"Invalid object: {recv_obj=}")
598
696
 
599
697
  def convert_logprob_style(
600
698
  self,
@@ -19,7 +19,12 @@ from typing import Optional
19
19
 
20
20
  from sglang.srt.configs.model_config import ModelConfig
21
21
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
22
- from sglang.srt.managers.io_struct import UpdateWeightReqInput
22
+ from sglang.srt.managers.io_struct import (
23
+ GetWeightsByNameReqInput,
24
+ InitWeightsUpdateGroupReqInput,
25
+ UpdateWeightFromDiskReqInput,
26
+ UpdateWeightsFromDistributedReqInput,
27
+ )
23
28
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
24
29
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
30
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -47,9 +52,12 @@ class TpModelWorker:
47
52
  self.model_config = ModelConfig(
48
53
  server_args.model_path,
49
54
  trust_remote_code=server_args.trust_remote_code,
55
+ revision=server_args.revision,
50
56
  context_length=server_args.context_length,
51
57
  model_override_args=server_args.json_model_override_args,
52
58
  is_embedding=server_args.is_embedding,
59
+ dtype=server_args.dtype,
60
+ quantization=server_args.quantization,
53
61
  )
54
62
  self.model_runner = ModelRunner(
55
63
  model_config=self.model_config,
@@ -155,8 +163,33 @@ class TpModelWorker:
155
163
  embeddings = logits_output.embeddings
156
164
  return embeddings
157
165
 
158
- def update_weights(self, recv_req: UpdateWeightReqInput):
159
- success, message = self.model_runner.update_weights(
166
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
167
+ success, message = self.model_runner.update_weights_from_disk(
160
168
  recv_req.model_path, recv_req.load_format
161
169
  )
162
170
  return success, message
171
+
172
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
173
+ success, message = self.model_runner.init_weights_update_group(
174
+ recv_req.master_address,
175
+ recv_req.master_port,
176
+ recv_req.rank_offset,
177
+ recv_req.world_size,
178
+ recv_req.group_name,
179
+ recv_req.backend,
180
+ )
181
+ return success, message
182
+
183
+ def update_weights_from_distributed(
184
+ self, recv_req: UpdateWeightsFromDistributedReqInput
185
+ ):
186
+ success, message = self.model_runner.update_weights_from_distributed(
187
+ recv_req.name, recv_req.dtype, recv_req.shape
188
+ )
189
+ return success, message
190
+
191
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
192
+ parameter = self.model_runner.get_weights_by_name(
193
+ recv_req.name, recv_req.truncate_size
194
+ )
195
+ return parameter
@@ -23,16 +23,22 @@ from typing import Optional
23
23
  import psutil
24
24
  import torch
25
25
 
26
- from sglang.srt.managers.io_struct import UpdateWeightReqInput
26
+ from sglang.srt.managers.io_struct import (
27
+ GetWeightsByNameReqInput,
28
+ InitWeightsUpdateGroupReqInput,
29
+ UpdateWeightFromDiskReqInput,
30
+ UpdateWeightsFromDistributedReqInput,
31
+ )
27
32
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
28
33
  from sglang.srt.managers.tp_worker import TpModelWorker
29
34
  from sglang.srt.server_args import ServerArgs
35
+ from sglang.srt.utils import get_compiler_backend
30
36
  from sglang.utils import get_exception_traceback
31
37
 
32
38
  logger = logging.getLogger(__name__)
33
39
 
34
40
 
35
- @torch.compile(dynamic=True)
41
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
36
42
  def resolve_future_token_ids(input_ids, future_token_ids_map):
37
43
  input_ids[:] = torch.where(
38
44
  input_ids < 0,
@@ -68,12 +74,13 @@ class TpModelWorkerClient:
68
74
  # Launch threads
69
75
  self.input_queue = Queue()
70
76
  self.output_queue = Queue()
71
- self.forward_stream = torch.cuda.Stream()
77
+ self.forward_stream = torch.get_device_module(self.device).Stream()
72
78
  self.forward_thread = threading.Thread(
73
79
  target=self.forward_thread_func,
74
80
  )
75
81
  self.forward_thread.start()
76
82
  self.parent_process = psutil.Process().parent()
83
+ self.scheduler_stream = torch.get_device_module(self.device).current_stream()
77
84
 
78
85
  def get_worker_info(self):
79
86
  return self.worker.get_worker_info()
@@ -92,7 +99,7 @@ class TpModelWorkerClient:
92
99
 
93
100
  def forward_thread_func(self):
94
101
  try:
95
- with torch.cuda.stream(self.forward_stream):
102
+ with torch.get_device_module(self.device).stream(self.forward_stream):
96
103
  self.forward_thread_func_()
97
104
  except Exception:
98
105
  traceback = get_exception_traceback()
@@ -117,7 +124,7 @@ class TpModelWorkerClient:
117
124
 
118
125
  # Create event
119
126
  self.launch_done = threading.Event()
120
- copy_done = torch.cuda.Event()
127
+ copy_done = torch.get_device_module(self.device).Event()
121
128
 
122
129
  # Resolve future tokens in the input
123
130
  input_ids = model_worker_batch.input_ids
@@ -185,7 +192,7 @@ class TpModelWorkerClient:
185
192
  )
186
193
 
187
194
  # A cuda stream sync here to avoid the cuda illegal memory access error.
188
- torch.cuda.current_stream().synchronize()
195
+ self.scheduler_stream.synchronize()
189
196
 
190
197
  # Push a new batch to the queue
191
198
  self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
@@ -204,10 +211,23 @@ class TpModelWorkerClient:
204
211
  ) % self.future_token_ids_limit
205
212
  return None, future_next_token_ids
206
213
 
207
- def update_weights(self, recv_req: UpdateWeightReqInput):
208
- success, message = self.worker.update_weights(recv_req)
214
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
215
+ success, message = self.worker.update_weights_from_disk(recv_req)
209
216
  return success, message
210
217
 
218
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
219
+ success, message = self.worker.init_weights_update_group(recv_req)
220
+ return success, message
221
+
222
+ def update_weights_from_distributed(
223
+ self, recv_req: UpdateWeightsFromDistributedReqInput
224
+ ):
225
+ success, message = self.worker.update_weights_from_distributed(recv_req)
226
+ return success, message
227
+
228
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
229
+ return self.worker.get_weights_by_name(recv_req)
230
+
211
231
  def __delete__(self):
212
232
  self.input_queue.put((None, None))
213
233
  self.copy_queue.put((None, None, None))
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
27
27
  import torch
28
28
 
29
29
  from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.utils import get_compiler_backend
30
31
 
31
32
  logger = logging.getLogger(__name__)
32
33
 
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
129
130
  return select_index.to(self.device, non_blocking=True)
130
131
 
131
132
  def free(self, free_index: torch.Tensor):
133
+ if free_index.numel() == 0:
134
+ return
135
+
132
136
  if self.is_not_in_free_group:
133
137
  self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
134
138
  else:
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
234
238
 
235
239
  # This compiled version is slower in the unit test
236
240
  # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
237
- @torch.compile(dynamic=True)
241
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
238
242
  def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
239
243
  dst_1[loc] = src_1.to(dtype).view(store_dtype)
240
244
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
36
36
  from sglang.srt.model_executor.model_runner import ModelRunner
37
37
 
38
38
 
39
- def _to_torch(model: torch.nn.Module, reverse: bool = False):
39
+ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
40
40
  for sub in model._modules.values():
41
41
  if isinstance(sub, CustomOp):
42
42
  if reverse:
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
45
45
  else:
46
46
  # NOTE: Temporarily workaround MoE
47
47
  if "FusedMoE" in sub.__class__.__name__:
48
- sub._forward_method = fused_moe_forward_native
48
+ if batch_size == 1:
49
+ # The performance of torch.compile on this layer is not always good when bs > 1,
50
+ # so we decide to only use torch.compile when bs =1
51
+ sub._forward_method = fused_moe_forward_native
49
52
  else:
50
53
  sub._forward_method = sub.forward_native
51
54
  setattr(sub, "is_torch_compile", True)
52
55
  if isinstance(sub, torch.nn.Module):
53
- _to_torch(sub, reverse)
56
+ _to_torch(sub, reverse, batch_size)
54
57
 
55
58
 
56
59
  @contextmanager
57
60
  def patch_model(
58
- model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
61
+ model: torch.nn.Module,
62
+ enable_compile: bool,
63
+ batch_size: int,
64
+ tp_group: "GroupCoordinator",
59
65
  ):
60
66
  """Patch the model to make it compatible with with torch.compile"""
61
67
  backup_ca_comm = None
62
68
 
63
69
  try:
64
70
  if enable_compile:
65
- _to_torch(model)
71
+ _to_torch(model, reverse=False, batch_size=batch_size)
66
72
  monkey_patch_vllm_all_gather()
67
73
  backup_ca_comm = tp_group.ca_comm
68
74
  # Use custom-allreduce here.
@@ -70,13 +76,15 @@ def patch_model(
70
76
  # even with ENABLE_INTRA_NODE_COMM=1.
71
77
  # tp_group.ca_comm = None
72
78
  yield torch.compile(
73
- torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
79
+ torch.no_grad()(model.forward),
80
+ mode="max-autotune-no-cudagraphs",
81
+ dynamic=False,
74
82
  )
75
83
  else:
76
84
  yield model.forward
77
85
  finally:
78
86
  if enable_compile:
79
- _to_torch(model, reverse=True)
87
+ _to_torch(model, reverse=True, batch_size=batch_size)
80
88
  monkey_patch_vllm_all_gather(reverse=True)
81
89
  tp_group.ca_comm = backup_ca_comm
82
90
 
@@ -122,6 +130,20 @@ class CudaGraphRunner:
122
130
  self.capture_bs = list(range(1, 32)) + [64, 128]
123
131
  else:
124
132
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
133
+
134
+ if max(self.capture_bs) > model_runner.req_to_token_pool.size:
135
+ # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
136
+ # is very samll. We add more values here to make sure we capture the maximum bs.
137
+ self.capture_bs = list(
138
+ sorted(
139
+ set(
140
+ self.capture_bs
141
+ + [model_runner.req_to_token_pool.size - 1]
142
+ + [model_runner.req_to_token_pool.size]
143
+ )
144
+ )
145
+ )
146
+
125
147
  self.capture_bs = [
126
148
  bs
127
149
  for bs in self.capture_bs
@@ -237,6 +259,7 @@ class CudaGraphRunner:
237
259
  with patch_model(
238
260
  self.model_runner.model,
239
261
  bs in self.compile_bs,
262
+ bs,
240
263
  self.model_runner.tp_group,
241
264
  ) as forward:
242
265
  (
@@ -256,10 +256,15 @@ class ForwardBatch:
256
256
  ret.extend_prefix_lens = torch.tensor(
257
257
  batch.extend_prefix_lens, dtype=torch.int32
258
258
  ).to(device, non_blocking=True)
259
- ret.extend_num_tokens = batch.extend_num_tokens
260
- ret.positions, ret.extend_start_loc = compute_position_triton(
261
- ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
262
- )
259
+ if model_runner.server_args.attention_backend != "torch_native":
260
+ ret.extend_num_tokens = batch.extend_num_tokens
261
+ ret.positions, ret.extend_start_loc = compute_position_triton(
262
+ ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
263
+ )
264
+ else:
265
+ ret.positions, ret.extend_start_loc = compute_position_torch(
266
+ ret.extend_prefix_lens, ret.extend_seq_lens
267
+ )
263
268
  ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
264
269
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
265
270
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens