sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -20,8 +20,12 @@ import random
20
20
  import tempfile
21
21
  from typing import List, Optional
22
22
 
23
+ import torch
24
+
25
+ from sglang.srt.hf_transformers_utils import check_gguf_file
23
26
  from sglang.srt.utils import (
24
27
  get_amdgpu_memory_capacity,
28
+ get_hpu_memory_capacity,
25
29
  get_nvgpu_memory_capacity,
26
30
  is_flashinfer_available,
27
31
  is_hip,
@@ -49,6 +53,7 @@ class ServerArgs:
49
53
  served_model_name: Optional[str] = None
50
54
  chat_template: Optional[str] = None
51
55
  is_embedding: bool = False
56
+ revision: Optional[str] = None
52
57
 
53
58
  # Port
54
59
  host: str = "127.0.0.1"
@@ -58,7 +63,7 @@ class ServerArgs:
58
63
  mem_fraction_static: Optional[float] = None
59
64
  max_running_requests: Optional[int] = None
60
65
  max_total_tokens: Optional[int] = None
61
- chunked_prefill_size: int = 8192
66
+ chunked_prefill_size: Optional[int] = None
62
67
  max_prefill_tokens: int = 16384
63
68
  schedule_policy: str = "lpm"
64
69
  schedule_conservativeness: float = 1.0
@@ -89,6 +94,8 @@ class ServerArgs:
89
94
  # Data parallelism
90
95
  dp_size: int = 1
91
96
  load_balance_method: str = "round_robin"
97
+ # Expert parallelism
98
+ ep_size: int = 1
92
99
 
93
100
  # Multi-node distributed serving
94
101
  dist_init_addr: Optional[str] = None
@@ -120,15 +127,16 @@ class ServerArgs:
120
127
  disable_jump_forward: bool = False
121
128
  disable_cuda_graph: bool = False
122
129
  disable_cuda_graph_padding: bool = False
123
- disable_disk_cache: bool = False
130
+ disable_outlines_disk_cache: bool = False
124
131
  disable_custom_all_reduce: bool = False
125
132
  disable_mla: bool = False
126
133
  disable_overlap_schedule: bool = False
127
134
  enable_mixed_chunk: bool = False
128
135
  enable_dp_attention: bool = False
136
+ enable_ep_moe: bool = False
129
137
  enable_torch_compile: bool = False
130
138
  torch_compile_max_bs: int = 32
131
- cuda_graph_max_bs: int = 160
139
+ cuda_graph_max_bs: Optional[int] = None
132
140
  torchao_config: str = ""
133
141
  enable_nan_detection: bool = False
134
142
  enable_p2p_check: bool = False
@@ -144,19 +152,25 @@ class ServerArgs:
144
152
  if self.served_model_name is None:
145
153
  self.served_model_name = self.model_path
146
154
 
147
- if self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0:
148
- # Disable chunked prefill
149
- self.chunked_prefill_size = None
150
-
151
155
  if self.random_seed is None:
152
156
  self.random_seed = random.randint(0, 1 << 30)
153
157
 
154
- # Mem fraction depends on the tensor parallelism size
158
+ if is_hip():
159
+ gpu_mem = get_amdgpu_memory_capacity()
160
+ elif torch.cuda.is_available():
161
+ gpu_mem = get_nvgpu_memory_capacity()
162
+ elif self.device == "hpu":
163
+ gpu_mem = get_hpu_memory_capacity()
164
+ else:
165
+ # GPU memory is not known yet or no GPU is available.
166
+ gpu_mem = None
167
+
168
+ # Set mem fraction static, which depends on the tensor parallelism size
155
169
  if self.mem_fraction_static is None:
156
170
  if self.tp_size >= 16:
157
171
  self.mem_fraction_static = 0.79
158
172
  elif self.tp_size >= 8:
159
- self.mem_fraction_static = 0.82
173
+ self.mem_fraction_static = 0.81
160
174
  elif self.tp_size >= 4:
161
175
  self.mem_fraction_static = 0.85
162
176
  elif self.tp_size >= 2:
@@ -164,25 +178,43 @@ class ServerArgs:
164
178
  else:
165
179
  self.mem_fraction_static = 0.88
166
180
 
167
- # Adjust for GPUs with small memory capacities
168
- if is_hip():
169
- gpu_mem = get_amdgpu_memory_capacity()
170
- else:
171
- gpu_mem = get_nvgpu_memory_capacity()
172
- if gpu_mem < 25000:
173
- self.chunked_prefill_size //= 4 # make it 2048
174
- self.cuda_graph_max_bs = 4
175
- logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
181
+ # Set chunked prefill size, which depends on the gpu memory capacity
182
+ if self.chunked_prefill_size is None:
183
+ if gpu_mem is not None and gpu_mem < 25_000:
184
+ self.chunked_prefill_size = 2048
185
+ else:
186
+ self.chunked_prefill_size = 8192
187
+
188
+ # Set cuda graph max batch size
189
+ if self.cuda_graph_max_bs is None:
190
+ # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
191
+ if gpu_mem is not None and gpu_mem < 25_000:
192
+ if self.tp_size < 4:
193
+ self.cuda_graph_max_bs = 8
194
+ else:
195
+ self.cuda_graph_max_bs = 80
196
+ else:
197
+ self.cuda_graph_max_bs = 160
176
198
 
177
199
  # Choose kernel backends
178
- if not is_flashinfer_available():
179
- self.attention_backend = "triton"
200
+ if self.device == "hpu":
201
+ self.attention_backend = "torch_native"
180
202
  self.sampling_backend = "pytorch"
181
203
 
182
204
  if self.attention_backend is None:
183
- self.attention_backend = "flashinfer"
205
+ self.attention_backend = (
206
+ "flashinfer" if is_flashinfer_available() else "triton"
207
+ )
184
208
  if self.sampling_backend is None:
185
- self.sampling_backend = "flashinfer"
209
+ self.sampling_backend = (
210
+ "flashinfer" if is_flashinfer_available() else "pytorch"
211
+ )
212
+
213
+ if self.attention_backend == "torch_native":
214
+ logger.warning(
215
+ "Cuda graph is disabled because of using torch native attention backend"
216
+ )
217
+ self.disable_cuda_graph = True
186
218
 
187
219
  # Others
188
220
  if self.enable_dp_attention:
@@ -191,13 +223,25 @@ class ServerArgs:
191
223
  self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
192
224
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
193
225
  self.disable_overlap_schedule = True
194
- logger.info(
226
+ logger.warning(
195
227
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
196
228
  f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
197
229
  f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
198
230
  "Data parallel size is adjusted to be the same as tensor parallel size. "
199
- "Overlap schedule is disabled."
231
+ "Overlap scheduler is disabled."
200
232
  )
233
+ # Expert parallelism
234
+ if self.enable_ep_moe:
235
+ self.ep_size = self.tp_size
236
+ logger.info(
237
+ f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
238
+ )
239
+
240
+ # GGUF
241
+ if (
242
+ self.load_format == "auto" or self.load_format == "gguf"
243
+ ) and check_gguf_file(self.model_path):
244
+ self.quantization = self.load_format = "gguf"
201
245
 
202
246
  @staticmethod
203
247
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -238,7 +282,7 @@ class ServerArgs:
238
282
  "--load-format",
239
283
  type=str,
240
284
  default=ServerArgs.load_format,
241
- choices=["auto", "pt", "safetensors", "npcache", "dummy"],
285
+ choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
242
286
  help="The format of the model weights to load. "
243
287
  '"auto" will try to load the weights in the safetensors format '
244
288
  "and fall back to the pytorch bin format if safetensors format "
@@ -248,7 +292,8 @@ class ServerArgs:
248
292
  '"npcache" will load the weights in pytorch format and store '
249
293
  "a numpy cache to speed up the loading. "
250
294
  '"dummy" will initialize the weights with random values, '
251
- "which is mainly for profiling.",
295
+ "which is mainly for profiling."
296
+ '"gguf" will load the weights in the gguf format. ',
252
297
  )
253
298
  parser.add_argument(
254
299
  "--trust-remote-code",
@@ -288,6 +333,7 @@ class ServerArgs:
288
333
  "gptq_marlin",
289
334
  "awq_marlin",
290
335
  "bitsandbytes",
336
+ "gguf",
291
337
  ],
292
338
  help="The quantization method.",
293
339
  )
@@ -321,6 +367,14 @@ class ServerArgs:
321
367
  action="store_true",
322
368
  help="Whether to use a CausalLM as an embedding model.",
323
369
  )
370
+ parser.add_argument(
371
+ "--revision",
372
+ type=str,
373
+ default=None,
374
+ help="The specific model version to use. It can be a branch "
375
+ "name, a tag name, or a commit id. If unspecified, will use "
376
+ "the default version.",
377
+ )
324
378
 
325
379
  # Memory and scheduling
326
380
  parser.add_argument(
@@ -492,6 +546,14 @@ class ServerArgs:
492
546
  "shortest_queue",
493
547
  ],
494
548
  )
549
+ # Expert parallelism
550
+ parser.add_argument(
551
+ "--expert-parallel-size",
552
+ "--ep-size",
553
+ type=int,
554
+ default=ServerArgs.ep_size,
555
+ help="The expert parallelism size.",
556
+ )
495
557
 
496
558
  # Multi-node distributed serving
497
559
  parser.add_argument(
@@ -572,7 +634,7 @@ class ServerArgs:
572
634
  parser.add_argument(
573
635
  "--attention-backend",
574
636
  type=str,
575
- choices=["flashinfer", "triton"],
637
+ choices=["flashinfer", "triton", "torch_native"],
576
638
  default=ServerArgs.attention_backend,
577
639
  help="Choose the kernels for attention layers.",
578
640
  )
@@ -613,9 +675,9 @@ class ServerArgs:
613
675
  help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
614
676
  )
615
677
  parser.add_argument(
616
- "--disable-disk-cache",
678
+ "--disable-outlines-disk-cache",
617
679
  action="store_true",
618
- help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
680
+ help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
619
681
  )
620
682
  parser.add_argument(
621
683
  "--disable-custom-all-reduce",
@@ -647,6 +709,11 @@ class ServerArgs:
647
709
  action="store_true",
648
710
  help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
649
711
  )
712
+ parser.add_argument(
713
+ "--enable-ep-moe",
714
+ action="store_true",
715
+ help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
716
+ )
650
717
  parser.add_argument(
651
718
  "--enable-torch-compile",
652
719
  action="store_true",
@@ -716,11 +783,17 @@ class ServerArgs:
716
783
  action=DeprecatedAction,
717
784
  help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
718
785
  )
786
+ parser.add_argument(
787
+ "--disable-disk-cache",
788
+ action=DeprecatedAction,
789
+ help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.",
790
+ )
719
791
 
720
792
  @classmethod
721
793
  def from_cli_args(cls, args: argparse.Namespace):
722
794
  args.tp_size = args.tensor_parallel_size
723
795
  args.dp_size = args.data_parallel_size
796
+ args.ep_size = args.expert_parallel_size
724
797
  attrs = [attr.name for attr in dataclasses.fields(cls)]
725
798
  return cls(**{attr: getattr(args, attr) for attr in attrs})
726
799
 
sglang/srt/utils.py CHANGED
@@ -30,6 +30,7 @@ import subprocess
30
30
  import tempfile
31
31
  import time
32
32
  import warnings
33
+ from functools import lru_cache
33
34
  from importlib.metadata import PackageNotFoundError, version
34
35
  from io import BytesIO
35
36
  from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
@@ -38,6 +39,7 @@ import numpy as np
38
39
  import psutil
39
40
  import requests
40
41
  import torch
42
+ import torch.distributed
41
43
  import torch.distributed as dist
42
44
  import triton
43
45
  import zmq
@@ -67,6 +69,22 @@ def is_hip() -> bool:
67
69
  return torch.version.hip is not None
68
70
 
69
71
 
72
+ def is_cuda():
73
+ return hasattr(torch, "cuda") and torch.cuda.is_available()
74
+
75
+
76
+ def is_cuda_alike():
77
+ return is_cuda() or is_hip()
78
+
79
+
80
+ def is_hpu() -> bool:
81
+ return hasattr(torch, "hpu") and torch.hpu.is_available()
82
+
83
+
84
+ def is_xpu() -> bool:
85
+ return hasattr(torch, "xpu") and torch.xpu.is_available()
86
+
87
+
70
88
  def is_flashinfer_available():
71
89
  """
72
90
  Check whether flashinfer is available.
@@ -183,6 +201,18 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
183
201
  total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
184
202
  free_gpu_memory = total_gpu_memory - used_memory
185
203
 
204
+ elif device == "hpu":
205
+ num_gpus = torch.hpu.device_count()
206
+ assert gpu_id < num_gpus
207
+
208
+ if torch.hpu.current_device() != gpu_id:
209
+ print(
210
+ f"WARNING: current device is not {gpu_id}, but {torch.hpu.current_device()}, ",
211
+ "which may cause useless memory allocation for torch HPU context.",
212
+ )
213
+
214
+ free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
215
+
186
216
  if distributed:
187
217
  tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
188
218
  torch.device(device, gpu_id)
@@ -412,16 +442,12 @@ def suppress_other_loggers():
412
442
  from vllm.logger import logger as vllm_default_logger
413
443
 
414
444
  vllm_default_logger.setLevel(logging.WARN)
415
- logging.getLogger("vllm.config").setLevel(logging.ERROR)
416
445
  logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
417
446
  logging.WARN
418
447
  )
419
448
  logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
420
449
  logging.WARN
421
450
  )
422
- logging.getLogger("vllm.selector").setLevel(logging.WARN)
423
- logging.getLogger("vllm.utils").setLevel(logging.ERROR)
424
- logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
425
451
 
426
452
  warnings.filterwarnings(
427
453
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -474,27 +500,6 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
474
500
  pass
475
501
 
476
502
 
477
- def monkey_patch_vllm_model_config():
478
- from vllm.config import ModelConfig
479
-
480
- if not hasattr(ModelConfig, "_resolve_task"):
481
- return
482
-
483
- def _resolve_task(
484
- self,
485
- task_option,
486
- hf_config,
487
- ):
488
- supported_tasks = {
489
- "generate": True,
490
- "embedding": False,
491
- }
492
- selected_task = "generate"
493
- return supported_tasks, selected_task
494
-
495
- setattr(ModelConfig, "_resolve_task", _resolve_task)
496
-
497
-
498
503
  def monkey_patch_vllm_p2p_access_check(gpu_id: int):
499
504
  """
500
505
  Monkey patch the slow p2p access check in vllm.
@@ -557,6 +562,29 @@ def monkey_patch_vllm_all_gather(reverse: bool = False):
557
562
  setattr(GroupCoordinator, "all_gather", all_gather)
558
563
 
559
564
 
565
+ def monkey_patch_vllm_gguf_config():
566
+ from vllm.model_executor.layers.linear import LinearBase
567
+ from vllm.model_executor.layers.quantization.gguf import (
568
+ GGUFConfig,
569
+ GGUFEmbeddingMethod,
570
+ GGUFLinearMethod,
571
+ )
572
+
573
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
574
+
575
+ def get_quant_method_with_embedding_replaced(
576
+ self, layer: torch.nn.Module, prefix: str
577
+ ) -> Optional["QuantizeMethodBase"]:
578
+ if isinstance(layer, LinearBase):
579
+ return GGUFLinearMethod(self)
580
+ elif isinstance(layer, VocabParallelEmbedding):
581
+ # patch to own VocabParallelEmbedding
582
+ return GGUFEmbeddingMethod(self)
583
+ return None
584
+
585
+ setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced)
586
+
587
+
560
588
  def maybe_set_triton_cache_manager() -> None:
561
589
  """Set environment variable to tell Triton to use a
562
590
  custom cache manager"""
@@ -862,7 +890,9 @@ def get_amdgpu_memory_capacity():
862
890
  try:
863
891
  # Run rocm-smi and capture the output
864
892
  result = subprocess.run(
865
- ["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"],
893
+ [
894
+ "rocminfo | grep 'gfx' -A 100 | grep 'Pool 1' -A 5 | grep 'Size:' | awk '{print $2}'"
895
+ ],
866
896
  stdout=subprocess.PIPE,
867
897
  stderr=subprocess.PIPE,
868
898
  shell=True,
@@ -873,9 +903,8 @@ def get_amdgpu_memory_capacity():
873
903
 
874
904
  # Parse the output to extract memory values in MiB
875
905
  memory_values = [
876
- float(mem) / 1024 / 1024
906
+ float(mem.split("(")[0].strip()) / 1024
877
907
  for mem in result.stdout.strip().split("\n")
878
- if re.match(r"^\d+(\.\d+)?$", mem.strip())
879
908
  ]
880
909
 
881
910
  if not memory_values:
@@ -922,11 +951,119 @@ def get_nvgpu_memory_capacity():
922
951
  )
923
952
 
924
953
 
954
+ def get_hpu_memory_capacity():
955
+ try:
956
+ # Run hl-smi and capture the output
957
+ result = subprocess.run(
958
+ ["hl-smi --query | grep 'Total'"],
959
+ stdout=subprocess.PIPE,
960
+ stderr=subprocess.PIPE,
961
+ shell=True,
962
+ text=True,
963
+ )
964
+
965
+ if result.returncode != 0:
966
+ raise RuntimeError(f"hl-smi error: {result.stderr.strip()}")
967
+
968
+ # Parse the output to extract memory values in MiB
969
+ memory_values = [
970
+ float(mem.split(" ")[-2]) for mem in result.stdout.strip().split("\n")
971
+ ]
972
+
973
+ if not memory_values:
974
+ raise ValueError("No GPU memory values found.")
975
+
976
+ # Return the minimum memory value
977
+ return min(memory_values)
978
+
979
+ except FileNotFoundError:
980
+ raise RuntimeError(
981
+ "hl-smi not found. Ensure Habana drivers are installed and accessible."
982
+ )
983
+
984
+
985
+ # Copy from pytorch and OpenRLHF to allow creating multiple main groups.
986
+ # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
987
+ # https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
988
+ def init_custom_process_group(
989
+ backend=None,
990
+ init_method=None,
991
+ timeout=None,
992
+ world_size=-1,
993
+ rank=-1,
994
+ store=None,
995
+ group_name=None,
996
+ pg_options=None,
997
+ ):
998
+ from torch.distributed.distributed_c10d import (
999
+ Backend,
1000
+ PrefixStore,
1001
+ _new_process_group_helper,
1002
+ _world,
1003
+ default_pg_timeout,
1004
+ rendezvous,
1005
+ )
1006
+
1007
+ assert (store is None) or (
1008
+ init_method is None
1009
+ ), "Cannot specify both init_method and store."
1010
+
1011
+ if store is not None:
1012
+ assert world_size > 0, "world_size must be positive if using store"
1013
+ assert rank >= 0, "rank must be non-negative if using store"
1014
+ elif init_method is None:
1015
+ init_method = "env://"
1016
+
1017
+ if backend:
1018
+ backend = Backend(backend)
1019
+ else:
1020
+ backend = Backend("undefined")
1021
+
1022
+ if timeout is None:
1023
+ timeout = default_pg_timeout
1024
+
1025
+ # backward compatible API
1026
+ if store is None:
1027
+ rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
1028
+ store, rank, world_size = next(rendezvous_iterator)
1029
+ store.set_timeout(timeout)
1030
+
1031
+ # Use a PrefixStore to avoid accidental overrides of keys used by
1032
+ # different systems (e.g. RPC) in case the store is multi-tenant.
1033
+ store = PrefixStore(group_name, store)
1034
+
1035
+ # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
1036
+ # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
1037
+ # We need to determine the appropriate parameter name based on PyTorch version
1038
+ pg_options_param_name = (
1039
+ "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
1040
+ )
1041
+ pg, _ = _new_process_group_helper(
1042
+ world_size,
1043
+ rank,
1044
+ [],
1045
+ backend,
1046
+ store,
1047
+ group_name=group_name,
1048
+ **{pg_options_param_name: pg_options},
1049
+ timeout=timeout,
1050
+ )
1051
+
1052
+ _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
1053
+
1054
+ return pg
1055
+
1056
+
925
1057
  def crash_on_warnings():
926
1058
  # Crash on warning if we are running CI tests
927
1059
  return get_bool_env_var("SGLANG_IS_IN_CI")
928
1060
 
929
1061
 
1062
+ def print_warning_once(msg: str) -> None:
1063
+ # Set the stacklevel to 2 to print the caller's line info
1064
+ logger.warning(msg, stacklevel=2)
1065
+
1066
+
930
1067
  def get_device_name(device_id: int = 0) -> str:
931
1068
  if hasattr(torch, "cuda") and torch.cuda.is_available():
932
1069
  return torch.cuda.get_device_name(device_id)
@@ -941,9 +1078,49 @@ def get_device_name(device_id: int = 0) -> str:
941
1078
  return torch.hpu.get_device_name(device_id)
942
1079
 
943
1080
 
1081
+ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1082
+ major, minor = None, None
1083
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1084
+ major, minor = torch.cuda.get_device_capability(device_id)
1085
+
1086
+ if hasattr(torch, "hip") and torch.hip.is_available():
1087
+ major, minor = torch.cuda.get_device_capability(device_id)
1088
+
1089
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
1090
+ major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
1091
+ "."
1092
+ )
1093
+ major, minor = int(major), int(minor)
1094
+
1095
+ # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
1096
+ # Update this once the support is available.
1097
+ if hasattr(torch, "hpu") and torch.hpu.is_available():
1098
+ try:
1099
+ major, minor = torch.hpu.get_device_capability(device_id)
1100
+ except Exception as e:
1101
+ raise RuntimeError(
1102
+ f"An error occurred while getting device capability of hpu: {e}."
1103
+ ) from e
1104
+
1105
+ return major, minor
1106
+
1107
+
1108
+ def get_compiler_backend() -> str:
1109
+ if hasattr(torch, "hpu") and torch.hpu.is_available():
1110
+ return "hpu_backend"
1111
+
1112
+ return "inductor"
1113
+
1114
+
944
1115
  sglang_lib = Library("sglang", "FRAGMENT") # noqa
945
1116
 
946
1117
 
1118
+ # Some backends use pytorch version < 2.4.0 which doesn't
1119
+ # support `torch.library.custom_op`.
1120
+ def supports_custom_op() -> bool:
1121
+ return hasattr(torch.library, "custom_op")
1122
+
1123
+
947
1124
  def direct_register_custom_op(
948
1125
  op_name: str,
949
1126
  op_func: Callable,
@@ -1020,3 +1197,93 @@ def set_gpu_proc_affinity(
1020
1197
  def get_bool_env_var(name: str, default: str = "false") -> bool:
1021
1198
  value = os.getenv(name, default)
1022
1199
  return value.lower() in ("true", "1")
1200
+
1201
+
1202
+ @lru_cache(maxsize=8)
1203
+ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
1204
+ # Note: cuda_visible_devices is not used, but we keep it as an argument for
1205
+ # LRU Cache purposes.
1206
+
1207
+ # Code below is based on
1208
+ # https://github.com/pytorch/pytorch/blob/
1209
+ # c1cd946818442aca8c7f812b16d187ce1586c3bc/
1210
+ # torch/cuda/__init__.py#L831C1-L831C17
1211
+ import torch.cuda
1212
+ import torch.version
1213
+
1214
+ if not torch.cuda._is_compiled():
1215
+ return 0
1216
+ if is_hip():
1217
+ # ROCm uses amdsmi instead of nvml for stateless device count
1218
+ # This requires a sufficiently modern version of Torch 2.4.0
1219
+ raw_count = (
1220
+ torch.cuda._device_count_amdsmi()
1221
+ if (hasattr(torch.cuda, "_device_count_amdsmi"))
1222
+ else -1
1223
+ )
1224
+ else:
1225
+ raw_count = torch.cuda._device_count_nvml()
1226
+ r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
1227
+ return r
1228
+
1229
+
1230
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
1231
+ def cuda_device_count_stateless() -> int:
1232
+ """Get number of CUDA devices, caching based on the value of
1233
+ CUDA_VISIBLE_DEVICES at the time of call.
1234
+
1235
+ This should be used instead of torch.cuda.device_count()
1236
+ unless CUDA_VISIBLE_DEVICES has already been set to the desired
1237
+ value."""
1238
+
1239
+ # This can be removed and simply replaced with torch.cuda.get_device_count
1240
+ # after https://github.com/pytorch/pytorch/pull/122815 is released.
1241
+ return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1242
+
1243
+
1244
+ def should_use_tensor_core(
1245
+ kv_cache_dtype: torch.dtype,
1246
+ num_attention_heads: int,
1247
+ num_kv_heads: int,
1248
+ ) -> bool:
1249
+ """
1250
+ Determine whether to use tensor cores for attention computation.
1251
+
1252
+ Args:
1253
+ kv_cache_dtype: Data type of the KV cache
1254
+ num_attention_heads: Number of attention heads
1255
+ num_kv_heads: Number of key/value heads
1256
+
1257
+ Returns:
1258
+ bool: Whether to use tensor cores
1259
+ """
1260
+ # Try to use environment variable first
1261
+ env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
1262
+ if env_override is not None:
1263
+ return env_override.lower() == "true"
1264
+
1265
+ # Try to use _grouped_size_compiled_for_decode_kernels if available
1266
+ # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1267
+ try:
1268
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1269
+
1270
+ if not _grouped_size_compiled_for_decode_kernels(
1271
+ num_attention_heads,
1272
+ num_kv_heads,
1273
+ ):
1274
+ return True
1275
+ else:
1276
+ return False
1277
+ except (ImportError, AttributeError):
1278
+ pass
1279
+
1280
+ # Calculate GQA group size
1281
+ gqa_group_size = num_attention_heads // num_kv_heads
1282
+
1283
+ # Determine based on dtype and GQA group size
1284
+ if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1285
+ return True
1286
+ elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
1287
+ return gqa_group_size > 4
1288
+ else:
1289
+ return False