sglang 0.4.3.post4__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +124 -665
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +78 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +9 -4
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -11
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/RECORD +124 -79
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from functools import lru_cache
|
4
|
-
from typing import Optional
|
4
|
+
from typing import Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn as nn
|
8
8
|
import torch.nn.functional as F
|
9
|
-
from einops import rearrange
|
9
|
+
from einops import rearrange
|
10
10
|
|
11
11
|
from sglang.srt.distributed import parallel_state
|
12
12
|
from sglang.srt.distributed import utils as dist_utils
|
@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
|
|
22
22
|
from sglang.srt.utils import add_prefix
|
23
23
|
|
24
24
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
return rearrange(
|
32
|
-
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
33
|
-
)
|
25
|
+
# Copied from transformers, modeling_qwen2_vl.py
|
26
|
+
def rotate_half(x):
|
27
|
+
"""Rotates half the hidden dims of the input."""
|
28
|
+
x1 = x[..., : x.shape[-1] // 2]
|
29
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
30
|
+
return torch.cat((-x2, x1), dim=-1)
|
34
31
|
|
35
32
|
|
36
|
-
def
|
37
|
-
|
38
|
-
) -> torch.Tensor:
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
)
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
return torch.cat(
|
52
|
-
[
|
53
|
-
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
54
|
-
x[..., ro_dim:],
|
55
|
-
],
|
56
|
-
dim=-1,
|
57
|
-
)
|
58
|
-
|
59
|
-
|
60
|
-
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
61
|
-
t_ = t.float()
|
62
|
-
cos = freqs.cos()
|
63
|
-
sin = freqs.sin()
|
64
|
-
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
65
|
-
return output
|
33
|
+
def apply_rotary_pos_emb_vision(
|
34
|
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
+
orig_q_dtype = q.dtype
|
37
|
+
orig_k_dtype = k.dtype
|
38
|
+
q, k = q.float(), k.float()
|
39
|
+
|
40
|
+
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
41
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
42
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
43
|
+
|
44
|
+
q_embed = q_embed.to(orig_q_dtype)
|
45
|
+
k_embed = k_embed.to(orig_k_dtype)
|
46
|
+
|
47
|
+
return q_embed, k_embed
|
66
48
|
|
67
49
|
|
68
50
|
class VisionAttention(nn.Module):
|
@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
|
|
75
57
|
use_context_forward (bool, default to True):
|
76
58
|
if ``True``, a flash_attn style attention will be applied
|
77
59
|
Otherwise, a full-sequence attention will be applied.
|
78
|
-
|
79
|
-
if ``True``, the softmax will be performed in
|
60
|
+
softmax_in_single_precision (bool, default to False):
|
61
|
+
if ``True``, the softmax will be performed in single-precision
|
80
62
|
Otherwise, it will be performed in half-precision
|
81
63
|
|
82
64
|
"""
|
@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
|
|
90
72
|
quant_config: Optional[QuantizationConfig] = None,
|
91
73
|
dropout: float = 0.0,
|
92
74
|
use_context_forward: bool = True,
|
93
|
-
|
75
|
+
softmax_in_single_precision: bool = False,
|
94
76
|
flatten_batch: bool = False,
|
95
77
|
prefix: str = "",
|
96
78
|
):
|
@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
|
|
113
95
|
head_size=self.head_size,
|
114
96
|
dropout=dropout,
|
115
97
|
flatten_batch=flatten_batch,
|
116
|
-
|
98
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
117
99
|
)
|
118
100
|
|
119
101
|
self.use_qkv_parallel = use_qkv_parallel
|
@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
|
|
143
125
|
self,
|
144
126
|
x: torch.Tensor,
|
145
127
|
cu_seqlens: Optional[torch.Tensor] = None,
|
146
|
-
|
128
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
147
129
|
attention_mask: Optional[torch.Tensor] = None,
|
148
130
|
) -> torch.Tensor:
|
149
131
|
r"""
|
@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
|
|
151
133
|
x: [b, s, embed_dim]
|
152
134
|
cu_seqlens: [b]
|
153
135
|
Returns:
|
154
|
-
[s, b,
|
136
|
+
[s, b, head * head_size]
|
155
137
|
"""
|
156
138
|
bsz, s, _ = x.shape
|
139
|
+
head = self.num_attention_heads_per_partition
|
157
140
|
if self.use_qkv_parallel:
|
158
141
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
159
142
|
qkv, _ = self.qkv_proj(x)
|
160
143
|
q, k, v = qkv.chunk(3, dim=-1)
|
161
144
|
|
162
|
-
# [b, s, embed_dim] --> [b * s,
|
163
|
-
q, k, v = [
|
164
|
-
x.reshape(
|
165
|
-
bsz * s, self.num_attention_heads_per_partition, -1
|
166
|
-
).contiguous()
|
167
|
-
for x in (q, k, v)
|
168
|
-
]
|
145
|
+
# [b, s, embed_dim] --> [b * s, head, head_size]
|
146
|
+
q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
|
169
147
|
else:
|
170
148
|
# [b, s, embed_dim] --> [s, b, embed_dim]
|
171
149
|
x = rearrange(x, "b s ... -> s b ...")
|
@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
|
|
173
151
|
qkv, _ = self.qkv_proj(x)
|
174
152
|
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
175
153
|
new_x_shape = qkv.size()[:-1] + (
|
176
|
-
|
154
|
+
head,
|
177
155
|
3 * self.hidden_size_per_attention_head,
|
178
156
|
)
|
179
157
|
qkv = qkv.view(*new_x_shape)
|
@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
|
|
186
164
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
187
165
|
]
|
188
166
|
|
189
|
-
if
|
190
|
-
|
191
|
-
|
167
|
+
if position_embeddings is not None:
|
168
|
+
cos, sin = position_embeddings
|
169
|
+
original_shape = q.shape
|
170
|
+
q, k = q.view(s, head, -1), k.view(s, head, -1)
|
171
|
+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
172
|
+
q, k = q.reshape(original_shape), k.reshape(original_shape)
|
192
173
|
|
193
174
|
if self.use_qkv_parallel:
|
194
175
|
pass
|
@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
|
|
230
211
|
head_size: int,
|
231
212
|
dropout: float = 0.0,
|
232
213
|
flatten_batch: bool = False,
|
233
|
-
|
214
|
+
softmax_in_single_precision: bool = False,
|
234
215
|
):
|
235
216
|
super().__init__()
|
236
217
|
self.head_size = head_size
|
237
218
|
self.flatten_batch = flatten_batch
|
238
|
-
self.
|
219
|
+
self.softmax_in_single_precision = softmax_in_single_precision
|
239
220
|
self.dropout = dropout
|
240
221
|
|
241
222
|
@staticmethod
|
@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
|
|
319
300
|
)
|
320
301
|
|
321
302
|
if attention_mask is None:
|
322
|
-
if self.
|
303
|
+
if self.softmax_in_single_precision:
|
323
304
|
raise RuntimeError("Empty attention mask")
|
324
305
|
else:
|
325
306
|
attention_mask = attention_mask.to(device=q.device)
|
326
307
|
|
327
308
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
328
309
|
|
329
|
-
if self.
|
310
|
+
if self.softmax_in_single_precision:
|
330
311
|
scale = self.head_size**-0.5
|
331
312
|
k_transposed = rearrange(k, "b h s d -> b h d s")
|
332
313
|
attn_weights = torch.matmul(q, k_transposed) * scale
|
sglang/srt/layers/linear.py
CHANGED
@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
|
|
18
18
|
)
|
19
19
|
from sglang.srt.layers.parameter import (
|
20
20
|
BasevLLMParameter,
|
21
|
+
BlockQuantScaleParameter,
|
21
22
|
PackedColumnParameter,
|
22
23
|
PackedvLLMParameter,
|
23
24
|
PerTensorScaleParameter,
|
@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
|
|
27
28
|
QuantizationConfig,
|
28
29
|
QuantizeMethodBase,
|
29
30
|
)
|
30
|
-
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
31
31
|
from sglang.srt.utils import set_weight_attrs
|
32
32
|
|
33
33
|
logger = logging.getLogger(__name__)
|
@@ -6,8 +6,9 @@ import triton
|
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
9
|
+
from sglang.srt.utils import is_cuda
|
9
10
|
|
10
|
-
_is_cuda =
|
11
|
+
_is_cuda = is_cuda()
|
11
12
|
if _is_cuda:
|
12
13
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
13
14
|
sglang_per_token_group_quant_fp8,
|
@@ -3,7 +3,7 @@ from typing import Callable, List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
from torch.nn import Module
|
6
|
-
from vllm import _custom_ops as
|
6
|
+
from vllm import _custom_ops as vllm_ops
|
7
7
|
|
8
8
|
from sglang.srt.custom_op import CustomOp
|
9
9
|
from sglang.srt.distributed import (
|
@@ -26,10 +26,18 @@ from sglang.srt.layers.quantization.base_config import (
|
|
26
26
|
QuantizeMethodBase,
|
27
27
|
)
|
28
28
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
29
|
-
from sglang.srt.utils import is_hip, set_weight_attrs
|
29
|
+
from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
|
30
|
+
|
31
|
+
_is_cuda = is_cuda()
|
32
|
+
|
33
|
+
if _is_cuda:
|
34
|
+
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
35
|
+
|
30
36
|
|
31
37
|
logger = logging.getLogger(__name__)
|
32
38
|
|
39
|
+
_is_hip = is_hip()
|
40
|
+
|
33
41
|
|
34
42
|
class GroupedGemmRunner(torch.nn.Module):
|
35
43
|
flashinfer_gemm_warpper = None
|
@@ -703,7 +711,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
703
711
|
# If checkpoint is fp16, quantize in place.
|
704
712
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
705
713
|
# If rocm, use float8_e4m3fnuz as dtype
|
706
|
-
fp8_dtype = torch.float8_e4m3fnuz if
|
714
|
+
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
707
715
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
708
716
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
709
717
|
|
@@ -717,12 +725,20 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
717
725
|
)
|
718
726
|
|
719
727
|
for expert in range(layer.num_experts_per_partition):
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
|
724
|
-
|
725
|
-
|
728
|
+
if _is_cuda:
|
729
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
730
|
+
sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
731
|
+
)
|
732
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
733
|
+
sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
734
|
+
)
|
735
|
+
else:
|
736
|
+
w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
|
737
|
+
vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
738
|
+
)
|
739
|
+
w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
|
740
|
+
vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
|
741
|
+
)
|
726
742
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
727
743
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
728
744
|
return
|
sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json
ADDED
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 64,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 2
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 5
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 2
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 8,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 2
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 2
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 32,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 32,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 32,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|