sglang 0.4.3.post4__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (124) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +1 -3
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +32 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +124 -665
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +6 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +78 -17
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/model_executor/cuda_graph_runner.py +9 -4
  95. sglang/srt/model_executor/forward_batch_info.py +12 -8
  96. sglang/srt/model_executor/model_runner.py +63 -63
  97. sglang/srt/model_loader/loader.py +2 -1
  98. sglang/srt/model_loader/weight_utils.py +1 -1
  99. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  100. sglang/srt/models/deepseek_nextn.py +23 -3
  101. sglang/srt/models/deepseek_v2.py +25 -19
  102. sglang/srt/models/minicpmv.py +28 -89
  103. sglang/srt/models/mllama.py +1 -1
  104. sglang/srt/models/qwen2.py +0 -1
  105. sglang/srt/models/qwen2_5_vl.py +25 -50
  106. sglang/srt/models/qwen2_vl.py +33 -49
  107. sglang/srt/openai_api/adapter.py +37 -15
  108. sglang/srt/openai_api/protocol.py +8 -1
  109. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  110. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  111. sglang/srt/server_args.py +19 -11
  112. sglang/srt/speculative/eagle_worker.py +75 -39
  113. sglang/srt/utils.py +104 -9
  114. sglang/test/runners.py +104 -10
  115. sglang/test/test_block_fp8.py +106 -16
  116. sglang/test/test_custom_ops.py +88 -0
  117. sglang/test/test_utils.py +20 -4
  118. sglang/utils.py +0 -4
  119. sglang/version.py +1 -1
  120. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -10
  121. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/RECORD +124 -79
  122. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  123. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  124. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from functools import lru_cache
4
- from typing import Optional
4
+ from typing import Optional, Tuple
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
8
  import torch.nn.functional as F
9
- from einops import rearrange, repeat
9
+ from einops import rearrange
10
10
 
11
11
  from sglang.srt.distributed import parallel_state
12
12
  from sglang.srt.distributed import utils as dist_utils
@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
22
22
  from sglang.srt.utils import add_prefix
23
23
 
24
24
 
25
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
26
- if not interleaved:
27
- x1, x2 = x.chunk(2, dim=-1)
28
- return torch.cat((-x2, x1), dim=-1)
29
- else:
30
- x1, x2 = x[..., ::2], x[..., 1::2]
31
- return rearrange(
32
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
33
- )
25
+ # Copied from transformers, modeling_qwen2_vl.py
26
+ def rotate_half(x):
27
+ """Rotates half the hidden dims of the input."""
28
+ x1 = x[..., : x.shape[-1] // 2]
29
+ x2 = x[..., x.shape[-1] // 2 :]
30
+ return torch.cat((-x2, x1), dim=-1)
34
31
 
35
32
 
36
- def apply_rotary_emb_torch(
37
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
38
- ) -> torch.Tensor:
39
- """
40
- x: (batch_size, seqlen, nheads, headdim)
41
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
42
- """
43
- ro_dim = cos.shape[-1] * 2
44
- assert ro_dim <= x.shape[-1]
45
- cos = repeat(
46
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
47
- )
48
- sin = repeat(
49
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
50
- )
51
- return torch.cat(
52
- [
53
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
54
- x[..., ro_dim:],
55
- ],
56
- dim=-1,
57
- )
58
-
59
-
60
- def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
61
- t_ = t.float()
62
- cos = freqs.cos()
63
- sin = freqs.sin()
64
- output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
65
- return output
33
+ def apply_rotary_pos_emb_vision(
34
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ orig_q_dtype = q.dtype
37
+ orig_k_dtype = k.dtype
38
+ q, k = q.float(), k.float()
39
+
40
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
41
+ q_embed = (q * cos) + (rotate_half(q) * sin)
42
+ k_embed = (k * cos) + (rotate_half(k) * sin)
43
+
44
+ q_embed = q_embed.to(orig_q_dtype)
45
+ k_embed = k_embed.to(orig_k_dtype)
46
+
47
+ return q_embed, k_embed
66
48
 
67
49
 
68
50
  class VisionAttention(nn.Module):
@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
75
57
  use_context_forward (bool, default to True):
76
58
  if ``True``, a flash_attn style attention will be applied
77
59
  Otherwise, a full-sequence attention will be applied.
78
- use_full_precision_softmax (bool, default to False):
79
- if ``True``, the softmax will be performed in full-precision
60
+ softmax_in_single_precision (bool, default to False):
61
+ if ``True``, the softmax will be performed in single-precision
80
62
  Otherwise, it will be performed in half-precision
81
63
 
82
64
  """
@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
90
72
  quant_config: Optional[QuantizationConfig] = None,
91
73
  dropout: float = 0.0,
92
74
  use_context_forward: bool = True,
93
- use_full_precision_softmax: bool = False,
75
+ softmax_in_single_precision: bool = False,
94
76
  flatten_batch: bool = False,
95
77
  prefix: str = "",
96
78
  ):
@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
113
95
  head_size=self.head_size,
114
96
  dropout=dropout,
115
97
  flatten_batch=flatten_batch,
116
- use_full_precision_softmax=use_full_precision_softmax,
98
+ softmax_in_single_precision=softmax_in_single_precision,
117
99
  )
118
100
 
119
101
  self.use_qkv_parallel = use_qkv_parallel
@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
143
125
  self,
144
126
  x: torch.Tensor,
145
127
  cu_seqlens: Optional[torch.Tensor] = None,
146
- rotary_pos_emb: torch.Tensor = None,
128
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
147
129
  attention_mask: Optional[torch.Tensor] = None,
148
130
  ) -> torch.Tensor:
149
131
  r"""
@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
151
133
  x: [b, s, embed_dim]
152
134
  cu_seqlens: [b]
153
135
  Returns:
154
- [s, b, num_heads * head]
136
+ [s, b, head * head_size]
155
137
  """
156
138
  bsz, s, _ = x.shape
139
+ head = self.num_attention_heads_per_partition
157
140
  if self.use_qkv_parallel:
158
141
  # [b, s, embed_dim] --> [b, s, embed_dim]
159
142
  qkv, _ = self.qkv_proj(x)
160
143
  q, k, v = qkv.chunk(3, dim=-1)
161
144
 
162
- # [b, s, embed_dim] --> [b * s, num_heads, head_size]
163
- q, k, v = [
164
- x.reshape(
165
- bsz * s, self.num_attention_heads_per_partition, -1
166
- ).contiguous()
167
- for x in (q, k, v)
168
- ]
145
+ # [b, s, embed_dim] --> [b * s, head, head_size]
146
+ q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
169
147
  else:
170
148
  # [b, s, embed_dim] --> [s, b, embed_dim]
171
149
  x = rearrange(x, "b s ... -> s b ...")
@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
173
151
  qkv, _ = self.qkv_proj(x)
174
152
  # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
175
153
  new_x_shape = qkv.size()[:-1] + (
176
- self.num_attention_heads_per_partition,
154
+ head,
177
155
  3 * self.hidden_size_per_attention_head,
178
156
  )
179
157
  qkv = qkv.view(*new_x_shape)
@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
186
164
  rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
187
165
  ]
188
166
 
189
- if rotary_pos_emb is not None:
190
- q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
191
- k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
167
+ if position_embeddings is not None:
168
+ cos, sin = position_embeddings
169
+ original_shape = q.shape
170
+ q, k = q.view(s, head, -1), k.view(s, head, -1)
171
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
172
+ q, k = q.reshape(original_shape), k.reshape(original_shape)
192
173
 
193
174
  if self.use_qkv_parallel:
194
175
  pass
@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
230
211
  head_size: int,
231
212
  dropout: float = 0.0,
232
213
  flatten_batch: bool = False,
233
- use_full_precision_softmax: bool = False,
214
+ softmax_in_single_precision: bool = False,
234
215
  ):
235
216
  super().__init__()
236
217
  self.head_size = head_size
237
218
  self.flatten_batch = flatten_batch
238
- self.use_full_precision_softmax = use_full_precision_softmax
219
+ self.softmax_in_single_precision = softmax_in_single_precision
239
220
  self.dropout = dropout
240
221
 
241
222
  @staticmethod
@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
319
300
  )
320
301
 
321
302
  if attention_mask is None:
322
- if self.use_full_precision_softmax:
303
+ if self.softmax_in_single_precision:
323
304
  raise RuntimeError("Empty attention mask")
324
305
  else:
325
306
  attention_mask = attention_mask.to(device=q.device)
326
307
 
327
308
  q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
328
309
 
329
- if self.use_full_precision_softmax:
310
+ if self.softmax_in_single_precision:
330
311
  scale = self.head_size**-0.5
331
312
  k_transposed = rearrange(k, "b h s d -> b h d s")
332
313
  attn_weights = torch.matmul(q, k_transposed) * scale
@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
18
18
  )
19
19
  from sglang.srt.layers.parameter import (
20
20
  BasevLLMParameter,
21
+ BlockQuantScaleParameter,
21
22
  PackedColumnParameter,
22
23
  PackedvLLMParameter,
23
24
  PerTensorScaleParameter,
@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
27
28
  QuantizationConfig,
28
29
  QuantizeMethodBase,
29
30
  )
30
- from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
31
31
  from sglang.srt.utils import set_weight_attrs
32
32
 
33
33
  logger = logging.getLogger(__name__)
@@ -6,8 +6,9 @@ import triton
6
6
  import triton.language as tl
7
7
 
8
8
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
+ from sglang.srt.utils import is_cuda
9
10
 
10
- _is_cuda = torch.cuda.is_available() and torch.version.cuda
11
+ _is_cuda = is_cuda()
11
12
  if _is_cuda:
12
13
  from sglang.srt.layers.quantization.fp8_kernel import (
13
14
  sglang_per_token_group_quant_fp8,
@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
  from torch.nn import Module
6
- from vllm import _custom_ops as ops
6
+ from vllm import _custom_ops as vllm_ops
7
7
 
8
8
  from sglang.srt.custom_op import CustomOp
9
9
  from sglang.srt.distributed import (
@@ -26,10 +26,18 @@ from sglang.srt.layers.quantization.base_config import (
26
26
  QuantizeMethodBase,
27
27
  )
28
28
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
29
- from sglang.srt.utils import is_hip, set_weight_attrs
29
+ from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
30
+
31
+ _is_cuda = is_cuda()
32
+
33
+ if _is_cuda:
34
+ from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
35
+
30
36
 
31
37
  logger = logging.getLogger(__name__)
32
38
 
39
+ _is_hip = is_hip()
40
+
33
41
 
34
42
  class GroupedGemmRunner(torch.nn.Module):
35
43
  flashinfer_gemm_warpper = None
@@ -703,7 +711,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
703
711
  # If checkpoint is fp16, quantize in place.
704
712
  if not self.quant_config.is_checkpoint_fp8_serialized:
705
713
  # If rocm, use float8_e4m3fnuz as dtype
706
- fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
714
+ fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
707
715
  w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
708
716
  w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
709
717
 
@@ -717,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
717
725
  )
718
726
 
719
727
  for expert in range(layer.num_experts_per_partition):
720
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
721
- ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
722
- )
723
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
724
- ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
725
- )
728
+ if _is_cuda:
729
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
730
+ sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
731
+ )
732
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
733
+ sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
734
+ )
735
+ else:
736
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
737
+ vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
738
+ )
739
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
740
+ vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
741
+ )
726
742
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
727
743
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
728
744
  return
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 5
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 2
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 2
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 64,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 1,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }