sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
sglang/test/runners.py CHANGED
@@ -19,7 +19,7 @@ from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
22
- from transformers import AutoModelForCausalLM
22
+ from transformers import AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor
23
23
 
24
24
  from sglang.srt.hf_transformers_utils import get_tokenizer
25
25
  from sglang.srt.server import Engine
@@ -135,6 +135,76 @@ class HFRunner:
135
135
  return True
136
136
  return False
137
137
 
138
+ # copy from https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct/blob/main/gme_inference.py
139
+
140
+ def _get_gme_qwen2_vl_embeddings(
141
+ self, prompts, image_data: Optional[List[str]] = None
142
+ ):
143
+ from sglang.srt.utils import load_image
144
+
145
+ images = None
146
+ if image_data is not None:
147
+ images = [load_image(image)[0] for image in image_data]
148
+
149
+ inputs = self.processor(
150
+ text=prompts,
151
+ images=images,
152
+ padding=True,
153
+ truncation=True,
154
+ max_length=1800,
155
+ return_tensors="pt",
156
+ )
157
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
158
+ with torch.no_grad():
159
+ embeddings = self._forward_gme_qwen2_vl(**inputs)
160
+ return embeddings.tolist()
161
+
162
+ def _forward_gme_qwen2_vl(
163
+ self,
164
+ input_ids: Optional[torch.LongTensor] = None,
165
+ attention_mask: Optional[torch.Tensor] = None,
166
+ position_ids: Optional[torch.LongTensor] = None,
167
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
168
+ inputs_embeds: Optional[torch.FloatTensor] = None,
169
+ pixel_values: Optional[torch.Tensor] = None,
170
+ image_grid_thw: Optional[torch.LongTensor] = None,
171
+ pooling_mask: Optional[torch.LongTensor] = None,
172
+ **kwargs,
173
+ ) -> torch.Tensor:
174
+ if inputs_embeds is None:
175
+ inputs_embeds = self.model.model.embed_tokens(input_ids)
176
+ if pixel_values is not None:
177
+ pixel_values = pixel_values.type(self.model.visual.get_dtype())
178
+ image_embeds = self.model.visual(
179
+ pixel_values, grid_thw=image_grid_thw
180
+ ).to(inputs_embeds.device)
181
+ image_mask = input_ids == self.model.config.image_token_id
182
+ inputs_embeds[image_mask] = image_embeds
183
+ if attention_mask is not None:
184
+ attention_mask = attention_mask.to(inputs_embeds.device)
185
+
186
+ outputs = self.model.model(
187
+ input_ids=None,
188
+ position_ids=position_ids,
189
+ attention_mask=attention_mask,
190
+ past_key_values=past_key_values,
191
+ inputs_embeds=inputs_embeds,
192
+ )
193
+
194
+ pooling_mask = attention_mask if pooling_mask is None else pooling_mask
195
+ left_padding = pooling_mask[:, -1].sum() == pooling_mask.shape[0] # TODO
196
+ if left_padding:
197
+ embeddings = outputs.last_hidden_state[:, -1]
198
+ else:
199
+ sequence_lengths = pooling_mask.sum(dim=1) - 1
200
+ batch_size = outputs.last_hidden_state.shape[0]
201
+ embeddings = outputs.last_hidden_state[
202
+ torch.arange(batch_size, device=outputs.last_hidden_state.device),
203
+ sequence_lengths,
204
+ ]
205
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
206
+ return embeddings.contiguous()
207
+
138
208
  def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
139
209
  # Apply model-specific patches
140
210
  monkey_patch_gemma2_sdpa()
@@ -148,9 +218,18 @@ class HFRunner:
148
218
  low_cpu_mem_usage=True,
149
219
  ).cuda()
150
220
  elif self.model_type == "embedding":
151
- self.model = _get_sentence_transformer_embedding_model(
152
- model_path, torch_dtype
153
- )
221
+ if "gme-qwen2-vl" in model_path.lower():
222
+ self.model = AutoModelForVision2Seq.from_pretrained(
223
+ model_path,
224
+ torch_dtype=torch_dtype,
225
+ trust_remote_code=False,
226
+ low_cpu_mem_usage=True,
227
+ ).cuda()
228
+ self.processor = AutoProcessor.from_pretrained(model_path)
229
+ else:
230
+ self.model = _get_sentence_transformer_embedding_model(
231
+ model_path, torch_dtype
232
+ )
154
233
  elif self.model_type == "reward":
155
234
  from transformers import AutoModelForSequenceClassification
156
235
 
@@ -169,7 +248,9 @@ class HFRunner:
169
248
 
170
249
  # Run forward
171
250
  while True:
172
- prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
251
+ prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob = (
252
+ in_queue.get()
253
+ )
173
254
  if lora_paths is not None:
174
255
  assert len(prompts) == len(lora_paths)
175
256
 
@@ -189,7 +270,10 @@ class HFRunner:
189
270
  )
190
271
  elif self.model_type == "embedding":
191
272
  assert not self.output_str_only
192
- logits = self.model.encode(prompts).tolist()
273
+ if "gme-qwen2-vl" in model_path.lower():
274
+ logits = self._get_gme_qwen2_vl_embeddings(prompts, image_data)
275
+ else:
276
+ logits = self.model.encode(prompts).tolist()
193
277
  out_queue.put(ModelOutput(embed_logits=logits))
194
278
 
195
279
  elif self.model_type == "reward":
@@ -211,11 +295,14 @@ class HFRunner:
211
295
  def forward(
212
296
  self,
213
297
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
298
+ image_data: Optional[List[str]] = None,
214
299
  max_new_tokens: int = 8,
215
300
  lora_paths: Optional[List[str]] = None,
216
301
  token_ids_logprob: Optional[int] = None,
217
302
  ):
218
- self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
303
+ self.in_queue.put(
304
+ (prompts, image_data, max_new_tokens, lora_paths, token_ids_logprob)
305
+ )
219
306
  return self.out_queue.get()
220
307
 
221
308
  def terminate(self):
@@ -396,6 +483,7 @@ class SRTRunner:
396
483
  def forward(
397
484
  self,
398
485
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
486
+ image_data: Optional[List[str]] = None,
399
487
  max_new_tokens: int = 8,
400
488
  lora_paths: Optional[List[str]] = None,
401
489
  logprob_start_len: int = 0,
@@ -413,17 +501,23 @@ class SRTRunner:
413
501
  token_ids_logprob=token_ids_logprob,
414
502
  )
415
503
  else:
416
- response = self.engine.encode(prompts)
417
504
  if self.model_type == "embedding":
418
- logits = [x["embedding"] for x in response]
505
+ response = self.engine.encode(prompt=prompts, image_data=image_data)
506
+ if isinstance(response, list):
507
+ logits = [x["embedding"] for x in response]
508
+ else:
509
+ logits = [response["embedding"]]
419
510
  return ModelOutput(embed_logits=logits)
511
+ # reward model
420
512
  else:
513
+ response = self.engine.encode(prompts)
421
514
  scores = [x["embedding"][0] for x in response]
422
515
  return ModelOutput(scores=scores)
423
516
 
424
517
  def batch_forward(
425
518
  self,
426
519
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
520
+ image_data: Optional[List[str]] = None,
427
521
  max_new_tokens=8,
428
522
  lora_paths=None,
429
523
  ):
@@ -439,7 +533,7 @@ class SRTRunner:
439
533
  lora_paths=lora_paths,
440
534
  )
441
535
  else:
442
- response = self.engine.encode(prompts)
536
+ response = self.engine.encode(prompts, image_data)
443
537
  if self.model_type == "embedding":
444
538
  logits = [x["embedding"] for x in response]
445
539
  return ModelOutput(embed_logits=logits)
@@ -1,4 +1,5 @@
1
1
  import itertools
2
+ import os
2
3
  import unittest
3
4
 
4
5
  import torch
@@ -7,9 +8,12 @@ from sglang.srt.layers.activation import SiluAndMul
7
8
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
8
9
  from sglang.srt.layers.quantization.fp8_kernel import (
9
10
  per_token_group_quant_fp8,
11
+ static_quant_fp8,
10
12
  w8a8_block_fp8_matmul,
11
13
  )
12
14
 
15
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
16
+
13
17
 
14
18
  # For test
15
19
  def native_per_token_group_quant_fp8(
@@ -63,7 +67,7 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
63
67
  out, scale = per_token_group_quant_fp8(x, group_size)
64
68
 
65
69
  self.assertTrue(
66
- torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
70
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.20)
67
71
  )
68
72
  self.assertTrue(torch.allclose(scale, ref_scale))
69
73
 
@@ -85,6 +89,71 @@ class TestPerTokenGroupQuantFP8(unittest.TestCase):
85
89
  self._per_token_group_quant_fp8(*params)
86
90
 
87
91
 
92
+ # For test
93
+ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
94
+ """Function to perform static quantization on an input tensor `x` using native torch.
95
+
96
+ It converts the tensor values into float8 values and returns the
97
+ quantized tensor along with the scaling factor used for quantization.
98
+ """
99
+ assert x.is_contiguous(), "`x` is not contiguous"
100
+ assert x_s.numel() == 1, "only supports per-tensor scale"
101
+
102
+ finfo = torch.finfo(dtype)
103
+ fp8_min = finfo.min
104
+ fp8_max = finfo.max
105
+
106
+ x_ = x.reshape(x.numel() // x.shape[-1], x.shape[-1])
107
+ x_s_inv = 1.0 / x_s
108
+ x_q = (x_ * x_s_inv).clamp(min=fp8_min, max=fp8_max).to(dtype)
109
+ x_q = x_q.reshape(x.shape)
110
+
111
+ return x_q, x_s
112
+
113
+
114
+ class TestStaticQuantFP8(unittest.TestCase):
115
+ DTYPES = [torch.half, torch.bfloat16, torch.float32]
116
+ NUM_TOKENS = [7, 83, 2048]
117
+ D = [512, 4096, 5120, 13824]
118
+ SEEDS = [0]
119
+
120
+ @classmethod
121
+ def setUpClass(cls):
122
+ if not torch.cuda.is_available():
123
+ raise unittest.SkipTest("CUDA is not available")
124
+ torch.set_default_device("cuda")
125
+
126
+ def _static_quant_fp8(self, num_tokens, d, dtype, seed):
127
+ torch.manual_seed(seed)
128
+
129
+ x = torch.rand(num_tokens, d, dtype=dtype)
130
+ fp8_max = torch.finfo(torch.float8_e4m3fn).max
131
+ x_s = x.max() / fp8_max
132
+
133
+ with torch.inference_mode():
134
+ ref_out, _ = native_static_quant_fp8(x, x_s)
135
+ out, _ = static_quant_fp8(x, x_s, repeat_scale=True)
136
+
137
+ self.assertTrue(
138
+ torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.50)
139
+ )
140
+
141
+ def test_static_quant_fp8(self):
142
+ for params in itertools.product(
143
+ self.NUM_TOKENS,
144
+ self.D,
145
+ self.DTYPES,
146
+ self.SEEDS,
147
+ ):
148
+ with self.subTest(
149
+ num_tokens=params[0],
150
+ d=params[1],
151
+ dtype=params[2],
152
+ seed=params[3],
153
+ ):
154
+ self._static_quant_fp8(*params)
155
+
156
+
88
157
  # For test
89
158
  def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
90
159
  """This function performs matrix multiplication with block-wise quantization using native torch.
@@ -142,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
142
211
 
143
212
 
144
213
  class TestW8A8BlockFP8Matmul(unittest.TestCase):
145
- OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
146
- M = [1, 7, 83, 512, 2048]
147
- N = [128, 512, 1024, 4096, 7748, 13824]
148
- K = [256, 4096, 5120, 3884, 13824]
149
- # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
150
- BLOCK_SIZE = [[128, 128]]
151
- SEEDS = [0]
214
+
215
+ if not _is_cuda:
216
+ OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
217
+ M = [1, 7, 83, 512, 2048]
218
+ NKs = [
219
+ (N, K)
220
+ for N in [128, 512, 1024, 4096, 7748, 13824]
221
+ for K in [256, 4096, 5120, 3884, 13824]
222
+ ]
223
+ # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
224
+ BLOCK_SIZE = [[128, 128]]
225
+ SEEDS = [0]
226
+ else:
227
+ # use practical shape in DeepSeek V3 for test
228
+ OUT_DTYPES = [torch.bfloat16]
229
+ M = [64, 128, 512, 1024, 4096]
230
+ NKs = [
231
+ (1536, 7168),
232
+ (3072, 1536),
233
+ (24576, 7168),
234
+ (4096, 512),
235
+ (7168, 2048),
236
+ (4608, 7168),
237
+ (512, 7168),
238
+ (7168, 2304),
239
+ (7168, 512),
240
+ ]
241
+ BLOCK_SIZE = [[128, 128]]
242
+ SEEDS = [0]
152
243
 
153
244
  @classmethod
154
245
  def setUpClass(cls):
@@ -156,7 +247,8 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
156
247
  raise unittest.SkipTest("CUDA is not available")
157
248
  torch.set_default_device("cuda")
158
249
 
159
- def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
250
+ def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
251
+ N, K = NK
160
252
  torch.manual_seed(seed)
161
253
  # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
162
254
  factor_for_scale = 1e-2
@@ -191,19 +283,17 @@ class TestW8A8BlockFP8Matmul(unittest.TestCase):
191
283
  def test_w8a8_block_fp8_matmul(self):
192
284
  for params in itertools.product(
193
285
  self.M,
194
- self.N,
195
- self.K,
286
+ self.NKs,
196
287
  self.BLOCK_SIZE,
197
288
  self.OUT_DTYPES,
198
289
  self.SEEDS,
199
290
  ):
200
291
  with self.subTest(
201
292
  M=params[0],
202
- N=params[1],
203
- K=params[2],
204
- block_size=params[3],
205
- out_dtype=params[4],
206
- seed=params[5],
293
+ NKs=params[1],
294
+ block_size=params[2],
295
+ out_dtype=params[3],
296
+ seed=params[4],
207
297
  ):
208
298
  self._w8a8_block_fp8_matmul(*params)
209
299
 
@@ -0,0 +1,88 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/8ca7a71df787ad711ad3ac70a5bd2eb2bb398938/tests/quantization/test_fp8.py
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from sglang.srt.custom_op import scaled_fp8_quant
7
+ from sglang.srt.utils import is_cuda
8
+
9
+
10
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
11
+ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
12
+
13
+ def quantize_ref_per_tensor(tensor, inv_scale):
14
+ # The reference implementation that fully aligns to
15
+ # the kernel being tested.
16
+ finfo = torch.finfo(torch.float8_e4m3fn)
17
+ scale = inv_scale.reciprocal()
18
+ qweight = (tensor.to(torch.float32) * scale).clamp(min=finfo.min, max=finfo.max)
19
+ qweight = qweight.to(torch.float8_e4m3fn)
20
+ return qweight
21
+
22
+ def dequantize_per_tensor(tensor, inv_scale, dtype):
23
+ fake_qweight = tensor.to(dtype)
24
+ dq_weight = fake_qweight * inv_scale
25
+ return dq_weight
26
+
27
+ # Note that we use a shape % 8 != 0 to cover edge cases,
28
+ # because scaled_fp8_quant is vectorized by 8.
29
+ x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
30
+
31
+ # Test Per Tensor Dynamic quantization
32
+ # scale = max(abs(x)) / FP8_E4M3_MAX
33
+ y, scale = scaled_fp8_quant(x, None)
34
+ ref_y = quantize_ref_per_tensor(x, scale)
35
+ torch.testing.assert_close(y, ref_y)
36
+ torch.testing.assert_close(
37
+ dequantize_per_tensor(y, scale, dtype),
38
+ dequantize_per_tensor(ref_y, scale, dtype),
39
+ )
40
+
41
+ # Test Per Tensor Static quantization
42
+ y, _ = scaled_fp8_quant(x, scale)
43
+ ref_y = quantize_ref_per_tensor(x, scale)
44
+ torch.testing.assert_close(y, ref_y)
45
+ torch.testing.assert_close(
46
+ dequantize_per_tensor(y, scale, dtype),
47
+ dequantize_per_tensor(ref_y, scale, dtype),
48
+ )
49
+
50
+
51
+ if is_cuda:
52
+
53
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
54
+ def test_scaled_fp8_quant_per_token_dynamic(dtype) -> None:
55
+ def quantize_ref_per_token(tensor, inv_scale):
56
+ # The reference implementation that fully aligns to
57
+ # the kernel being tested.
58
+ finfo = torch.finfo(torch.float8_e4m3fn)
59
+ scale = inv_scale.reciprocal()
60
+ qweight = (tensor.to(torch.float32) * scale).clamp(
61
+ min=finfo.min, max=finfo.max
62
+ )
63
+ qweight = qweight.to(torch.float8_e4m3fn)
64
+ return qweight
65
+
66
+ def dequantize_per_token(tensor, inv_scale, dtype):
67
+ fake_qweight = tensor.to(dtype)
68
+ dq_weight = fake_qweight * inv_scale
69
+ return dq_weight
70
+
71
+ # Note that we use a shape % 8 = 0,
72
+ # because per_token_quant_fp8 is vectorized by 8 elements.
73
+ x = (torch.randn(size=(11, 16), device="cuda") * 13).to(dtype)
74
+
75
+ # Test Per Tensor Dynamic quantization
76
+ # scale = max(abs(x)) / FP8_E4M3_MAX
77
+ y, scale = scaled_fp8_quant(x, None, use_per_token_if_dynamic=True)
78
+ ref_y = quantize_ref_per_token(x, scale)
79
+ torch.testing.assert_close(y, ref_y)
80
+ torch.testing.assert_close(
81
+ dequantize_per_token(y, scale, dtype),
82
+ dequantize_per_token(ref_y, scale, dtype),
83
+ )
84
+
85
+
86
+ if __name__ == "__main__":
87
+ # Run the specific test function directly
88
+ pytest.main([__file__])
sglang/test/test_utils.py CHANGED
@@ -28,6 +28,10 @@ from sglang.test.run_eval import run_eval
28
28
  from sglang.utils import get_exception_traceback
29
29
 
30
30
  DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
31
+ DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST = "neuralmagic/Meta-Llama-3-8B-Instruct-FP8"
32
+ DEFAULT_FP8_MODEL_NAME_FOR_DYNAMIC_QUANT_ACCURACY_TEST = (
33
+ "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8-dynamic"
34
+ )
31
35
  DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
32
36
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
33
37
  DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
@@ -36,12 +40,15 @@ DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instru
36
40
  DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
37
41
  DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
38
42
  DEFAULT_REASONING_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
43
+ DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST = (
44
+ "hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
45
+ )
39
46
  DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
40
47
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
41
48
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
42
49
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
43
50
  DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
44
- DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
51
+ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4,hugging-quants/Mixtral-8x7B-Instruct-v0.1-AWQ-INT4"
45
52
  DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
46
53
  DEFAULT_SMALL_VLM_MODEL_NAME = "Qwen/Qwen2-VL-2B"
47
54
 
@@ -446,22 +453,31 @@ def run_with_timeout(
446
453
  return ret_value[0]
447
454
 
448
455
 
449
- def run_unittest_files(files: List[str], timeout_per_file: float):
456
+ def run_unittest_files(files: List, timeout_per_file: float):
450
457
  tic = time.time()
451
458
  success = True
452
459
 
453
- for filename in files:
460
+ for file in files:
461
+ filename, estimated_time = file.name, file.estimated_time
454
462
  process = None
455
463
 
456
464
  def run_one_file(filename):
457
465
  nonlocal process
458
466
 
459
467
  filename = os.path.join(os.getcwd(), filename)
460
- print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
468
+ print(f".\n.\nBegin:\npython3 {filename}\n.\n.\n", flush=True)
469
+ tic = time.time()
470
+
461
471
  process = subprocess.Popen(
462
472
  ["python3", filename], stdout=None, stderr=None, env=os.environ
463
473
  )
464
474
  process.wait()
475
+ elapsed = time.time() - tic
476
+
477
+ print(
478
+ f".\n.\nEnd:\n{filename=}, {elapsed=:.0f}, {estimated_time=}\n.\n.\n",
479
+ flush=True,
480
+ )
465
481
  return process.returncode
466
482
 
467
483
  try:
sglang/utils.py CHANGED
@@ -24,14 +24,10 @@ import requests
24
24
  from IPython.display import HTML, display
25
25
  from tqdm import tqdm
26
26
 
27
- from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
28
27
  from sglang.srt.utils import kill_process_tree
29
28
 
30
29
  logger = logging.getLogger(__name__)
31
30
 
32
- # type of content fields, can be only prompts or with images/videos
33
- MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
34
-
35
31
 
36
32
  def get_exception_traceback():
37
33
  etype, value, tb = sys.exc_info()
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.4.3.post4"
1
+ __version__ = "0.4.4.post1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: sglang
3
- Version: 0.4.3.post4
3
+ Version: 0.4.4.post1
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -211,19 +211,22 @@ Classifier: License :: OSI Approved :: Apache Software License
211
211
  Requires-Python: >=3.8
212
212
  Description-Content-Type: text/markdown
213
213
  License-File: LICENSE
214
+ Requires-Dist: aiohttp
214
215
  Requires-Dist: requests
215
216
  Requires-Dist: tqdm
216
217
  Requires-Dist: numpy
217
218
  Requires-Dist: IPython
218
219
  Requires-Dist: setproctitle
219
220
  Provides-Extra: runtime-common
220
- Requires-Dist: aiohttp; extra == "runtime-common"
221
+ Requires-Dist: datasets; extra == "runtime-common"
221
222
  Requires-Dist: decord; extra == "runtime-common"
222
223
  Requires-Dist: fastapi; extra == "runtime-common"
223
224
  Requires-Dist: hf_transfer; extra == "runtime-common"
224
225
  Requires-Dist: huggingface_hub; extra == "runtime-common"
225
226
  Requires-Dist: interegular; extra == "runtime-common"
227
+ Requires-Dist: llguidance>=0.6.15; extra == "runtime-common"
226
228
  Requires-Dist: modelscope; extra == "runtime-common"
229
+ Requires-Dist: ninja; extra == "runtime-common"
227
230
  Requires-Dist: orjson; extra == "runtime-common"
228
231
  Requires-Dist: packaging; extra == "runtime-common"
229
232
  Requires-Dist: pillow; extra == "runtime-common"
@@ -233,24 +236,20 @@ Requires-Dist: pydantic; extra == "runtime-common"
233
236
  Requires-Dist: python-multipart; extra == "runtime-common"
234
237
  Requires-Dist: pyzmq>=25.1.2; extra == "runtime-common"
235
238
  Requires-Dist: torchao>=0.7.0; extra == "runtime-common"
239
+ Requires-Dist: transformers==4.48.3; extra == "runtime-common"
236
240
  Requires-Dist: uvicorn; extra == "runtime-common"
237
241
  Requires-Dist: uvloop; extra == "runtime-common"
238
- Requires-Dist: xgrammar==0.1.14; extra == "runtime-common"
239
- Requires-Dist: ninja; extra == "runtime-common"
240
- Requires-Dist: transformers==4.48.3; extra == "runtime-common"
241
- Requires-Dist: llguidance>=0.6.15; extra == "runtime-common"
242
- Requires-Dist: datasets; extra == "runtime-common"
242
+ Requires-Dist: xgrammar==0.1.15; extra == "runtime-common"
243
243
  Provides-Extra: srt
244
244
  Requires-Dist: sglang[runtime_common]; extra == "srt"
245
- Requires-Dist: sgl-kernel==0.0.3.post6; extra == "srt"
246
- Requires-Dist: flashinfer_python==0.2.2.post1; extra == "srt"
245
+ Requires-Dist: sgl-kernel==0.0.5; extra == "srt"
246
+ Requires-Dist: flashinfer_python==0.2.3; extra == "srt"
247
247
  Requires-Dist: torch==2.5.1; extra == "srt"
248
248
  Requires-Dist: vllm<=0.7.2,>=0.6.4.post1; extra == "srt"
249
249
  Requires-Dist: cuda-python; extra == "srt"
250
250
  Requires-Dist: outlines<=0.1.11,>=0.0.44; extra == "srt"
251
251
  Provides-Extra: srt-hip
252
252
  Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
253
- Requires-Dist: sgl-kernel==0.0.3.post6; extra == "srt-hip"
254
253
  Requires-Dist: torch; extra == "srt-hip"
255
254
  Requires-Dist: vllm==0.6.7.dev2; extra == "srt-hip"
256
255
  Requires-Dist: outlines==0.1.11; extra == "srt-hip"