sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
+ import os
1
2
  from typing import List, Optional, Tuple
2
3
 
3
4
  import torch
4
5
 
5
6
  from sglang.srt.layers.quantization.fp8_kernel import (
7
+ _enable_jit_deepgemm,
6
8
  per_token_group_quant_fp8,
7
9
  static_quant_fp8,
8
10
  w8a8_block_fp8_matmul,
@@ -15,6 +17,14 @@ from sglang.srt.utils import (
15
17
  is_hip,
16
18
  )
17
19
 
20
+ try:
21
+ import vllm
22
+ from vllm import _custom_ops as ops
23
+
24
+ VLLM_AVAILABLE = True
25
+ except ImportError:
26
+ VLLM_AVAILABLE = False
27
+
18
28
  use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
19
29
 
20
30
  _is_hip = is_hip()
@@ -23,19 +33,29 @@ if _is_hip and get_bool_env_var("CK_MOE"):
23
33
 
24
34
  _is_cuda = is_cuda()
25
35
  if _is_cuda:
26
- from sgl_kernel import fp8_blockwise_scaled_mm
36
+ from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
27
37
 
38
+ from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
28
39
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8
29
40
 
30
- if use_vllm_cutlass_w8a8_fp8_kernel:
31
- from vllm import _custom_ops as ops
32
- else:
33
- from sgl_kernel import fp8_scaled_mm
34
-
35
41
  # Input scaling factors are no longer optional in _scaled_mm starting
36
42
  # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
37
43
  TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
38
44
 
45
+ _TORCH_VERSION = torch.__version__.split("+")[0]
46
+ try:
47
+ _TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
48
+ except ValueError:
49
+ _TORCH_VERSION_TUPLE = (0, 0, 0)
50
+
51
+ # The condition to determine if it is on a platform that supports
52
+ # torch._scaled_mm rowwise feature.
53
+ # The condition is determined once as the operations
54
+ # are time consuming.
55
+ USE_ROWWISE_TORCH_SCALED_MM = (
56
+ _is_hip and get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
57
+ )
58
+
39
59
 
40
60
  def cutlass_fp8_supported():
41
61
  if not _is_cuda:
@@ -74,7 +94,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
74
94
 
75
95
 
76
96
  def cutlass_block_fp8_supported() -> bool:
77
- if get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
97
+ if not get_bool_env_var("SUPPORT_CUTLASS_BLOCK_FP8"):
78
98
  return False
79
99
  if _is_cuda:
80
100
  major, minor = torch.cuda.get_device_capability()
@@ -122,9 +142,17 @@ def apply_w8a8_block_fp8_linear(
122
142
  )
123
143
  gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
124
144
  else:
125
- q_input, x_scale = per_token_group_quant_fp8(
126
- input_2d, block_size[1], column_major_scales=False
127
- )
145
+ if _enable_jit_deepgemm:
146
+ q_input, x_scale = per_token_group_quant_fp8(
147
+ input_2d,
148
+ block_size[1],
149
+ column_major_scales=True,
150
+ scale_tma_aligned=True,
151
+ )
152
+ else:
153
+ q_input, x_scale = per_token_group_quant_fp8(
154
+ input_2d, block_size[1], column_major_scales=False
155
+ )
128
156
  output = w8a8_block_fp8_matmul(
129
157
  q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
130
158
  )
@@ -219,24 +247,32 @@ def apply_fp8_linear(
219
247
  )
220
248
 
221
249
  if cutlass_fp8_supported:
222
- if use_vllm_cutlass_w8a8_fp8_kernel:
223
- # Fall back to vllm cutlass w8a8 fp8 kernel
224
- output = ops.cutlass_scaled_mm(
225
- qinput,
226
- weight,
227
- out_dtype=input.dtype,
228
- scale_a=x_scale,
229
- scale_b=weight_scale,
230
- bias=bias,
231
- )
232
- else:
233
- assert (
234
- weight_scale.numel() == weight.shape[1]
235
- ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
236
- output = fp8_scaled_mm(
237
- qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias
238
- )
239
- return output.view(*output_shape)
250
+ try:
251
+ if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
252
+ # Fall back to vllm cutlass w8a8 fp8 kernel
253
+ output = ops.cutlass_scaled_mm(
254
+ qinput,
255
+ weight,
256
+ out_dtype=input.dtype,
257
+ scale_a=x_scale,
258
+ scale_b=weight_scale,
259
+ bias=bias,
260
+ )
261
+ else:
262
+ assert (
263
+ weight_scale.numel() == weight.shape[1]
264
+ ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
265
+ output = fp8_scaled_mm(
266
+ qinput,
267
+ weight,
268
+ x_scale,
269
+ weight_scale,
270
+ out_dtype=input.dtype,
271
+ bias=bias,
272
+ )
273
+ return output.view(*output_shape)
274
+ except (ImportError, NameError, AttributeError):
275
+ pass
240
276
 
241
277
  # torch.scaled_mm supports per tensor weights + activations only
242
278
  # so fallback to naive if per channel or per token
@@ -306,3 +342,223 @@ def apply_fp8_linear(
306
342
  if bias is not None:
307
343
  output = output + bias
308
344
  return output.to(dtype=input.dtype).view(*output_shape)
345
+
346
+
347
+ def maybe_create_device_identity():
348
+ # Allocate dummy ones tensor for torch._scaled_mm
349
+ global TORCH_DEVICE_IDENTITY
350
+ if TORCH_DEVICE_IDENTITY is None:
351
+ TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
352
+
353
+
354
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/w8a8_utils.py
355
+ # TODO(luka): follow similar pattern for marlin and block-fp8-linear
356
+ # https://github.com/vllm-project/vllm/issues/14397
357
+ class Fp8LinearOp:
358
+ """
359
+ This class executes a FP8 linear layer using cutlass if supported and
360
+ torch.scaled_mm otherwise.
361
+ It needs to be a class instead of a method so that config can be read
362
+ in the __init__ method, as reading config is not allowed inside forward.
363
+ """
364
+
365
+ def __init__(
366
+ self,
367
+ cutlass_fp8_supported: bool = cutlass_fp8_supported(),
368
+ use_per_token_if_dynamic: bool = False,
369
+ pad_output: Optional[bool] = None,
370
+ ):
371
+ self.cutlass_fp8_supported = cutlass_fp8_supported
372
+ self.use_per_token_if_dynamic = use_per_token_if_dynamic
373
+
374
+ # Note: we pad the input because torch._scaled_mm is more performant
375
+ # for matrices with batch dimension > 16.
376
+ # This could change in the future.
377
+ # We also don't pad when using torch.compile,
378
+ # as it breaks with dynamic shapes.
379
+ if pad_output is None:
380
+ enable_torch_compile = os.environ.get(
381
+ "SGLANG_ENABLE_TORCH_COMPILE", "0"
382
+ ).lower() in ("1", "true", "yes")
383
+ pad_output = not enable_torch_compile
384
+ self.output_padding = 17 if pad_output else None
385
+
386
+ def apply(
387
+ self,
388
+ input: torch.Tensor,
389
+ weight: torch.Tensor,
390
+ weight_scale: torch.Tensor,
391
+ input_scale: Optional[torch.Tensor] = None,
392
+ input_scale_ub: Optional[torch.Tensor] = None,
393
+ bias: Optional[torch.Tensor] = None,
394
+ # TODO(luka) remove this parameter in favor of __init__
395
+ use_per_token_if_dynamic: Optional[bool] = None,
396
+ ) -> torch.Tensor:
397
+ # ops.scaled_fp8_quant supports both dynamic and static quant.
398
+ # If dynamic, layer.input_scale is None and x_scale computed from x.
399
+ # If static, layer.input_scale is scalar and x_scale is input_scale.
400
+
401
+ # View input as 2D matrix for fp8 methods
402
+ input_2d = input.view(-1, input.shape[-1])
403
+ output_shape = [*input.shape[:-1], weight.shape[1]]
404
+
405
+ # TODO(luka) this is here because currently MLA only decides this
406
+ # during the forward method instead of in __init__.
407
+ if use_per_token_if_dynamic is None:
408
+ use_per_token_if_dynamic = self.use_per_token_if_dynamic
409
+
410
+ # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
411
+ # for sgl-kernel fp8_scaled_mm, it support per channel W now
412
+ if self.cutlass_fp8_supported and weight_scale.numel() == weight.shape[1]:
413
+ if _is_cuda:
414
+ qinput, x_scale = sgl_scaled_fp8_quant(
415
+ input_2d,
416
+ input_scale,
417
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
418
+ )
419
+ else:
420
+ qinput, x_scale = ops.scaled_fp8_quant(
421
+ input_2d,
422
+ input_scale,
423
+ scale_ub=input_scale_ub,
424
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
425
+ )
426
+
427
+ # Fused GEMM_DQ
428
+ if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel:
429
+ # Fall back to vllm cutlass w8a8 fp8 kernel
430
+ output = ops.cutlass_scaled_mm(
431
+ qinput,
432
+ weight,
433
+ out_dtype=input.dtype,
434
+ scale_a=x_scale,
435
+ scale_b=weight_scale,
436
+ bias=bias,
437
+ )
438
+ else:
439
+ assert (
440
+ weight_scale.numel() == weight.shape[1]
441
+ ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale"
442
+ output = fp8_scaled_mm(
443
+ qinput,
444
+ weight,
445
+ x_scale,
446
+ weight_scale,
447
+ out_dtype=input.dtype,
448
+ bias=bias,
449
+ )
450
+ return output.view(*output_shape)
451
+
452
+ # torch.scaled_mm supports per tensor weights + activations only
453
+ # so fallback to naive if per channel or per token
454
+ else:
455
+ # Maybe apply padding to output, see comment in __init__
456
+ if _is_cuda:
457
+ qinput, x_scale = sgl_scaled_fp8_quant(
458
+ input_2d,
459
+ input_scale,
460
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
461
+ )
462
+ if self.output_padding:
463
+ pad_size = max(self.output_padding - qinput.shape[0], 0)
464
+ if pad_size > 0:
465
+ qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
466
+ else:
467
+ qinput, x_scale = ops.scaled_fp8_quant(
468
+ input_2d,
469
+ input_scale,
470
+ num_token_padding=self.output_padding,
471
+ use_per_token_if_dynamic=use_per_token_if_dynamic,
472
+ )
473
+
474
+ per_tensor_weights = weight_scale.numel() == 1
475
+ per_tensor_activations = x_scale.numel() == 1
476
+
477
+ if per_tensor_weights and per_tensor_activations:
478
+ # Fused GEMM_DQ
479
+ output = torch._scaled_mm(
480
+ qinput,
481
+ weight,
482
+ out_dtype=input.dtype,
483
+ scale_a=x_scale,
484
+ scale_b=weight_scale,
485
+ bias=bias,
486
+ )
487
+ # A fix for discrepancy in scaled_mm which returns tuple
488
+ # for torch < 2.5 and a single value in torch >= 2.5
489
+ if type(output) is tuple and len(output) == 2:
490
+ output = output[0]
491
+
492
+ return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
493
+
494
+ elif (
495
+ use_per_token_if_dynamic
496
+ and not per_tensor_weights
497
+ and not per_tensor_activations
498
+ and USE_ROWWISE_TORCH_SCALED_MM
499
+ ):
500
+ # For now validated on ROCm platform
501
+ # fp8 rowwise scaling in torch._scaled_mm is introduced in
502
+ # https://github.com/pytorch/pytorch/pull/144432 using hipBLASLt
503
+ # and ROCm 6.3, which only exists in torch 2.7 and above.
504
+ # For CUDA platform please validate if the
505
+ # torch._scaled_mm support rowwise scaled GEMM
506
+ # Fused GEMM_DQ Rowwise GEMM
507
+ output = torch._scaled_mm(
508
+ qinput,
509
+ weight,
510
+ out_dtype=input.dtype,
511
+ scale_a=x_scale,
512
+ scale_b=weight_scale.t(),
513
+ bias=bias,
514
+ )
515
+
516
+ output = torch.narrow(output, 0, 0, input_2d.shape[0])
517
+ output = output.view(*output_shape)
518
+ return output
519
+
520
+ else:
521
+ # Fallback for channelwise case, where we use unfused DQ
522
+ # due to limitations with scaled_mm
523
+
524
+ # Symmetric quantized GEMM by definition computes the following:
525
+ # C = (s_x * X) (s_w * W) + bias
526
+ # This is equivalent to dequantizing the weights and activations
527
+ # before applying a GEMM.
528
+ #
529
+ # In order to compute quantized operands, a quantized kernel
530
+ # will rewrite the above like so:
531
+ # C = s_w * s_x * (X * W) + bias
532
+ #
533
+ # For the scaled_mm fallback case, we break this down, since it
534
+ # does not support s_w being a vector.
535
+
536
+ # GEMM
537
+ # This computes C = (X * W).
538
+ # Output in fp32 to allow subsequent ops to happen in-place
539
+
540
+ global TORCH_DEVICE_IDENTITY
541
+ if TORCH_DEVICE_IDENTITY.device != weight.device:
542
+ TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
543
+
544
+ output = torch._scaled_mm(
545
+ qinput,
546
+ weight,
547
+ scale_a=TORCH_DEVICE_IDENTITY,
548
+ scale_b=TORCH_DEVICE_IDENTITY,
549
+ out_dtype=torch.float32,
550
+ )
551
+ # A fix for discrepancy in scaled_mm which returns tuple
552
+ # for torch < 2.5 and a single value in torch >= 2.5
553
+ if type(output) is tuple and len(output) == 2:
554
+ output = output[0]
555
+ # Unpad (undo num_token_padding)
556
+ output = torch.narrow(output, 0, 0, input_2d.shape[0])
557
+ x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0])
558
+
559
+ # DQ
560
+ # C = sw * sx * (X * W) + bias
561
+ output = output * x_scale * weight_scale.t()
562
+ if bias is not None:
563
+ output = output + bias
564
+ return output.to(dtype=input.dtype).view(*output_shape)
@@ -3,11 +3,19 @@ from fractions import Fraction
3
3
  from typing import Any, Dict, List, Optional, Union
4
4
 
5
5
  import torch
6
- from vllm.scalar_type import scalar_types
7
6
 
8
7
  from sglang.srt.layers.linear import LinearBase
9
8
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
10
- from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
9
+ from sglang.srt.utils import is_cuda
10
+
11
+ _is_cuda = is_cuda()
12
+
13
+ try:
14
+ import vllm
15
+
16
+ VLLM_AVAILABLE = True
17
+ except ImportError:
18
+ VLLM_AVAILABLE = False
11
19
 
12
20
  logger = logging.getLogger(__name__)
13
21
 
@@ -110,6 +118,9 @@ class GPTQConfig(QuantizationConfig):
110
118
  def get_quant_method(
111
119
  self, layer: torch.nn.Module, prefix: str
112
120
  ) -> Optional["GPTQLinearMethod"]:
121
+ if not VLLM_AVAILABLE:
122
+ raise ImportError("vllm is not installed")
123
+
113
124
  from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
114
125
 
115
126
  from sglang.srt.layers.quantization import get_linear_quant_method
@@ -120,11 +131,16 @@ class GPTQConfig(QuantizationConfig):
120
131
  class GPTQMarlinConfig(QuantizationConfig):
121
132
  """Config class for GPTQ Marlin"""
122
133
 
123
- # (num_bits, is_sym) -> quant_type
124
- TYPE_MAP = {
125
- (4, True): scalar_types.uint4b8,
126
- (8, True): scalar_types.uint8b128,
127
- }
134
+ if VLLM_AVAILABLE:
135
+ from vllm.scalar_type import scalar_types
136
+
137
+ # (num_bits, is_sym) -> quant_type
138
+ TYPE_MAP = {
139
+ (4, True): scalar_types.uint4b8,
140
+ (8, True): scalar_types.uint8b128,
141
+ }
142
+ else:
143
+ raise ImportError("vllm is not installed")
128
144
 
129
145
  def __init__(
130
146
  self,
@@ -263,6 +279,9 @@ class GPTQMarlinConfig(QuantizationConfig):
263
279
  def get_quant_method(
264
280
  self, layer: torch.nn.Module, prefix: str
265
281
  ) -> Optional["QuantizeMethodBase"]:
282
+ if not VLLM_AVAILABLE:
283
+ raise ImportError("vllm is not installed")
284
+
266
285
  from vllm.model_executor.layers.quantization.gptq_marlin import (
267
286
  GPTQMarlinLinearMethod,
268
287
  GPTQMarlinMoEMethod,
@@ -285,6 +304,9 @@ class GPTQMarlinConfig(QuantizationConfig):
285
304
 
286
305
  @classmethod
287
306
  def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
307
+ if not VLLM_AVAILABLE:
308
+ return False
309
+
288
310
  quant_method = quant_config.get("quant_method", "").lower()
289
311
  num_bits = quant_config.get("bits")
290
312
  group_size = quant_config.get("group_size")
@@ -294,9 +316,8 @@ class GPTQMarlinConfig(QuantizationConfig):
294
316
  from vllm.model_executor.layers.quantization.utils.marlin_utils import (
295
317
  check_marlin_supported,
296
318
  )
297
- from vllm.platforms import current_platform
298
319
 
299
- if not current_platform.is_cuda():
320
+ if not _is_cuda:
300
321
  return False
301
322
 
302
323
  if quant_method != "gptq":
@@ -407,8 +428,14 @@ class MarlinConfig(QuantizationConfig):
407
428
  def get_quant_method(
408
429
  self, layer: torch.nn.Module, prefix: str
409
430
  ) -> Optional["MarlinLinearMethod"]:
431
+ if not VLLM_AVAILABLE:
432
+ raise ImportError("vllm is not installed")
433
+
410
434
  from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
411
435
 
436
+ # Delay import to avoid circular dependency
437
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
438
+
412
439
  if isinstance(layer, LinearBase) or (
413
440
  isinstance(layer, ParallelLMHead) and self.lm_head_quantized
414
441
  ):
@@ -0,0 +1,98 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kv_cache.py
2
+
3
+ import logging
4
+
5
+ import torch
6
+
7
+ from sglang.srt.layers.quantization.base_config import (
8
+ QuantizationConfig,
9
+ QuantizeMethodBase,
10
+ )
11
+ from sglang.srt.utils import is_hip
12
+
13
+ _is_hip = is_hip()
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class BaseKVCacheMethod(QuantizeMethodBase):
19
+ """
20
+ Quant method that adds `_k_scale` and `_v_scale` attributes to the
21
+ Attention layer to support loading those scaling factors from checkpoints.
22
+ The k/v_scale will be used to:
23
+ - quantize k/v_cache entries before saving them to the cache
24
+ - dequantize k/v_cache entries before fetching them from the cache
25
+
26
+ :param quant_config: the appropriate QuantizationConfig
27
+ """
28
+
29
+ def __init__(self, quant_config: QuantizationConfig):
30
+ self.quant_config = quant_config
31
+
32
+ def create_weights(self, layer: torch.nn.Module):
33
+ """
34
+ Create "weight" (aka k_scale and v_scale) for an attention layer.
35
+ """
36
+ # Initialize the KV cache scales to -1.0, which is an invalid value.
37
+ # If the k/v_scale appears in the checkpoint, it will be
38
+ # overwritten when loading weights.
39
+ layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
40
+ layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False)
41
+
42
+ @classmethod
43
+ def is_fp8_fnuz(cls) -> bool:
44
+ # only device 0 is checked, this assumes MI300 platforms are homogeneous
45
+ return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
46
+
47
+ def apply(self, layer: torch.nn.Module) -> torch.Tensor:
48
+ raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
49
+
50
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
51
+ # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
52
+ # regardless whether the kv-scale is available in the checkpoint.
53
+ # No need to process kv scales after loading if we are going to
54
+ # calculate them on the fly.
55
+ if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales:
56
+ if layer.k_scale > 0.0 and layer.v_scale > 0.0:
57
+ # We prefer to use separate k_scale and v_scale if present
58
+ k_scale = layer.k_scale.to("cpu").tolist()
59
+ v_scale = layer.v_scale.to("cpu").tolist()
60
+ if _is_hip and self.is_fp8_fnuz():
61
+ k_scale *= 2
62
+ v_scale *= 2
63
+ elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
64
+ # If no scales were loaded (both scales are invalid negative
65
+ # values), use the default value of 1.0
66
+ k_scale = 1.0
67
+ v_scale = 1.0
68
+ else:
69
+ # If we find a single kv_scale in the checkpoint, we remap
70
+ # kv_scale to k_scale during weight loading, and duplicate
71
+ # k_scale to v_scale here
72
+ assert layer.k_scale > 0.0
73
+ scale_to_duplicate = max(layer.k_scale, layer.v_scale)
74
+ k_scale = scale_to_duplicate.to("cpu").tolist()
75
+ v_scale = scale_to_duplicate.to("cpu").tolist()
76
+ if _is_hip and self.is_fp8_fnuz():
77
+ k_scale *= 2
78
+ v_scale *= 2
79
+
80
+ if not isinstance(k_scale, float) or not isinstance(v_scale, float):
81
+ raise ValueError(
82
+ "Only support per-tensor scaling factor " "for fp8 KV cache"
83
+ )
84
+
85
+ # These are used in the final Attention.forward()
86
+ layer._k_scale.copy_(k_scale)
87
+ layer._v_scale.copy_(v_scale)
88
+ layer._k_scale_float = k_scale
89
+ layer._v_scale_float = v_scale
90
+ if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
91
+ logger.warning(
92
+ "Using KV cache scaling factor 1.0 for fp8_e4m3. This "
93
+ "may cause accuracy issues. Please make sure k/v_scale "
94
+ "scaling factors are available in the fp8 checkpoint."
95
+ )
96
+
97
+ del layer.k_scale
98
+ del layer.v_scale
@@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn.parameter import Parameter
8
- from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
9
- from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10
- convert_to_channelwise,
11
- cutlass_fp8_supported,
12
- requantize_with_max_scale,
13
- )
14
8
 
15
9
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
16
10
  from sglang.srt.layers.linear import LinearBase, LinearMethodBase
@@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import (
19
13
  QuantizationConfig,
20
14
  QuantizeMethodBase,
21
15
  )
22
- from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
16
+ from sglang.srt.layers.quantization.fp8_utils import (
17
+ apply_fp8_linear,
18
+ cutlass_fp8_supported,
19
+ )
20
+ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
21
+ from sglang.srt.layers.quantization.utils import (
22
+ convert_to_channelwise,
23
+ requantize_with_max_scale,
24
+ )
23
25
 
24
26
  # Initialize logger for the module
25
27
  logger = logging.getLogger(__name__)