sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -45,13 +45,19 @@ from sglang.srt.managers.io_struct import (
45
45
  EmbeddingReqInput,
46
46
  FlushCacheReq,
47
47
  GenerateReqInput,
48
+ GetWeightsByNameReqInput,
49
+ GetWeightsByNameReqOutput,
50
+ InitWeightsUpdateGroupReqInput,
51
+ InitWeightsUpdateGroupReqOutput,
48
52
  OpenSessionReqInput,
49
53
  OpenSessionReqOutput,
50
54
  ProfileReq,
51
55
  TokenizedEmbeddingReqInput,
52
56
  TokenizedGenerateReqInput,
53
- UpdateWeightReqInput,
54
- UpdateWeightReqOutput,
57
+ UpdateWeightFromDiskReqInput,
58
+ UpdateWeightFromDiskReqOutput,
59
+ UpdateWeightsFromDistributedReqInput,
60
+ UpdateWeightsFromDistributedReqOutput,
55
61
  )
56
62
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
57
63
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -103,9 +109,12 @@ class TokenizerManager:
103
109
  self.model_config = ModelConfig(
104
110
  server_args.model_path,
105
111
  trust_remote_code=server_args.trust_remote_code,
112
+ revision=server_args.revision,
106
113
  context_length=server_args.context_length,
107
114
  model_override_args=server_args.json_model_override_args,
108
115
  is_embedding=server_args.is_embedding,
116
+ dtype=server_args.dtype,
117
+ quantization=server_args.quantization,
109
118
  )
110
119
 
111
120
  self.is_generation = self.model_config.is_generation
@@ -330,6 +339,12 @@ class TokenizerManager:
330
339
  rids.append(tmp_obj.rid)
331
340
  else:
332
341
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
342
+ if batch_size > 128:
343
+ logger.warning(
344
+ "Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
345
+ "The performance might be better if you just duplicate the requests n times or use "
346
+ "many threads to send them one by one with parallel sampling (n > 1)."
347
+ )
333
348
 
334
349
  # Tokenize all requests
335
350
  objs = [obj[i] for i in range(batch_size)]
@@ -405,8 +420,10 @@ class TokenizerManager:
405
420
  req = ProfileReq.STOP_PROFILE
406
421
  self.send_to_scheduler.send_pyobj(req)
407
422
 
408
- async def update_weights(
409
- self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
423
+ async def update_weights_from_disk(
424
+ self,
425
+ obj: UpdateWeightFromDiskReqInput,
426
+ request: Optional[fastapi.Request] = None,
410
427
  ):
411
428
  if self.to_create_loop:
412
429
  self.create_handle_loop()
@@ -451,6 +468,63 @@ class TokenizerManager:
451
468
  else:
452
469
  return False, "Another update is in progress. Please try again later."
453
470
 
471
+ async def init_weights_update_group(
472
+ self,
473
+ obj: InitWeightsUpdateGroupReqInput,
474
+ request: Optional[fastapi.Request] = None,
475
+ ) -> bool:
476
+ if self.to_create_loop:
477
+ self.create_handle_loop()
478
+ self.send_to_scheduler.send_pyobj(obj)
479
+
480
+ self.init_weights_update_group_result = asyncio.Future()
481
+ assert (
482
+ self.server_args.dp_size == 1
483
+ ), "dp_size must be 1 for init parameter update group"
484
+ result = await self.init_weights_update_group_result
485
+ return result.success, result.message
486
+
487
+ async def update_weights_from_distributed(
488
+ self,
489
+ obj: UpdateWeightsFromDistributedReqInput,
490
+ request: Optional[fastapi.Request] = None,
491
+ ):
492
+ if self.to_create_loop:
493
+ self.create_handle_loop()
494
+
495
+ if not self.model_update_lock.locked():
496
+ async with self.model_update_lock:
497
+ self.send_to_scheduler.send_pyobj(obj)
498
+ self.parameter_update_result = asyncio.Future()
499
+ assert (
500
+ self.server_args.dp_size == 1
501
+ ), "dp_size must be for update weights from distributed"
502
+ result = await self.parameter_update_result
503
+ return result.success, result.message
504
+ else:
505
+ logger.error("Another parameter update is in progress in tokenizer manager")
506
+ return (
507
+ False,
508
+ "Another parameter update is in progress. Please try again later.",
509
+ )
510
+
511
+ async def get_weights_by_name(
512
+ self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
513
+ ):
514
+ if self.to_create_loop:
515
+ self.create_handle_loop()
516
+
517
+ self.send_to_scheduler.send_pyobj(obj)
518
+ self.get_weights_by_name_result = asyncio.Future()
519
+ if self.server_args.dp_size == 1:
520
+ result = await self.get_weights_by_name_result
521
+ return result.parameter
522
+ else:
523
+ self.get_weights_by_name_tmp = []
524
+ result = await self.get_weights_by_name_result
525
+ all_parameters = [r.parameter for r in result]
526
+ return all_parameters
527
+
454
528
  async def open_session(
455
529
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
456
530
  ):
@@ -520,10 +594,77 @@ class TokenizerManager:
520
594
 
521
595
  while True:
522
596
  recv_obj: Union[
523
- BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput
597
+ BatchStrOut,
598
+ BatchEmbeddingOut,
599
+ BatchTokenIDOut,
600
+ UpdateWeightFromDiskReqOutput,
601
+ UpdateWeightsFromDistributedReqOutput,
602
+ GetWeightsByNameReqOutput,
603
+ InitWeightsUpdateGroupReqOutput,
524
604
  ] = await self.recv_from_detokenizer.recv_pyobj()
525
605
 
526
- if isinstance(recv_obj, UpdateWeightReqOutput):
606
+ if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
607
+ for i, rid in enumerate(recv_obj.rids):
608
+ state = self.rid_to_state.get(rid, None)
609
+ if state is None:
610
+ continue
611
+
612
+ recv_obj.meta_info[i]["id"] = rid
613
+ if isinstance(recv_obj, BatchStrOut):
614
+ out_dict = {
615
+ "text": recv_obj.output_strs[i],
616
+ "meta_info": recv_obj.meta_info[i],
617
+ }
618
+ elif isinstance(recv_obj, BatchTokenIDOut):
619
+ out_dict = {
620
+ "token_ids": recv_obj.output_ids[i],
621
+ "meta_info": recv_obj.meta_info[i],
622
+ }
623
+ else:
624
+ assert isinstance(recv_obj, BatchEmbeddingOut)
625
+ out_dict = {
626
+ "embedding": recv_obj.embeddings[i],
627
+ "meta_info": recv_obj.meta_info[i],
628
+ }
629
+ state.out_list.append(out_dict)
630
+ state.finished = recv_obj.finished_reason[i] is not None
631
+ state.event.set()
632
+
633
+ if self.enable_metrics:
634
+ completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
635
+
636
+ if state.first_token_time is None:
637
+ state.first_token_time = time.time()
638
+ self.metrics_collector.observe_time_to_first_token(
639
+ state.first_token_time - state.created_time
640
+ )
641
+ else:
642
+ if completion_tokens >= 2:
643
+ self.metrics_collector.observe_time_per_output_token(
644
+ (time.time() - state.first_token_time)
645
+ / (completion_tokens - 1)
646
+ )
647
+
648
+ if state.finished:
649
+ self.metrics_collector.inc_prompt_tokens(
650
+ recv_obj.meta_info[i]["prompt_tokens"]
651
+ )
652
+ self.metrics_collector.inc_generation_tokens(
653
+ completion_tokens
654
+ )
655
+ self.metrics_collector.observe_e2e_request_latency(
656
+ time.time() - state.created_time
657
+ )
658
+ if completion_tokens >= 1:
659
+ self.metrics_collector.observe_time_per_output_token(
660
+ (time.time() - state.created_time)
661
+ / completion_tokens
662
+ )
663
+ elif isinstance(recv_obj, OpenSessionReqOutput):
664
+ self.session_futures[recv_obj.session_id].set_result(
665
+ recv_obj.session_id
666
+ )
667
+ elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
527
668
  if self.server_args.dp_size == 1:
528
669
  self.model_update_result.set_result(recv_obj)
529
670
  else: # self.server_args.dp_size > 1
@@ -531,70 +672,27 @@ class TokenizerManager:
531
672
  # set future if the all results are recevied
532
673
  if len(self.model_update_tmp) == self.server_args.dp_size:
533
674
  self.model_update_result.set_result(self.model_update_tmp)
534
- continue
535
- elif isinstance(recv_obj, OpenSessionReqOutput):
536
- self.session_futures[recv_obj.session_id].set_result(
537
- recv_obj.session_id
538
- )
539
- continue
540
-
541
- assert isinstance(
542
- recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
543
- ), f"Unexpected obj received: {type(recv_obj)}"
544
-
545
- for i, rid in enumerate(recv_obj.rids):
546
- state = self.rid_to_state.get(rid, None)
547
- if state is None:
548
- continue
549
-
550
- recv_obj.meta_info[i]["id"] = rid
551
- if isinstance(recv_obj, BatchStrOut):
552
- out_dict = {
553
- "text": recv_obj.output_strs[i],
554
- "meta_info": recv_obj.meta_info[i],
555
- }
556
- elif isinstance(recv_obj, BatchTokenIDOut):
557
- out_dict = {
558
- "token_ids": recv_obj.output_ids[i],
559
- "meta_info": recv_obj.meta_info[i],
560
- }
675
+ elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
676
+ assert (
677
+ self.server_args.dp_size == 1
678
+ ), "dp_size must be 1 for init parameter update group"
679
+ self.init_weights_update_group_result.set_result(recv_obj)
680
+ elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
681
+ assert (
682
+ self.server_args.dp_size == 1
683
+ ), "dp_size must be 1 for update weights from distributed"
684
+ self.parameter_update_result.set_result(recv_obj)
685
+ elif isinstance(recv_obj, GetWeightsByNameReqOutput):
686
+ if self.server_args.dp_size == 1:
687
+ self.get_weights_by_name_result.set_result(recv_obj)
561
688
  else:
562
- assert isinstance(recv_obj, BatchEmbeddingOut)
563
- out_dict = {
564
- "embedding": recv_obj.embeddings[i],
565
- "meta_info": recv_obj.meta_info[i],
566
- }
567
- state.out_list.append(out_dict)
568
- state.finished = recv_obj.finished_reason[i] is not None
569
- state.event.set()
570
-
571
- if self.enable_metrics:
572
- completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
573
-
574
- if state.first_token_time is None:
575
- state.first_token_time = time.time()
576
- self.metrics_collector.observe_time_to_first_token(
577
- state.first_token_time - state.created_time
689
+ self.get_weights_by_name_tmp.append(recv_obj)
690
+ if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
691
+ self.get_weights_by_name_result.set_result(
692
+ self.get_weights_by_name_tmp
578
693
  )
579
- else:
580
- if completion_tokens >= 2:
581
- self.metrics_collector.observe_time_per_output_token(
582
- (time.time() - state.first_token_time)
583
- / (completion_tokens - 1)
584
- )
585
-
586
- if state.finished:
587
- self.metrics_collector.inc_prompt_tokens(
588
- recv_obj.meta_info[i]["prompt_tokens"]
589
- )
590
- self.metrics_collector.inc_generation_tokens(completion_tokens)
591
- self.metrics_collector.observe_e2e_request_latency(
592
- time.time() - state.created_time
593
- )
594
- if completion_tokens >= 1:
595
- self.metrics_collector.observe_time_per_output_token(
596
- (time.time() - state.created_time) / completion_tokens
597
- )
694
+ else:
695
+ raise ValueError(f"Invalid object: {recv_obj=}")
598
696
 
599
697
  def convert_logprob_style(
600
698
  self,
@@ -19,7 +19,12 @@ from typing import Optional
19
19
 
20
20
  from sglang.srt.configs.model_config import ModelConfig
21
21
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
22
- from sglang.srt.managers.io_struct import UpdateWeightReqInput
22
+ from sglang.srt.managers.io_struct import (
23
+ GetWeightsByNameReqInput,
24
+ InitWeightsUpdateGroupReqInput,
25
+ UpdateWeightFromDiskReqInput,
26
+ UpdateWeightsFromDistributedReqInput,
27
+ )
23
28
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
24
29
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
30
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -47,9 +52,12 @@ class TpModelWorker:
47
52
  self.model_config = ModelConfig(
48
53
  server_args.model_path,
49
54
  trust_remote_code=server_args.trust_remote_code,
55
+ revision=server_args.revision,
50
56
  context_length=server_args.context_length,
51
57
  model_override_args=server_args.json_model_override_args,
52
58
  is_embedding=server_args.is_embedding,
59
+ dtype=server_args.dtype,
60
+ quantization=server_args.quantization,
53
61
  )
54
62
  self.model_runner = ModelRunner(
55
63
  model_config=self.model_config,
@@ -155,8 +163,33 @@ class TpModelWorker:
155
163
  embeddings = logits_output.embeddings
156
164
  return embeddings
157
165
 
158
- def update_weights(self, recv_req: UpdateWeightReqInput):
159
- success, message = self.model_runner.update_weights(
166
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
167
+ success, message = self.model_runner.update_weights_from_disk(
160
168
  recv_req.model_path, recv_req.load_format
161
169
  )
162
170
  return success, message
171
+
172
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
173
+ success, message = self.model_runner.init_weights_update_group(
174
+ recv_req.master_address,
175
+ recv_req.master_port,
176
+ recv_req.rank_offset,
177
+ recv_req.world_size,
178
+ recv_req.group_name,
179
+ recv_req.backend,
180
+ )
181
+ return success, message
182
+
183
+ def update_weights_from_distributed(
184
+ self, recv_req: UpdateWeightsFromDistributedReqInput
185
+ ):
186
+ success, message = self.model_runner.update_weights_from_distributed(
187
+ recv_req.name, recv_req.dtype, recv_req.shape
188
+ )
189
+ return success, message
190
+
191
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
192
+ parameter = self.model_runner.get_weights_by_name(
193
+ recv_req.name, recv_req.truncate_size
194
+ )
195
+ return parameter
@@ -23,7 +23,12 @@ from typing import Optional
23
23
  import psutil
24
24
  import torch
25
25
 
26
- from sglang.srt.managers.io_struct import UpdateWeightReqInput
26
+ from sglang.srt.managers.io_struct import (
27
+ GetWeightsByNameReqInput,
28
+ InitWeightsUpdateGroupReqInput,
29
+ UpdateWeightFromDiskReqInput,
30
+ UpdateWeightsFromDistributedReqInput,
31
+ )
27
32
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
28
33
  from sglang.srt.managers.tp_worker import TpModelWorker
29
34
  from sglang.srt.server_args import ServerArgs
@@ -204,10 +209,23 @@ class TpModelWorkerClient:
204
209
  ) % self.future_token_ids_limit
205
210
  return None, future_next_token_ids
206
211
 
207
- def update_weights(self, recv_req: UpdateWeightReqInput):
208
- success, message = self.worker.update_weights(recv_req)
212
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
213
+ success, message = self.worker.update_weights_from_disk(recv_req)
209
214
  return success, message
210
215
 
216
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
217
+ success, message = self.worker.init_weights_update_group(recv_req)
218
+ return success, message
219
+
220
+ def update_weights_from_distributed(
221
+ self, recv_req: UpdateWeightsFromDistributedReqInput
222
+ ):
223
+ success, message = self.worker.update_weights_from_distributed(recv_req)
224
+ return success, message
225
+
226
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
227
+ return self.worker.get_weights_by_name(recv_req)
228
+
211
229
  def __delete__(self):
212
230
  self.input_queue.put((None, None))
213
231
  self.copy_queue.put((None, None, None))
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
36
36
  from sglang.srt.model_executor.model_runner import ModelRunner
37
37
 
38
38
 
39
- def _to_torch(model: torch.nn.Module, reverse: bool = False):
39
+ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
40
40
  for sub in model._modules.values():
41
41
  if isinstance(sub, CustomOp):
42
42
  if reverse:
@@ -45,24 +45,30 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
45
45
  else:
46
46
  # NOTE: Temporarily workaround MoE
47
47
  if "FusedMoE" in sub.__class__.__name__:
48
- sub._forward_method = fused_moe_forward_native
48
+ if batch_size == 1:
49
+ # The performance of torch.compile on this layer is not always good when bs > 1,
50
+ # so we decide to skip it for now.
51
+ sub._forward_method = fused_moe_forward_native
49
52
  else:
50
53
  sub._forward_method = sub.forward_native
51
54
  setattr(sub, "is_torch_compile", True)
52
55
  if isinstance(sub, torch.nn.Module):
53
- _to_torch(sub, reverse)
56
+ _to_torch(sub, reverse, batch_size)
54
57
 
55
58
 
56
59
  @contextmanager
57
60
  def patch_model(
58
- model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
61
+ model: torch.nn.Module,
62
+ enable_compile: bool,
63
+ batch_size: int,
64
+ tp_group: "GroupCoordinator",
59
65
  ):
60
66
  """Patch the model to make it compatible with with torch.compile"""
61
67
  backup_ca_comm = None
62
68
 
63
69
  try:
64
70
  if enable_compile:
65
- _to_torch(model)
71
+ _to_torch(model, reverse=False, batch_size=batch_size)
66
72
  monkey_patch_vllm_all_gather()
67
73
  backup_ca_comm = tp_group.ca_comm
68
74
  # Use custom-allreduce here.
@@ -70,13 +76,15 @@ def patch_model(
70
76
  # even with ENABLE_INTRA_NODE_COMM=1.
71
77
  # tp_group.ca_comm = None
72
78
  yield torch.compile(
73
- torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
79
+ torch.no_grad()(model.forward),
80
+ mode="max-autotune-no-cudagraphs",
81
+ dynamic=False,
74
82
  )
75
83
  else:
76
84
  yield model.forward
77
85
  finally:
78
86
  if enable_compile:
79
- _to_torch(model, reverse=True)
87
+ _to_torch(model, reverse=True, batch_size=batch_size)
80
88
  monkey_patch_vllm_all_gather(reverse=True)
81
89
  tp_group.ca_comm = backup_ca_comm
82
90
 
@@ -237,6 +245,7 @@ class CudaGraphRunner:
237
245
  with patch_model(
238
246
  self.model_runner.model,
239
247
  bs in self.compile_bs,
248
+ bs,
240
249
  self.model_runner.tp_group,
241
250
  ) as forward:
242
251
  (
@@ -256,10 +256,15 @@ class ForwardBatch:
256
256
  ret.extend_prefix_lens = torch.tensor(
257
257
  batch.extend_prefix_lens, dtype=torch.int32
258
258
  ).to(device, non_blocking=True)
259
- ret.extend_num_tokens = batch.extend_num_tokens
260
- ret.positions, ret.extend_start_loc = compute_position_triton(
261
- ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
262
- )
259
+ if model_runner.server_args.attention_backend != "torch_native":
260
+ ret.extend_num_tokens = batch.extend_num_tokens
261
+ ret.positions, ret.extend_start_loc = compute_position_triton(
262
+ ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
263
+ )
264
+ else:
265
+ ret.positions, ret.extend_start_loc = compute_position_torch(
266
+ ret.extend_prefix_lens, ret.extend_seq_lens
267
+ )
263
268
  ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
264
269
  ret.extend_seq_lens_cpu = batch.extend_seq_lens
265
270
  ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens