sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -141,6 +141,7 @@ class ServerArgs:
141
141
  enable_nan_detection: bool = False
142
142
  enable_p2p_check: bool = False
143
143
  triton_attention_reduce_in_fp32: bool = False
144
+ triton_attention_num_kv_splits: int = 8
144
145
  num_continuous_decode_steps: int = 1
145
146
  delete_ckpt_after_loading: bool = False
146
147
 
@@ -220,12 +221,10 @@ class ServerArgs:
220
221
  if self.enable_dp_attention:
221
222
  self.dp_size = self.tp_size
222
223
  self.chunked_prefill_size = self.chunked_prefill_size // 2
223
- self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
224
224
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
225
225
  self.disable_overlap_schedule = True
226
226
  logger.warning(
227
227
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
228
- f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
229
228
  f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
230
229
  "Data parallel size is adjusted to be the same as tensor parallel size. "
231
230
  "Overlap scheduler is disabled."
@@ -282,7 +281,15 @@ class ServerArgs:
282
281
  "--load-format",
283
282
  type=str,
284
283
  default=ServerArgs.load_format,
285
- choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"],
284
+ choices=[
285
+ "auto",
286
+ "pt",
287
+ "safetensors",
288
+ "npcache",
289
+ "dummy",
290
+ "gguf",
291
+ "bitsandbytes",
292
+ ],
286
293
  help="The format of the model weights to load. "
287
294
  '"auto" will try to load the weights in the safetensors format '
288
295
  "and fall back to the pytorch bin format if safetensors format "
@@ -293,7 +300,9 @@ class ServerArgs:
293
300
  "a numpy cache to speed up the loading. "
294
301
  '"dummy" will initialize the weights with random values, '
295
302
  "which is mainly for profiling."
296
- '"gguf" will load the weights in the gguf format. ',
303
+ '"gguf" will load the weights in the gguf format. '
304
+ '"bitsandbytes" will load the weights using bitsandbytes '
305
+ "quantization.",
297
306
  )
298
307
  parser.add_argument(
299
308
  "--trust-remote-code",
@@ -689,11 +698,6 @@ class ServerArgs:
689
698
  action="store_true",
690
699
  help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
691
700
  )
692
- parser.add_argument(
693
- "--disable-nan-detection",
694
- action="store_true",
695
- help="Disable the NaN detection for better performance.",
696
- )
697
701
  parser.add_argument(
698
702
  "--disable-overlap-schedule",
699
703
  action="store_true",
@@ -753,6 +757,12 @@ class ServerArgs:
753
757
  help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
754
758
  "This only affects Triton attention kernels.",
755
759
  )
760
+ parser.add_argument(
761
+ "--triton-attention-num-kv-splits",
762
+ type=int,
763
+ default=ServerArgs.triton_attention_num_kv_splits,
764
+ help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
765
+ )
756
766
  parser.add_argument(
757
767
  "--num-continuous-decode-steps",
758
768
  type=int,
sglang/srt/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Common utilities."""
15
15
 
16
16
  import base64
17
+ import dataclasses
17
18
  import ipaddress
18
19
  import itertools
19
20
  import json
@@ -92,7 +93,7 @@ def is_flashinfer_available():
92
93
  """
93
94
  if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
94
95
  return False
95
- return torch.cuda.is_available() and not is_hip()
96
+ return torch.cuda.is_available() and torch.version.cuda
96
97
 
97
98
 
98
99
  def is_ipv6(address):
@@ -169,7 +170,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
169
170
  return wrapper
170
171
 
171
172
 
172
- def get_available_gpu_memory(device, gpu_id, distributed=False):
173
+ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True):
173
174
  """
174
175
  Get available memory for cuda:gpu_id device.
175
176
  When distributed is True, the available memory is the minimum available memory of all GPUs.
@@ -184,7 +185,8 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
184
185
  "which may cause useless memory allocation for torch CUDA context.",
185
186
  )
186
187
 
187
- torch.cuda.empty_cache()
188
+ if empty_cache:
189
+ torch.cuda.empty_cache()
188
190
  free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
189
191
 
190
192
  elif device == "xpu":
@@ -196,7 +198,9 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
196
198
  f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
197
199
  "which may cause useless memory allocation for torch XPU context.",
198
200
  )
199
- torch.xpu.empty_cache()
201
+
202
+ if empty_cache:
203
+ torch.xpu.empty_cache()
200
204
  used_memory = torch.xpu.memory_allocated()
201
205
  total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
202
206
  free_gpu_memory = total_gpu_memory - used_memory
@@ -1068,9 +1072,6 @@ def get_device_name(device_id: int = 0) -> str:
1068
1072
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1069
1073
  return torch.cuda.get_device_name(device_id)
1070
1074
 
1071
- if hasattr(torch, "hip") and torch.hip.is_available():
1072
- return torch.hip.get_device_name(device_id)
1073
-
1074
1075
  if hasattr(torch, "xpu") and torch.xpu.is_available():
1075
1076
  return torch.xpu.get_device_name(device_id)
1076
1077
 
@@ -1083,9 +1084,6 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1083
1084
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1084
1085
  major, minor = torch.cuda.get_device_capability(device_id)
1085
1086
 
1086
- if hasattr(torch, "hip") and torch.hip.is_available():
1087
- major, minor = torch.cuda.get_device_capability(device_id)
1088
-
1089
1087
  if hasattr(torch, "xpu") and torch.xpu.is_available():
1090
1088
  major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
1091
1089
  "."
@@ -1241,49 +1239,37 @@ def cuda_device_count_stateless() -> int:
1241
1239
  return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1242
1240
 
1243
1241
 
1244
- def should_use_tensor_core(
1245
- kv_cache_dtype: torch.dtype,
1246
- num_attention_heads: int,
1247
- num_kv_heads: int,
1248
- ) -> bool:
1249
- """
1250
- Determine whether to use tensor cores for attention computation.
1251
-
1252
- Args:
1253
- kv_cache_dtype: Data type of the KV cache
1254
- num_attention_heads: Number of attention heads
1255
- num_kv_heads: Number of key/value heads
1256
-
1257
- Returns:
1258
- bool: Whether to use tensor cores
1259
- """
1260
- # Try to use environment variable first
1261
- env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
1262
- if env_override is not None:
1263
- return env_override.lower() == "true"
1264
-
1265
- # Try to use _grouped_size_compiled_for_decode_kernels if available
1266
- # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1267
- try:
1268
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1269
-
1270
- if not _grouped_size_compiled_for_decode_kernels(
1271
- num_attention_heads,
1272
- num_kv_heads,
1273
- ):
1274
- return True
1242
+ def dataclass_to_string_truncated(data, max_length=2048):
1243
+ if isinstance(data, str):
1244
+ if len(data) > max_length:
1245
+ half_length = max_length // 2
1246
+ return f'"{data[:half_length]} ... {data[-half_length:]}"'
1275
1247
  else:
1276
- return False
1277
- except (ImportError, AttributeError):
1278
- pass
1279
-
1280
- # Calculate GQA group size
1281
- gqa_group_size = num_attention_heads // num_kv_heads
1282
-
1283
- # Determine based on dtype and GQA group size
1284
- if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1285
- return True
1286
- elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
1287
- return gqa_group_size > 4
1248
+ return f'"{data}"'
1249
+ elif isinstance(data, (list, tuple)):
1250
+ if len(data) > max_length:
1251
+ half_length = max_length // 2
1252
+ return str(data[:half_length]) + " ... " + str(data[-half_length:])
1253
+ else:
1254
+ return str(data)
1255
+ elif isinstance(data, dict):
1256
+ return (
1257
+ "{"
1258
+ + ", ".join(
1259
+ f"{k}: {dataclass_to_string_truncated(v, max_length)}"
1260
+ for k, v in data.items()
1261
+ )
1262
+ + "}"
1263
+ )
1264
+ elif dataclasses.is_dataclass(data):
1265
+ fields = dataclasses.fields(data)
1266
+ return (
1267
+ f"{data.__class__.__name__}("
1268
+ + ", ".join(
1269
+ f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
1270
+ for f in fields
1271
+ )
1272
+ + ")"
1273
+ )
1288
1274
  else:
1289
- return False
1275
+ return str(data)
@@ -0,0 +1,341 @@
1
+ import itertools
2
+ import unittest
3
+
4
+ import torch
5
+
6
+ from sglang.srt.layers.activation import SiluAndMul
7
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
8
+ from sglang.srt.layers.quantization.fp8_kernel import (
9
+ per_token_group_quant_fp8,
10
+ w8a8_block_fp8_matmul,
11
+ )
12
+
13
+
14
+ # For test
15
+ def native_per_token_group_quant_fp8(
16
+ x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
17
+ ):
18
+ """Function to perform per-token-group quantization on an input tensor `x` using native torch.
19
+
20
+ It converts the tensor values into float8 values and returns the
21
+ quantized tensor along with the scaling factor used for quantization.
22
+ Note that only `torch.float8_e4m3fn` is supported for now.
23
+ """
24
+ assert (
25
+ x.shape[-1] % group_size == 0
26
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
27
+ assert x.is_contiguous(), "`x` is not contiguous"
28
+
29
+ finfo = torch.finfo(dtype)
30
+ fp8_min = finfo.min
31
+ fp8_max = finfo.max
32
+
33
+ x_ = x.reshape(x.numel() // group_size, group_size)
34
+ amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
35
+ x_s = amax / fp8_max
36
+ x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
37
+ x_q = x_q.reshape(x.shape)
38
+ x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
39
+
40
+ return x_q, x_s
41
+
42
+
43
+ class TestPerTokenGroupQuantFP8(unittest.TestCase):
44
+ DTYPES = [torch.half, torch.bfloat16, torch.float32]
45
+ NUM_TOKENS = [7, 83, 2048]
46
+ D = [512, 4096, 5120, 13824]
47
+ GROUP_SIZE = [64, 128, 256, 512]
48
+ SEEDS = [0]
49
+
50
+ @classmethod
51
+ def setUpClass(cls):
52
+ if not torch.cuda.is_available():
53
+ raise unittest.SkipTest("CUDA is not available")
54
+ torch.set_default_device("cuda")
55
+
56
+ def _per_token_group_quant_fp8(self, num_tokens, d, dtype, group_size, seed):
57
+ torch.manual_seed(seed)
58
+
59
+ x = torch.rand(num_tokens, d, dtype=dtype)
60
+
61
+ with torch.inference_mode():
62
+ ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
63
+ out, scale = per_token_group_quant_fp8(x, group_size)
64
+
65
+ self.assertTrue(
66
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
67
+ )
68
+ self.assertTrue(torch.allclose(scale, ref_scale))
69
+
70
+ def test_per_token_group_quant_fp8(self):
71
+ for params in itertools.product(
72
+ self.NUM_TOKENS,
73
+ self.D,
74
+ self.DTYPES,
75
+ self.GROUP_SIZE,
76
+ self.SEEDS,
77
+ ):
78
+ with self.subTest(
79
+ num_tokens=params[0],
80
+ d=params[1],
81
+ dtype=params[2],
82
+ group_size=params[3],
83
+ seed=params[4],
84
+ ):
85
+ self._per_token_group_quant_fp8(*params)
86
+
87
+
88
+ # For test
89
+ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
90
+ """This function performs matrix multiplication with block-wise quantization using native torch.
91
+
92
+ It takes two input tensors `A` and `B` with scales `As` and `Bs`.
93
+ The output is returned in the specified `output_dtype`.
94
+ """
95
+
96
+ A = A.to(torch.float32)
97
+ B = B.to(torch.float32)
98
+ assert A.shape[-1] == B.shape[-1]
99
+ assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
100
+ assert len(block_size) == 2
101
+ block_n, block_k = block_size[0], block_size[1]
102
+ assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
103
+ assert A.shape[:-1] == As.shape[:-1]
104
+
105
+ M = A.numel() // A.shape[-1]
106
+ N, K = B.shape
107
+ origin_C_shape = A.shape[:-1] + (N,)
108
+ A = A.reshape(M, A.shape[-1])
109
+ As = As.reshape(M, As.shape[-1])
110
+ n_tiles = (N + block_n - 1) // block_n
111
+ k_tiles = (K + block_k - 1) // block_k
112
+ assert n_tiles == Bs.shape[0]
113
+ assert k_tiles == Bs.shape[1]
114
+
115
+ C_shape = (M, N)
116
+ C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
117
+
118
+ A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
119
+ B_tiles = [
120
+ [
121
+ B[
122
+ j * block_n : min((j + 1) * block_n, N),
123
+ i * block_k : min((i + 1) * block_k, K),
124
+ ]
125
+ for i in range(k_tiles)
126
+ ]
127
+ for j in range(n_tiles)
128
+ ]
129
+ C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
130
+ As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
131
+
132
+ for i in range(k_tiles):
133
+ for j in range(n_tiles):
134
+ a = A_tiles[i]
135
+ b = B_tiles[j][i]
136
+ c = C_tiles[j]
137
+ s = As_tiles[i] * Bs[j][i]
138
+ c[:, :] += torch.matmul(a, b.t()) * s
139
+
140
+ C = C.reshape(origin_C_shape).to(output_dtype)
141
+ return C
142
+
143
+
144
+ class TestW8A8BlockFP8Matmul(unittest.TestCase):
145
+ OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
146
+ M = [1, 7, 83, 512, 2048]
147
+ N = [128, 512, 1024, 4096, 7748, 13824]
148
+ K = [256, 4096, 5120, 3884, 13824]
149
+ # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
150
+ BLOCK_SIZE = [[128, 128]]
151
+ SEEDS = [0]
152
+
153
+ @classmethod
154
+ def setUpClass(cls):
155
+ if not torch.cuda.is_available():
156
+ raise unittest.SkipTest("CUDA is not available")
157
+ torch.set_default_device("cuda")
158
+
159
+ def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
160
+ torch.manual_seed(seed)
161
+ # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
162
+ factor_for_scale = 1e-2
163
+ fp8_info = torch.finfo(torch.float8_e4m3fn)
164
+ fp8_max, fp8_min = fp8_info.max, fp8_info.min
165
+
166
+ A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
167
+ A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
168
+
169
+ B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
170
+ B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
171
+
172
+ block_n, block_k = block_size[0], block_size[1]
173
+ n_tiles = (N + block_n - 1) // block_n
174
+ k_tiles = (K + block_k - 1) // block_k
175
+
176
+ As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
177
+ Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
178
+
179
+ with torch.inference_mode():
180
+ ref_out = native_w8a8_block_fp8_matmul(
181
+ A_fp8, B_fp8, As, Bs, block_size, out_dtype
182
+ )
183
+ out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
184
+
185
+ self.assertTrue(
186
+ torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
187
+ / torch.mean(torch.abs(ref_out.to(torch.float32)))
188
+ < 0.001
189
+ )
190
+
191
+ def test_w8a8_block_fp8_matmul(self):
192
+ for params in itertools.product(
193
+ self.M,
194
+ self.N,
195
+ self.K,
196
+ self.BLOCK_SIZE,
197
+ self.OUT_DTYPES,
198
+ self.SEEDS,
199
+ ):
200
+ with self.subTest(
201
+ M=params[0],
202
+ N=params[1],
203
+ K=params[2],
204
+ block_size=params[3],
205
+ out_dtype=params[4],
206
+ seed=params[5],
207
+ ):
208
+ self._w8a8_block_fp8_matmul(*params)
209
+
210
+
211
+ # For test
212
+ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
213
+ """This function performs fused moe with block-wise quantization using native torch."""
214
+
215
+ B, D = a.shape
216
+ a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
217
+ out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
218
+ score = torch.softmax(score, dim=-1, dtype=torch.float32)
219
+ topk_weight, topk_ids = torch.topk(score, topk)
220
+ topk_weight = topk_weight.view(-1)
221
+ topk_ids = topk_ids.view(-1)
222
+
223
+ _, block_k = block_shape[0], block_shape[1]
224
+ a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
225
+ # NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
226
+ a_q = a_q.to(torch.float32)
227
+ for i in range(w1.shape[0]):
228
+ mask = topk_ids == i
229
+ if mask.sum():
230
+ inter_out = native_w8a8_block_fp8_matmul(
231
+ a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype
232
+ )
233
+ act_out = SiluAndMul().forward_native(inter_out)
234
+ act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k)
235
+ act_out = act_out.to(torch.float32)
236
+ out[mask] = native_w8a8_block_fp8_matmul(
237
+ act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype
238
+ )
239
+ return (
240
+ out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
241
+ ).sum(dim=1)
242
+
243
+
244
+ class TestW8A8BlockFP8FusedMoE(unittest.TestCase):
245
+ DTYPES = [torch.float32, torch.half, torch.bfloat16]
246
+ M = [1, 33, 64, 222, 1024 * 128]
247
+ N = [128, 1024, 2048]
248
+ K = [256, 4096, 5120]
249
+ E = [8, 24]
250
+ TOP_KS = [2, 6]
251
+ BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
252
+ # BLOCK_SIZE = [[128, 128]]
253
+ SEEDS = [0]
254
+
255
+ @classmethod
256
+ def setUpClass(cls):
257
+ if not torch.cuda.is_available():
258
+ raise unittest.SkipTest("CUDA is not available")
259
+ torch.set_default_device("cuda")
260
+
261
+ def _w8a8_block_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed):
262
+ torch.manual_seed(seed)
263
+ # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
264
+ factor_for_scale = 1e-2
265
+ fp8_info = torch.finfo(torch.float8_e4m3fn)
266
+ fp8_max, fp8_min = fp8_info.max, fp8_info.min
267
+
268
+ a = torch.randn((M, K), dtype=dtype) / 10
269
+
270
+ w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max
271
+ w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
272
+
273
+ w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * fp8_max
274
+ w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
275
+
276
+ block_n, block_k = block_size[0], block_size[1]
277
+ n_tiles_w1 = (2 * N + block_n - 1) // block_n
278
+ n_tiles_w2 = (K + block_n - 1) // block_n
279
+ k_tiles_w1 = (K + block_k - 1) // block_k
280
+ k_tiles_w2 = (N + block_k - 1) // block_k
281
+
282
+ w1_s = (
283
+ torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
284
+ * factor_for_scale
285
+ )
286
+ w2_s = (
287
+ torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
288
+ * factor_for_scale
289
+ )
290
+
291
+ score = torch.randn((M, E), dtype=dtype)
292
+
293
+ with torch.inference_mode():
294
+ out = fused_moe(
295
+ a,
296
+ w1,
297
+ w2,
298
+ score,
299
+ topk,
300
+ renormalize=False,
301
+ use_fp8_w8a8=True,
302
+ w1_scale=w1_s,
303
+ w2_scale=w2_s,
304
+ block_shape=block_size,
305
+ )
306
+ ref_out = torch_w8a8_block_fp8_moe(
307
+ a, w1, w2, w1_s, w2_s, score, topk, block_size
308
+ )
309
+
310
+ self.assertTrue(
311
+ torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
312
+ / torch.mean(torch.abs(ref_out.to(torch.float32)))
313
+ < 0.02
314
+ )
315
+
316
+ def test_w8a8_block_fp8_fused_moe(self):
317
+ for params in itertools.product(
318
+ self.M,
319
+ self.N,
320
+ self.K,
321
+ self.E,
322
+ self.TOP_KS,
323
+ self.BLOCK_SIZE,
324
+ self.DTYPES,
325
+ self.SEEDS,
326
+ ):
327
+ with self.subTest(
328
+ M=params[0],
329
+ N=params[1],
330
+ K=params[2],
331
+ E=params[3],
332
+ topk=params[4],
333
+ block_size=params[5],
334
+ dtype=params[6],
335
+ seed=params[7],
336
+ ):
337
+ self._w8a8_block_fp8_fused_moe(*params)
338
+
339
+
340
+ if __name__ == "__main__":
341
+ unittest.main(verbosity=2)
sglang/test/test_utils.py CHANGED
@@ -568,6 +568,7 @@ def run_bench_serving(
568
568
  disable_tqdm=False,
569
569
  disable_stream=disable_stream,
570
570
  disable_ignore_eos=False,
571
+ return_logprob=False,
571
572
  lora_name=None,
572
573
  extra_request_body=None,
573
574
  profile=None,
@@ -719,13 +720,13 @@ def run_and_check_memory_leak(
719
720
 
720
721
  # Clean up everything
721
722
  kill_process_tree(process.pid)
722
- kill_process_tree(process.pid)
723
723
  stdout.close()
724
724
  stderr.close()
725
725
  if os.path.exists(STDOUT_FILENAME):
726
726
  os.remove(STDOUT_FILENAME)
727
727
  if os.path.exists(STDERR_FILENAME):
728
728
  os.remove(STDERR_FILENAME)
729
+ kill_process_tree(process.pid)
729
730
  t.join()
730
731
 
731
732
  # Assert success
@@ -733,7 +734,7 @@ def run_and_check_memory_leak(
733
734
  has_leak = False
734
735
  has_abort = False
735
736
  for line in output_lines:
736
- if "The server is fired" in line:
737
+ if "Uvicorn running" in line:
737
738
  has_new_server = True
738
739
  if "leak" in line:
739
740
  has_leak = True
sglang/utils.py CHANGED
@@ -1,4 +1,4 @@
1
- """Common utilities."""
1
+ """Common utilities"""
2
2
 
3
3
  import base64
4
4
  import gc
@@ -79,7 +79,14 @@ class HttpResponse:
79
79
  return self.resp.status
80
80
 
81
81
 
82
- def http_request(url, json=None, stream=False, api_key=None, verify=None):
82
+ def http_request(
83
+ url,
84
+ json=None,
85
+ stream=False,
86
+ api_key=None,
87
+ verify=None,
88
+ method: Optional[str] = None,
89
+ ):
83
90
  """A faster version of requests.post with low-level urllib API."""
84
91
  headers = {"Content-Type": "application/json; charset=utf-8"}
85
92
 
@@ -90,7 +97,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None):
90
97
  if stream:
91
98
  return requests.post(url, json=json, stream=True, headers=headers)
92
99
  else:
93
- req = urllib.request.Request(url, headers=headers)
100
+ req = urllib.request.Request(url, headers=headers, method=method)
94
101
  if json is None:
95
102
  data = None
96
103
  else:
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.0.post1"
1
+ __version__ = "0.4.1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.4.0.post1
3
+ Version: 0.4.1
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -215,6 +215,7 @@ Requires-Dist: requests
215
215
  Requires-Dist: tqdm
216
216
  Requires-Dist: numpy
217
217
  Requires-Dist: IPython
218
+ Requires-Dist: setproctitle
218
219
  Provides-Extra: runtime-common
219
220
  Requires-Dist: aiohttp; extra == "runtime-common"
220
221
  Requires-Dist: decord; extra == "runtime-common"
@@ -232,16 +233,17 @@ Requires-Dist: psutil; extra == "runtime-common"
232
233
  Requires-Dist: pydantic; extra == "runtime-common"
233
234
  Requires-Dist: python-multipart; extra == "runtime-common"
234
235
  Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
235
- Requires-Dist: torchao; extra == "runtime-common"
236
+ Requires-Dist: torchao>=0.7.0; extra == "runtime-common"
236
237
  Requires-Dist: uvicorn; extra == "runtime-common"
237
238
  Requires-Dist: uvloop; extra == "runtime-common"
238
- Requires-Dist: xgrammar>=0.1.4; extra == "runtime-common"
239
+ Requires-Dist: xgrammar>=0.1.6; extra == "runtime-common"
239
240
  Provides-Extra: srt
240
241
  Requires-Dist: sglang[runtime_common]; extra == "srt"
241
242
  Requires-Dist: torch; extra == "srt"
242
243
  Requires-Dist: vllm<=0.6.4.post1,>=0.6.3.post1; extra == "srt"
243
244
  Requires-Dist: cuda-python; extra == "srt"
244
- Requires-Dist: flashinfer>=0.1.6; extra == "srt"
245
+ Requires-Dist: flashinfer==0.1.6; extra == "srt"
246
+ Requires-Dist: sgl-kernel>=0.0.2.post8; extra == "srt"
245
247
  Provides-Extra: srt-hip
246
248
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
247
249
  Requires-Dist: torch; extra == "srt-hip"
@@ -311,8 +313,11 @@ Requires-Dist: sglang[test]; extra == "dev-hpu"
311
313
 
312
314
  --------------------------------------------------------------------------------
313
315
 
314
- | [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Documentation**](https://sgl-project.github.io/) | [**Join Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2tmmp6flg-89dOlJW2TjnBrTRk1I_~GA) |
315
- [**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing) | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
316
+ | [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/)
317
+ | [**Documentation**](https://sgl-project.github.io/)
318
+ | [**Join Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2tmmp6flg-89dOlJW2TjnBrTRk1I_~GA)
319
+ | [**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing)
320
+ | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
316
321
 
317
322
  ## News
318
323
  - [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
@@ -353,7 +358,7 @@ Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s
353
358
  [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487)
354
359
 
355
360
  ## Adoption and Sponsorship
356
- The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI.
361
+ The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, xAI, 01.AI and DataCrunch.
357
362
 
358
363
  ## Acknowledgment and Citation
359
364
  We learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).