sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -20,6 +20,7 @@ import random
20
20
  import tempfile
21
21
  from typing import List, Optional
22
22
 
23
+ from sglang.srt.hf_transformers_utils import check_gguf_file
23
24
  from sglang.srt.utils import (
24
25
  get_amdgpu_memory_capacity,
25
26
  get_nvgpu_memory_capacity,
@@ -49,6 +50,7 @@ class ServerArgs:
49
50
  served_model_name: Optional[str] = None
50
51
  chat_template: Optional[str] = None
51
52
  is_embedding: bool = False
53
+ revision: Optional[str] = None
52
54
 
53
55
  # Port
54
56
  host: str = "127.0.0.1"
@@ -58,7 +60,7 @@ class ServerArgs:
58
60
  mem_fraction_static: Optional[float] = None
59
61
  max_running_requests: Optional[int] = None
60
62
  max_total_tokens: Optional[int] = None
61
- chunked_prefill_size: int = 8192
63
+ chunked_prefill_size: Optional[int] = None
62
64
  max_prefill_tokens: int = 16384
63
65
  schedule_policy: str = "lpm"
64
66
  schedule_conservativeness: float = 1.0
@@ -120,7 +122,7 @@ class ServerArgs:
120
122
  disable_jump_forward: bool = False
121
123
  disable_cuda_graph: bool = False
122
124
  disable_cuda_graph_padding: bool = False
123
- disable_disk_cache: bool = False
125
+ disable_outlines_disk_cache: bool = False
124
126
  disable_custom_all_reduce: bool = False
125
127
  disable_mla: bool = False
126
128
  disable_overlap_schedule: bool = False
@@ -128,7 +130,7 @@ class ServerArgs:
128
130
  enable_dp_attention: bool = False
129
131
  enable_torch_compile: bool = False
130
132
  torch_compile_max_bs: int = 32
131
- cuda_graph_max_bs: int = 160
133
+ cuda_graph_max_bs: Optional[int] = None
132
134
  torchao_config: str = ""
133
135
  enable_nan_detection: bool = False
134
136
  enable_p2p_check: bool = False
@@ -144,19 +146,20 @@ class ServerArgs:
144
146
  if self.served_model_name is None:
145
147
  self.served_model_name = self.model_path
146
148
 
147
- if self.chunked_prefill_size <= 0:
148
- # Disable chunked prefill
149
- self.chunked_prefill_size = None
150
-
151
149
  if self.random_seed is None:
152
150
  self.random_seed = random.randint(0, 1 << 30)
153
151
 
154
- # Mem fraction depends on the tensor parallelism size
152
+ if is_hip():
153
+ gpu_mem = get_amdgpu_memory_capacity()
154
+ else:
155
+ gpu_mem = get_nvgpu_memory_capacity()
156
+
157
+ # Set mem fraction static, which depends on the tensor parallelism size
155
158
  if self.mem_fraction_static is None:
156
159
  if self.tp_size >= 16:
157
160
  self.mem_fraction_static = 0.79
158
161
  elif self.tp_size >= 8:
159
- self.mem_fraction_static = 0.82
162
+ self.mem_fraction_static = 0.81
160
163
  elif self.tp_size >= 4:
161
164
  self.mem_fraction_static = 0.85
162
165
  elif self.tp_size >= 2:
@@ -164,25 +167,35 @@ class ServerArgs:
164
167
  else:
165
168
  self.mem_fraction_static = 0.88
166
169
 
167
- # Adjust for GPUs with small memory capacities
168
- if is_hip():
169
- gpu_mem = get_amdgpu_memory_capacity()
170
- else:
171
- gpu_mem = get_nvgpu_memory_capacity()
172
- if gpu_mem < 25000:
173
- self.chunked_prefill_size //= 4 # make it 2048
174
- self.cuda_graph_max_bs = 4
175
- logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
170
+ # Set chunked prefill size, which depends on the gpu memory capacity
171
+ if self.chunked_prefill_size is None:
172
+ if gpu_mem < 25_000:
173
+ self.chunked_prefill_size = 2048
174
+ else:
175
+ self.chunked_prefill_size = 8192
176
176
 
177
- # Choose kernel backends
178
- if not is_flashinfer_available():
179
- self.attention_backend = "triton"
180
- self.sampling_backend = "pytorch"
177
+ # Set cuda graph max batch size
178
+ if self.cuda_graph_max_bs is None:
179
+ if gpu_mem < 25_000:
180
+ self.cuda_graph_max_bs = 8
181
+ else:
182
+ self.cuda_graph_max_bs = 160
181
183
 
184
+ # Choose kernel backends
182
185
  if self.attention_backend is None:
183
- self.attention_backend = "flashinfer"
186
+ self.attention_backend = (
187
+ "flashinfer" if is_flashinfer_available() else "triton"
188
+ )
184
189
  if self.sampling_backend is None:
185
- self.sampling_backend = "flashinfer"
190
+ self.sampling_backend = (
191
+ "flashinfer" if is_flashinfer_available() else "pytorch"
192
+ )
193
+
194
+ if self.attention_backend == "torch_native":
195
+ logger.warning(
196
+ "Cuda graph is disabled because of using torch native attention backend"
197
+ )
198
+ self.disable_cuda_graph = True
186
199
 
187
200
  # Others
188
201
  if self.enable_dp_attention:
@@ -191,14 +204,20 @@ class ServerArgs:
191
204
  self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
192
205
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
193
206
  self.disable_overlap_schedule = True
194
- logger.info(
207
+ logger.warning(
195
208
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
196
209
  f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
197
210
  f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
198
211
  "Data parallel size is adjusted to be the same as tensor parallel size. "
199
- "Overlap schedule is disabled."
212
+ "Overlap scheduler is disabled."
200
213
  )
201
214
 
215
+ # GGUF
216
+ if (
217
+ self.load_format == "auto" or self.load_format == "gguf"
218
+ ) and check_gguf_file(self.model_path):
219
+ self.quantization = self.load_format = "gguf"
220
+
202
221
  @staticmethod
203
222
  def add_cli_args(parser: argparse.ArgumentParser):
204
223
  # Model and port args
@@ -238,7 +257,7 @@ class ServerArgs:
238
257
  "--load-format",
239
258
  type=str,
240
259
  default=ServerArgs.load_format,
241
- choices=["auto", "pt", "safetensors", "npcache", "dummy"],
260
+ choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
242
261
  help="The format of the model weights to load. "
243
262
  '"auto" will try to load the weights in the safetensors format '
244
263
  "and fall back to the pytorch bin format if safetensors format "
@@ -248,7 +267,8 @@ class ServerArgs:
248
267
  '"npcache" will load the weights in pytorch format and store '
249
268
  "a numpy cache to speed up the loading. "
250
269
  '"dummy" will initialize the weights with random values, '
251
- "which is mainly for profiling.",
270
+ "which is mainly for profiling."
271
+ '"gguf" will load the weights in the gguf format. ',
252
272
  )
253
273
  parser.add_argument(
254
274
  "--trust-remote-code",
@@ -288,6 +308,7 @@ class ServerArgs:
288
308
  "gptq_marlin",
289
309
  "awq_marlin",
290
310
  "bitsandbytes",
311
+ "gguf",
291
312
  ],
292
313
  help="The quantization method.",
293
314
  )
@@ -321,6 +342,14 @@ class ServerArgs:
321
342
  action="store_true",
322
343
  help="Whether to use a CausalLM as an embedding model.",
323
344
  )
345
+ parser.add_argument(
346
+ "--revision",
347
+ type=str,
348
+ default=None,
349
+ help="The specific model version to use. It can be a branch "
350
+ "name, a tag name, or a commit id. If unspecified, will use "
351
+ "the default version.",
352
+ )
324
353
 
325
354
  # Memory and scheduling
326
355
  parser.add_argument(
@@ -572,7 +601,7 @@ class ServerArgs:
572
601
  parser.add_argument(
573
602
  "--attention-backend",
574
603
  type=str,
575
- choices=["flashinfer", "triton"],
604
+ choices=["flashinfer", "triton", "torch_native"],
576
605
  default=ServerArgs.attention_backend,
577
606
  help="Choose the kernels for attention layers.",
578
607
  )
@@ -613,9 +642,9 @@ class ServerArgs:
613
642
  help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
614
643
  )
615
644
  parser.add_argument(
616
- "--disable-disk-cache",
645
+ "--disable-outlines-disk-cache",
617
646
  action="store_true",
618
- help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
647
+ help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
619
648
  )
620
649
  parser.add_argument(
621
650
  "--disable-custom-all-reduce",
@@ -716,6 +745,11 @@ class ServerArgs:
716
745
  action=DeprecatedAction,
717
746
  help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
718
747
  )
748
+ parser.add_argument(
749
+ "--disable-disk-cache",
750
+ action=DeprecatedAction,
751
+ help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
752
+ )
719
753
 
720
754
  @classmethod
721
755
  def from_cli_args(cls, args: argparse.Namespace):
sglang/srt/utils.py CHANGED
@@ -30,6 +30,7 @@ import subprocess
30
30
  import tempfile
31
31
  import time
32
32
  import warnings
33
+ from functools import lru_cache
33
34
  from importlib.metadata import PackageNotFoundError, version
34
35
  from io import BytesIO
35
36
  from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
@@ -38,6 +39,7 @@ import numpy as np
38
39
  import psutil
39
40
  import requests
40
41
  import torch
42
+ import torch.distributed
41
43
  import torch.distributed as dist
42
44
  import triton
43
45
  import zmq
@@ -67,6 +69,22 @@ def is_hip() -> bool:
67
69
  return torch.version.hip is not None
68
70
 
69
71
 
72
+ def is_cuda():
73
+ return hasattr(torch, "cuda") and torch.cuda.is_available()
74
+
75
+
76
+ def is_cuda_alike():
77
+ return is_cuda() or is_hip()
78
+
79
+
80
+ def is_hpu() -> bool:
81
+ return hasattr(torch, "hpu") and torch.hpu.is_available()
82
+
83
+
84
+ def is_xpu() -> bool:
85
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
86
+
87
+
70
88
  def is_flashinfer_available():
71
89
  """
72
90
  Check whether flashinfer is available.
@@ -412,16 +430,12 @@ def suppress_other_loggers():
412
430
  from vllm.logger import logger as vllm_default_logger
413
431
 
414
432
  vllm_default_logger.setLevel(logging.WARN)
415
- logging.getLogger("vllm.config").setLevel(logging.ERROR)
416
433
  logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
417
434
  logging.WARN
418
435
  )
419
436
  logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
420
437
  logging.WARN
421
438
  )
422
- logging.getLogger("vllm.selector").setLevel(logging.WARN)
423
- logging.getLogger("vllm.utils").setLevel(logging.ERROR)
424
- logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
425
439
 
426
440
  warnings.filterwarnings(
427
441
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -443,26 +457,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
443
457
  )
444
458
 
445
459
 
446
- def kill_parent_process():
447
- """Kill the parent process and all children of the parent process."""
448
- current_process = psutil.Process()
449
- parent_process = current_process.parent()
450
- kill_child_process(
451
- parent_process.pid, include_self=True, skip_pid=current_process.pid
452
- )
453
- try:
454
- current_process.kill()
455
- except psutil.NoSuchProcess:
456
- pass
457
-
458
-
459
- def kill_child_process(pid=None, include_self=False, skip_pid=None):
460
- """Kill the process and all its children process."""
461
- if pid is None:
462
- pid = os.getpid()
460
+ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
461
+ """Kill the process and all its child processes."""
462
+ if parent_pid is None:
463
+ parent_pid = os.getpid()
464
+ include_parent = False
463
465
 
464
466
  try:
465
- itself = psutil.Process(pid)
467
+ itself = psutil.Process(parent_pid)
466
468
  except psutil.NoSuchProcess:
467
469
  return
468
470
 
@@ -475,38 +477,17 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
475
477
  except psutil.NoSuchProcess:
476
478
  pass
477
479
 
478
- if include_self:
480
+ if include_parent:
479
481
  try:
480
482
  itself.kill()
481
483
 
482
484
  # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
483
485
  # so we send an additional signal to kill them.
484
- itself.send_signal(signal.SIGINT)
486
+ itself.send_signal(signal.SIGQUIT)
485
487
  except psutil.NoSuchProcess:
486
488
  pass
487
489
 
488
490
 
489
- def monkey_patch_vllm_model_config():
490
- from vllm.config import ModelConfig
491
-
492
- if not hasattr(ModelConfig, "_resolve_task"):
493
- return
494
-
495
- def _resolve_task(
496
- self,
497
- task_option,
498
- hf_config,
499
- ):
500
- supported_tasks = {
501
- "generate": True,
502
- "embedding": False,
503
- }
504
- selected_task = "generate"
505
- return supported_tasks, selected_task
506
-
507
- setattr(ModelConfig, "_resolve_task", _resolve_task)
508
-
509
-
510
491
  def monkey_patch_vllm_p2p_access_check(gpu_id: int):
511
492
  """
512
493
  Monkey patch the slow p2p access check in vllm.
@@ -569,6 +550,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
569
550
  setattr(GroupCoordinator, "all_gather", all_gather)
570
551
 
571
552
 
553
+ def monkey_patch_vllm_gguf_config():
554
+ from vllm.model_executor.layers.linear import LinearBase
555
+ from vllm.model_executor.layers.quantization.gguf import (
556
+ GGUFConfig,
557
+ GGUFEmbeddingMethod,
558
+ GGUFLinearMethod,
559
+ )
560
+
561
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
562
+
563
+ def get_quant_method_with_embedding_replaced(
564
+ self, layer: torch.nn.Module, prefix: str
565
+ ) -> Optional["QuantizeMethodBase"]:
566
+ if isinstance(layer, LinearBase):
567
+ return GGUFLinearMethod(self)
568
+ elif isinstance(layer, VocabParallelEmbedding):
569
+ # patch to own VocabParallelEmbedding
570
+ return GGUFEmbeddingMethod(self)
571
+ return None
572
+
573
+ setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
574
+
575
+
572
576
  def maybe_set_triton_cache_manager() -> None:
573
577
  """Set environment variable to tell Triton to use a
574
578
  custom cache manager"""
@@ -874,7 +878,9 @@ def get_amdgpu_memory_capacity():
874
878
  try:
875
879
  # Run rocm-smi and capture the output
876
880
  result = subprocess.run(
877
- ["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"],
881
+ [
882
+ "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'"
883
+ ],
878
884
  stdout=subprocess.PIPE,
879
885
  stderr=subprocess.PIPE,
880
886
  shell=True,
@@ -885,9 +891,8 @@ def get_amdgpu_memory_capacity():
885
891
 
886
892
  # Parse the output to extract memory values in MiB
887
893
  memory_values = [
888
- float(mem) / 1024 / 1024
894
+ float(mem.split("(")[0].strip()) / 1024
889
895
  for mem in result.stdout.strip().split("\n")
890
- if re.match(r"^\d+(\.\d+)?$", mem.strip())
891
896
  ]
892
897
 
893
898
  if not memory_values:
@@ -934,11 +939,88 @@ def get_nvgpu_memory_capacity():
934
939
  )
935
940
 
936
941
 
942
+ # Copy from pytorch and OpenRLHF to allow creating multiple main groups.
943
+ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
944
+ # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
945
+ def init_custom_process_group(
946
+ backend=None,
947
+ init_method=None,
948
+ timeout=None,
949
+ world_size=-1,
950
+ rank=-1,
951
+ store=None,
952
+ group_name=None,
953
+ pg_options=None,
954
+ ):
955
+ from torch.distributed.distributed_c10d import (
956
+ Backend,
957
+ PrefixStore,
958
+ _new_process_group_helper,
959
+ _world,
960
+ default_pg_timeout,
961
+ rendezvous,
962
+ )
963
+
964
+ assert (store is None) or (
965
+ init_method is None
966
+ ), "Cannot specify both init_method and store."
967
+
968
+ if store is not None:
969
+ assert world_size > 0, "world_size must be positive if using store"
970
+ assert rank >= 0, "rank must be non-negative if using store"
971
+ elif init_method is None:
972
+ init_method = "env://"
973
+
974
+ if backend:
975
+ backend = Backend(backend)
976
+ else:
977
+ backend = Backend("undefined")
978
+
979
+ if timeout is None:
980
+ timeout = default_pg_timeout
981
+
982
+ # backward compatible API
983
+ if store is None:
984
+ rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
985
+ store, rank, world_size = next(rendezvous_iterator)
986
+ store.set_timeout(timeout)
987
+
988
+ # Use a PrefixStore to avoid accidental overrides of keys used by
989
+ # different systems (e.g. RPC) in case the store is multi-tenant.
990
+ store = PrefixStore(group_name, store)
991
+
992
+ # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
993
+ # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
994
+ # We need to determine the appropriate parameter name based on PyTorch version
995
+ pg_options_param_name = (
996
+ "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
997
+ )
998
+ pg, _ = _new_process_group_helper(
999
+ world_size,
1000
+ rank,
1001
+ [],
1002
+ backend,
1003
+ store,
1004
+ group_name=group_name,
1005
+ **{pg_options_param_name: pg_options},
1006
+ timeout=timeout,
1007
+ )
1008
+
1009
+ _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
1010
+
1011
+ return pg
1012
+
1013
+
937
1014
  def crash_on_warnings():
938
1015
  # Crash on warning if we are running CI tests
939
1016
  return get_bool_env_var("SGLANG_IS_IN_CI")
940
1017
 
941
1018
 
1019
+ def print_warning_once(msg: str) -> None:
1020
+ # Set the stacklevel to 2 to print the caller's line info
1021
+ logger.warning(msg, stacklevel=2)
1022
+
1023
+
942
1024
  def get_device_name(device_id: int = 0) -> str:
943
1025
  if hasattr(torch, "cuda") and torch.cuda.is_available():
944
1026
  return torch.cuda.get_device_name(device_id)
@@ -953,9 +1035,42 @@ def get_device_name(device_id: int = 0) -> str:
953
1035
  return torch.hpu.get_device_name(device_id)
954
1036
 
955
1037
 
1038
+ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1039
+ major, minor = None, None
1040
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1041
+ major, minor = torch.cuda.get_device_capability(device_id)
1042
+
1043
+ if hasattr(torch, "hip") and torch.hip.is_available():
1044
+ major, minor = torch.cuda.get_device_capability(device_id)
1045
+
1046
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
1047
+ major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
1048
+ "."
1049
+ )
1050
+ major, minor = int(major), int(minor)
1051
+
1052
+ # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
1053
+ # Update this once the support is available.
1054
+ if hasattr(torch, "hpu") and torch.hpu.is_available():
1055
+ try:
1056
+ major, minor = torch.hpu.get_device_capability(device_id)
1057
+ except Exception as e:
1058
+ raise RuntimeError(
1059
+ f"An error occurred while getting device capability of hpu: {e}."
1060
+ ) from e
1061
+
1062
+ return major, minor
1063
+
1064
+
956
1065
  sglang_lib = Library("sglang", "FRAGMENT") # noqa
957
1066
 
958
1067
 
1068
+ # Some backends use pytorch version < 2.4.0 which doesn't
1069
+ # support `torch.library.custom_op`.
1070
+ def supports_custom_op() -> bool:
1071
+ return hasattr(torch.library, "custom_op")
1072
+
1073
+
959
1074
  def direct_register_custom_op(
960
1075
  op_name: str,
961
1076
  op_func: Callable,
@@ -1032,3 +1147,93 @@ def set_gpu_proc_affinity(
1032
1147
  def get_bool_env_var(name: str, default: str = "false") -> bool:
1033
1148
  value = os.getenv(name, default)
1034
1149
  return value.lower() in ("true", "1")
1150
+
1151
+
1152
+ @lru_cache(maxsize=8)
1153
+ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
1154
+ # Note: cuda_visible_devices is not used, but we keep it as an argument for
1155
+ # LRU Cache purposes.
1156
+
1157
+ # Code below is based on
1158
+ # https://github.com/pytorch/pytorch/blob/
1159
+ # c1cd946818442aca8c7f812b16d187ce1586c3bc/
1160
+ # torch/cuda/__init__.py#L831C1-L831C17
1161
+ import torch.cuda
1162
+ import torch.version
1163
+
1164
+ if not torch.cuda._is_compiled():
1165
+ return 0
1166
+ if is_hip():
1167
+ # ROCm uses amdsmi instead of nvml for stateless device count
1168
+ # This requires a sufficiently modern version of Torch 2.4.0
1169
+ raw_count = (
1170
+ torch.cuda._device_count_amdsmi()
1171
+ if (hasattr(torch.cuda, "_device_count_amdsmi"))
1172
+ else -1
1173
+ )
1174
+ else:
1175
+ raw_count = torch.cuda._device_count_nvml()
1176
+ r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1177
+ return r
1178
+
1179
+
1180
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
1181
+ def cuda_device_count_stateless() -> int:
1182
+ """Get number of CUDA devices, caching based on the value of
1183
+ CUDA_VISIBLE_DEVICES at the time of call.
1184
+
1185
+ This should be used instead of torch.cuda.device_count()
1186
+ unless CUDA_VISIBLE_DEVICES has already been set to the desired
1187
+ value."""
1188
+
1189
+ # This can be removed and simply replaced with torch.cuda.get_device_count
1190
+ # after https://github.com/pytorch/pytorch/pull/122815 is released.
1191
+ return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1192
+
1193
+
1194
+ def should_use_tensor_core(
1195
+ kv_cache_dtype: torch.dtype,
1196
+ num_attention_heads: int,
1197
+ num_kv_heads: int,
1198
+ ) -> bool:
1199
+ """
1200
+ Determine whether to use tensor cores for attention computation.
1201
+
1202
+ Args:
1203
+ kv_cache_dtype: Data type of the KV cache
1204
+ num_attention_heads: Number of attention heads
1205
+ num_kv_heads: Number of key/value heads
1206
+
1207
+ Returns:
1208
+ bool: Whether to use tensor cores
1209
+ """
1210
+ # Try to use environment variable first
1211
+ env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
1212
+ if env_override is not None:
1213
+ return env_override.lower() == "true"
1214
+
1215
+ # Try to use _grouped_size_compiled_for_decode_kernels if available
1216
+ # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1217
+ try:
1218
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1219
+
1220
+ if not _grouped_size_compiled_for_decode_kernels(
1221
+ num_attention_heads,
1222
+ num_kv_heads,
1223
+ ):
1224
+ return True
1225
+ else:
1226
+ return False
1227
+ except (ImportError, AttributeError):
1228
+ pass
1229
+
1230
+ # Calculate GQA group size
1231
+ gqa_group_size = num_attention_heads // num_kv_heads
1232
+
1233
+ # Determine based on dtype and GQA group size
1234
+ if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1235
+ return True
1236
+ elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
1237
+ return gqa_group_size > 4
1238
+ else:
1239
+ return False