sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from functools import lru_cache
|
4
|
-
from typing import Optional
|
4
|
+
from typing import Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn as nn
|
8
8
|
import torch.nn.functional as F
|
9
|
-
from einops import rearrange
|
9
|
+
from einops import rearrange
|
10
10
|
|
11
11
|
from sglang.srt.distributed import parallel_state
|
12
12
|
from sglang.srt.distributed import utils as dist_utils
|
@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
|
|
22
22
|
from sglang.srt.utils import add_prefix
|
23
23
|
|
24
24
|
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
return rearrange(
|
32
|
-
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
33
|
-
)
|
25
|
+
# Copied from transformers, modeling_qwen2_vl.py
|
26
|
+
def rotate_half(x):
|
27
|
+
"""Rotates half the hidden dims of the input."""
|
28
|
+
x1 = x[..., : x.shape[-1] // 2]
|
29
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
30
|
+
return torch.cat((-x2, x1), dim=-1)
|
34
31
|
|
35
32
|
|
36
|
-
def
|
37
|
-
|
38
|
-
) -> torch.Tensor:
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
)
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
return torch.cat(
|
52
|
-
[
|
53
|
-
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
54
|
-
x[..., ro_dim:],
|
55
|
-
],
|
56
|
-
dim=-1,
|
57
|
-
)
|
58
|
-
|
59
|
-
|
60
|
-
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
61
|
-
t_ = t.float()
|
62
|
-
cos = freqs.cos()
|
63
|
-
sin = freqs.sin()
|
64
|
-
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
65
|
-
return output
|
33
|
+
def apply_rotary_pos_emb_vision(
|
34
|
+
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
35
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36
|
+
orig_q_dtype = q.dtype
|
37
|
+
orig_k_dtype = k.dtype
|
38
|
+
q, k = q.float(), k.float()
|
39
|
+
|
40
|
+
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
41
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
42
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
43
|
+
|
44
|
+
q_embed = q_embed.to(orig_q_dtype)
|
45
|
+
k_embed = k_embed.to(orig_k_dtype)
|
46
|
+
|
47
|
+
return q_embed, k_embed
|
66
48
|
|
67
49
|
|
68
50
|
class VisionAttention(nn.Module):
|
@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
|
|
75
57
|
use_context_forward (bool, default to True):
|
76
58
|
if ``True``, a flash_attn style attention will be applied
|
77
59
|
Otherwise, a full-sequence attention will be applied.
|
78
|
-
|
79
|
-
if ``True``, the softmax will be performed in
|
60
|
+
softmax_in_single_precision (bool, default to False):
|
61
|
+
if ``True``, the softmax will be performed in single-precision
|
80
62
|
Otherwise, it will be performed in half-precision
|
81
63
|
|
82
64
|
"""
|
@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
|
|
90
72
|
quant_config: Optional[QuantizationConfig] = None,
|
91
73
|
dropout: float = 0.0,
|
92
74
|
use_context_forward: bool = True,
|
93
|
-
|
75
|
+
softmax_in_single_precision: bool = False,
|
94
76
|
flatten_batch: bool = False,
|
95
77
|
prefix: str = "",
|
96
78
|
):
|
@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
|
|
113
95
|
head_size=self.head_size,
|
114
96
|
dropout=dropout,
|
115
97
|
flatten_batch=flatten_batch,
|
116
|
-
|
98
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
117
99
|
)
|
118
100
|
|
119
101
|
self.use_qkv_parallel = use_qkv_parallel
|
@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
|
|
143
125
|
self,
|
144
126
|
x: torch.Tensor,
|
145
127
|
cu_seqlens: Optional[torch.Tensor] = None,
|
146
|
-
|
128
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
147
129
|
attention_mask: Optional[torch.Tensor] = None,
|
148
130
|
) -> torch.Tensor:
|
149
131
|
r"""
|
@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
|
|
151
133
|
x: [b, s, embed_dim]
|
152
134
|
cu_seqlens: [b]
|
153
135
|
Returns:
|
154
|
-
[s, b,
|
136
|
+
[s, b, head * head_size]
|
155
137
|
"""
|
156
138
|
bsz, s, _ = x.shape
|
139
|
+
head = self.num_attention_heads_per_partition
|
157
140
|
if self.use_qkv_parallel:
|
158
141
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
159
142
|
qkv, _ = self.qkv_proj(x)
|
160
143
|
q, k, v = qkv.chunk(3, dim=-1)
|
161
144
|
|
162
|
-
# [b, s, embed_dim] --> [b * s,
|
163
|
-
q, k, v = [
|
164
|
-
x.reshape(
|
165
|
-
bsz * s, self.num_attention_heads_per_partition, -1
|
166
|
-
).contiguous()
|
167
|
-
for x in (q, k, v)
|
168
|
-
]
|
145
|
+
# [b, s, embed_dim] --> [b * s, head, head_size]
|
146
|
+
q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
|
169
147
|
else:
|
170
148
|
# [b, s, embed_dim] --> [s, b, embed_dim]
|
171
149
|
x = rearrange(x, "b s ... -> s b ...")
|
@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
|
|
173
151
|
qkv, _ = self.qkv_proj(x)
|
174
152
|
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
175
153
|
new_x_shape = qkv.size()[:-1] + (
|
176
|
-
|
154
|
+
head,
|
177
155
|
3 * self.hidden_size_per_attention_head,
|
178
156
|
)
|
179
157
|
qkv = qkv.view(*new_x_shape)
|
@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
|
|
186
164
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
187
165
|
]
|
188
166
|
|
189
|
-
if
|
190
|
-
|
191
|
-
|
167
|
+
if position_embeddings is not None:
|
168
|
+
cos, sin = position_embeddings
|
169
|
+
original_shape = q.shape
|
170
|
+
q, k = q.view(s, head, -1), k.view(s, head, -1)
|
171
|
+
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
172
|
+
q, k = q.reshape(original_shape), k.reshape(original_shape)
|
192
173
|
|
193
174
|
if self.use_qkv_parallel:
|
194
175
|
pass
|
@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
|
|
230
211
|
head_size: int,
|
231
212
|
dropout: float = 0.0,
|
232
213
|
flatten_batch: bool = False,
|
233
|
-
|
214
|
+
softmax_in_single_precision: bool = False,
|
234
215
|
):
|
235
216
|
super().__init__()
|
236
217
|
self.head_size = head_size
|
237
218
|
self.flatten_batch = flatten_batch
|
238
|
-
self.
|
219
|
+
self.softmax_in_single_precision = softmax_in_single_precision
|
239
220
|
self.dropout = dropout
|
240
221
|
|
241
222
|
@staticmethod
|
@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
|
|
319
300
|
)
|
320
301
|
|
321
302
|
if attention_mask is None:
|
322
|
-
if self.
|
303
|
+
if self.softmax_in_single_precision:
|
323
304
|
raise RuntimeError("Empty attention mask")
|
324
305
|
else:
|
325
306
|
attention_mask = attention_mask.to(device=q.device)
|
326
307
|
|
327
308
|
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
328
309
|
|
329
|
-
if self.
|
310
|
+
if self.softmax_in_single_precision:
|
330
311
|
scale = self.head_size**-0.5
|
331
312
|
k_transposed = rearrange(k, "b h s d -> b h d s")
|
332
313
|
attn_weights = torch.matmul(q, k_transposed) * scale
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import functools
|
4
|
+
import logging
|
5
|
+
from contextlib import contextmanager
|
4
6
|
from typing import TYPE_CHECKING, Union
|
5
7
|
|
6
8
|
import torch
|
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
|
|
14
16
|
tensor_model_parallel_all_reduce,
|
15
17
|
)
|
16
18
|
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
17
21
|
if TYPE_CHECKING:
|
18
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
23
|
|
@@ -86,6 +90,27 @@ def get_attention_dp_size():
|
|
86
90
|
return _DP_SIZE
|
87
91
|
|
88
92
|
|
93
|
+
@contextmanager
|
94
|
+
def disable_dp_size():
|
95
|
+
"""Patch the tp group temporarily until this function ends.
|
96
|
+
|
97
|
+
This method is for draft workers of speculative decoding to run draft model
|
98
|
+
with different tp degree from that of target model workers.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
tp_group (GroupCoordinator): the tp group coordinator
|
102
|
+
"""
|
103
|
+
global _DP_SIZE
|
104
|
+
assert _DP_SIZE is not None, "dp attention not initialized!"
|
105
|
+
|
106
|
+
old_dp_size = _DP_SIZE
|
107
|
+
_DP_SIZE = 1
|
108
|
+
try:
|
109
|
+
yield
|
110
|
+
finally:
|
111
|
+
_DP_SIZE = old_dp_size
|
112
|
+
|
113
|
+
|
89
114
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
90
115
|
dp_rank = get_attention_dp_rank()
|
91
116
|
|
@@ -159,7 +184,8 @@ def dp_gather(
|
|
159
184
|
layer_id != "embedding" or get_attention_tp_rank() == 0
|
160
185
|
):
|
161
186
|
assert (
|
162
|
-
global_tokens.
|
187
|
+
global_tokens.untyped_storage().data_ptr()
|
188
|
+
!= local_tokens.untyped_storage().data_ptr()
|
163
189
|
), "aliasing between global_tokens and local_tokens not allowed"
|
164
190
|
memcpy_triton(
|
165
191
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
@@ -174,8 +200,9 @@ def dp_gather(
|
|
174
200
|
torch.ops.sglang.inplace_all_reduce(
|
175
201
|
global_tokens, group_name=get_tp_group().unique_name
|
176
202
|
)
|
203
|
+
|
177
204
|
else:
|
178
|
-
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
205
|
+
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
179
206
|
|
180
207
|
|
181
208
|
def dp_scatter(
|
@@ -186,6 +213,7 @@ def dp_scatter(
|
|
186
213
|
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
187
214
|
# since local_tokens may be padded for cuda graph
|
188
215
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
216
|
+
|
189
217
|
local_tokens.fill_(0)
|
190
218
|
assert local_tokens.is_contiguous()
|
191
219
|
assert global_tokens.is_contiguous()
|
@@ -0,0 +1,411 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
fused_softcap_autotune = triton.autotune(
|
8
|
+
configs=[
|
9
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
10
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
|
11
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
|
12
|
+
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
|
13
|
+
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
|
14
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
|
15
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
|
16
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
|
17
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
18
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
19
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
20
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
|
21
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
|
22
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
|
23
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
24
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
25
|
+
triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
|
26
|
+
],
|
27
|
+
key=["n_ele"],
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@triton.jit
|
32
|
+
def fused_softcap_kernel(
|
33
|
+
output_ptr,
|
34
|
+
input_ptr,
|
35
|
+
n_ele,
|
36
|
+
softcap_const: tl.constexpr,
|
37
|
+
BLOCK_SIZE: tl.constexpr,
|
38
|
+
):
|
39
|
+
pid = tl.program_id(axis=0)
|
40
|
+
block_start = pid * BLOCK_SIZE
|
41
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
42
|
+
mask = offsets < n_ele
|
43
|
+
x = tl.load(input_ptr + offsets, mask=mask)
|
44
|
+
fx = x.to(tl.float32)
|
45
|
+
fxs = fx / softcap_const
|
46
|
+
exped = tl.exp(2 * fxs)
|
47
|
+
top = exped - 1
|
48
|
+
bottom = exped + 1
|
49
|
+
output = top / bottom * softcap_const
|
50
|
+
tl.store(output_ptr + offsets, output, mask=mask)
|
51
|
+
|
52
|
+
|
53
|
+
fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
|
54
|
+
|
55
|
+
|
56
|
+
def fused_softcap(x, softcap_const, autotune=False):
|
57
|
+
output = torch.empty_like(x, dtype=torch.float32)
|
58
|
+
n_elements = output.numel()
|
59
|
+
if autotune:
|
60
|
+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
61
|
+
fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
|
62
|
+
else:
|
63
|
+
fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
|
64
|
+
output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
|
65
|
+
)
|
66
|
+
return output
|
67
|
+
|
68
|
+
|
69
|
+
# cast to float + softcap
|
70
|
+
class Softcap:
|
71
|
+
def __init__(self, softcap_const: float):
|
72
|
+
self.softcap_const = softcap_const
|
73
|
+
|
74
|
+
def __call__(self, *args, **kwargs):
|
75
|
+
return self.forward(*args, **kwargs)
|
76
|
+
|
77
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78
|
+
if x.is_cuda:
|
79
|
+
return self.forward_cuda(x)
|
80
|
+
else:
|
81
|
+
return self.forward_native(x)
|
82
|
+
|
83
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
84
|
+
return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
|
85
|
+
|
86
|
+
def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
|
87
|
+
return fused_softcap(x, self.softcap_const, autotune=autotune)
|
88
|
+
|
89
|
+
|
90
|
+
rmsnorm_autotune = triton.autotune(
|
91
|
+
configs=[
|
92
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
|
93
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
|
94
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
|
95
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
96
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
97
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
98
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
|
99
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
|
100
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
|
101
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
|
102
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
|
103
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
|
104
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
|
105
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
|
106
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
|
107
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
|
108
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
|
109
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
|
110
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
|
111
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
112
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
|
113
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
|
114
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
|
115
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
|
116
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
|
117
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
|
118
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
|
119
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
|
120
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
121
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
|
122
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
|
123
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
|
124
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
|
125
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
|
126
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
|
127
|
+
],
|
128
|
+
key=["hidden_dim"],
|
129
|
+
)
|
130
|
+
|
131
|
+
|
132
|
+
@triton.jit
|
133
|
+
def fused_dual_residual_rmsnorm_kernel(
|
134
|
+
output_ptr,
|
135
|
+
mid_ptr,
|
136
|
+
activ_ptr,
|
137
|
+
residual_ptr,
|
138
|
+
weight1_ptr,
|
139
|
+
weight2_ptr,
|
140
|
+
eps: tl.constexpr,
|
141
|
+
hidden_dim: tl.constexpr,
|
142
|
+
BLOCK_SIZE: tl.constexpr,
|
143
|
+
):
|
144
|
+
pid = tl.program_id(axis=0)
|
145
|
+
input_start = pid * hidden_dim
|
146
|
+
|
147
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
148
|
+
mask = offsets < hidden_dim
|
149
|
+
|
150
|
+
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
151
|
+
a = a_.to(tl.float32)
|
152
|
+
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
153
|
+
|
154
|
+
r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
|
155
|
+
w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
|
156
|
+
w1 = w1_.to(tl.float32)
|
157
|
+
|
158
|
+
a2r = r + (a / rms * w1).to(r.dtype)
|
159
|
+
tl.store(
|
160
|
+
mid_ptr + input_start + offsets,
|
161
|
+
a2r,
|
162
|
+
mask=mask,
|
163
|
+
)
|
164
|
+
|
165
|
+
a2r = a2r.to(tl.float32)
|
166
|
+
rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
|
167
|
+
|
168
|
+
w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
|
169
|
+
w2 = w2_.to(tl.float32)
|
170
|
+
|
171
|
+
tl.store(
|
172
|
+
output_ptr + input_start + offsets,
|
173
|
+
a2r / rms2 * w2, # implicitly casts to output dtype here
|
174
|
+
mask=mask,
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
179
|
+
fused_dual_residual_rmsnorm_kernel
|
180
|
+
)
|
181
|
+
|
182
|
+
|
183
|
+
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
184
|
+
assert len(x.shape) == 2
|
185
|
+
assert x.shape == residual.shape and x.dtype == residual.dtype
|
186
|
+
output, mid = torch.empty_like(x), torch.empty_like(x)
|
187
|
+
bs, hidden_dim = x.shape
|
188
|
+
if autotune:
|
189
|
+
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
190
|
+
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
config = {
|
194
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
195
|
+
"num_warps": max(
|
196
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
197
|
+
),
|
198
|
+
}
|
199
|
+
|
200
|
+
fused_dual_residual_rmsnorm_kernel[(bs,)](
|
201
|
+
output,
|
202
|
+
mid,
|
203
|
+
x,
|
204
|
+
residual,
|
205
|
+
weight1,
|
206
|
+
weight2,
|
207
|
+
eps=eps,
|
208
|
+
hidden_dim=hidden_dim,
|
209
|
+
**config,
|
210
|
+
)
|
211
|
+
|
212
|
+
return output, mid
|
213
|
+
|
214
|
+
|
215
|
+
@triton.jit
|
216
|
+
def fused_rmsnorm_kernel(
|
217
|
+
output_ptr,
|
218
|
+
activ_ptr,
|
219
|
+
weight_ptr,
|
220
|
+
eps: tl.constexpr,
|
221
|
+
hidden_dim: tl.constexpr,
|
222
|
+
BLOCK_SIZE: tl.constexpr,
|
223
|
+
):
|
224
|
+
pid = tl.program_id(axis=0)
|
225
|
+
input_start = pid * hidden_dim
|
226
|
+
|
227
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
228
|
+
mask = offsets < hidden_dim
|
229
|
+
|
230
|
+
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
231
|
+
a = a_.to(tl.float32)
|
232
|
+
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
233
|
+
|
234
|
+
w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
|
235
|
+
w1 = w1_.to(tl.float32)
|
236
|
+
|
237
|
+
a_rms = a / rms * w1
|
238
|
+
|
239
|
+
tl.store(
|
240
|
+
output_ptr + input_start + offsets,
|
241
|
+
a_rms, # implicitly casts to output dtype here
|
242
|
+
mask=mask,
|
243
|
+
)
|
244
|
+
|
245
|
+
|
246
|
+
def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
247
|
+
assert len(x.shape) == 2
|
248
|
+
if inplace:
|
249
|
+
output = x
|
250
|
+
else:
|
251
|
+
output = torch.empty_like(x)
|
252
|
+
bs, hidden_dim = x.shape
|
253
|
+
config = {
|
254
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
255
|
+
"num_warps": max(
|
256
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
257
|
+
),
|
258
|
+
}
|
259
|
+
|
260
|
+
fused_rmsnorm_kernel[(bs,)](
|
261
|
+
output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
|
262
|
+
)
|
263
|
+
return output
|
264
|
+
|
265
|
+
|
266
|
+
class FusedDualResidualRMSNorm:
|
267
|
+
"""
|
268
|
+
Fused implementation of
|
269
|
+
y = RMSNorm2(RMSNorm1(x) + residual))
|
270
|
+
"""
|
271
|
+
|
272
|
+
def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
|
273
|
+
self.rmsnorm1 = rmsnorm1
|
274
|
+
self.rmsnorm2 = rmsnorm2
|
275
|
+
self.variance_epsilon = self.rmsnorm1.variance_epsilon
|
276
|
+
assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
|
277
|
+
assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
|
278
|
+
|
279
|
+
def __call__(self, *args, **kwargs):
|
280
|
+
return self.forward(*args, **kwargs)
|
281
|
+
|
282
|
+
def forward(
|
283
|
+
self, x: torch.Tensor, residual: torch.Tensor
|
284
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
285
|
+
if x.is_cuda:
|
286
|
+
return self.forward_cuda(x, residual)
|
287
|
+
else:
|
288
|
+
return self.forward_flashinfer(x, residual)
|
289
|
+
|
290
|
+
def forward_cuda(
|
291
|
+
self, x: torch.Tensor, residual: torch.Tensor, autotune=False
|
292
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
293
|
+
return fused_dual_residual_rmsnorm(
|
294
|
+
x,
|
295
|
+
residual,
|
296
|
+
self.rmsnorm1.weight,
|
297
|
+
self.rmsnorm2.weight,
|
298
|
+
self.variance_epsilon,
|
299
|
+
autotune=autotune,
|
300
|
+
)
|
301
|
+
|
302
|
+
def forward_flashinfer(
|
303
|
+
self,
|
304
|
+
x: torch.Tensor,
|
305
|
+
residual: torch.Tensor,
|
306
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
307
|
+
normed1 = self.rmsnorm1(x)
|
308
|
+
residual = normed1 + residual
|
309
|
+
return self.rmsnorm2(residual), residual
|
310
|
+
|
311
|
+
def forward_native(
|
312
|
+
self,
|
313
|
+
x: torch.Tensor,
|
314
|
+
residual: torch.Tensor,
|
315
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
316
|
+
normed1 = self.rmsnorm1.forward_native(x)
|
317
|
+
residual = normed1 + residual
|
318
|
+
return self.rmsnorm2.forward_native(residual), residual
|
319
|
+
|
320
|
+
|
321
|
+
# gelu on first half of vector
|
322
|
+
@triton.jit
|
323
|
+
def gelu_and_mul_kernel(
|
324
|
+
out_hidden_states_ptr, # (bs, hidden_dim)
|
325
|
+
out_scales_ptr, # (bs,)
|
326
|
+
hidden_states_ptr, # (bs, hidden_dim * 2)
|
327
|
+
quant_max: tl.constexpr,
|
328
|
+
static_scale: tl.constexpr,
|
329
|
+
hidden_dim: tl.constexpr, # the output hidden_dim
|
330
|
+
BLOCK_SIZE: tl.constexpr,
|
331
|
+
):
|
332
|
+
pid = tl.program_id(axis=0)
|
333
|
+
|
334
|
+
input_start = pid * hidden_dim * 2
|
335
|
+
output_start = pid * hidden_dim
|
336
|
+
|
337
|
+
input1_offs = tl.arange(0, BLOCK_SIZE)
|
338
|
+
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
339
|
+
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
340
|
+
output_offs = tl.arange(0, BLOCK_SIZE)
|
341
|
+
|
342
|
+
x1 = tl.load(
|
343
|
+
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
344
|
+
).to(tl.float32)
|
345
|
+
x3 = tl.load(
|
346
|
+
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
347
|
+
).to(tl.float32)
|
348
|
+
|
349
|
+
# gelu
|
350
|
+
# cast down before mul to better match training?
|
351
|
+
gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
|
352
|
+
out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
|
353
|
+
|
354
|
+
if quant_max is not None:
|
355
|
+
raise NotImplementedError()
|
356
|
+
|
357
|
+
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
358
|
+
|
359
|
+
|
360
|
+
def gelu_and_mul_triton(
|
361
|
+
hidden_states,
|
362
|
+
scales=None,
|
363
|
+
quantize=None, # dtype to quantize to
|
364
|
+
out=None,
|
365
|
+
):
|
366
|
+
bs, in_hidden_dim = hidden_states.shape
|
367
|
+
hidden_dim = in_hidden_dim // 2
|
368
|
+
|
369
|
+
if out is None:
|
370
|
+
out_hidden_states = torch.empty(
|
371
|
+
(bs, hidden_dim),
|
372
|
+
dtype=quantize or hidden_states.dtype,
|
373
|
+
device=hidden_states.device,
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
assert out.shape == (bs, hidden_dim)
|
377
|
+
assert out.dtype == (quantize or hidden_states.dtype)
|
378
|
+
out_hidden_states = out
|
379
|
+
out_scales = None
|
380
|
+
static_scale = False
|
381
|
+
if quantize is not None:
|
382
|
+
if scales is None:
|
383
|
+
out_scales = torch.empty(
|
384
|
+
(bs,), dtype=torch.float32, device=hidden_states.device
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
out_scales = scales
|
388
|
+
static_scale = True
|
389
|
+
|
390
|
+
config = {
|
391
|
+
# 8 ele per thread (not tuned)
|
392
|
+
"num_warps": max(
|
393
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
|
394
|
+
),
|
395
|
+
}
|
396
|
+
|
397
|
+
gelu_and_mul_kernel[(bs,)](
|
398
|
+
out_hidden_states,
|
399
|
+
out_scales,
|
400
|
+
hidden_states,
|
401
|
+
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
402
|
+
static_scale=static_scale,
|
403
|
+
hidden_dim=hidden_dim,
|
404
|
+
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
405
|
+
**config,
|
406
|
+
)
|
407
|
+
|
408
|
+
if quantize is not None:
|
409
|
+
return out_hidden_states, out_scales
|
410
|
+
else:
|
411
|
+
return out_hidden_states, None
|
sglang/srt/layers/linear.py
CHANGED
@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
|
|
18
18
|
)
|
19
19
|
from sglang.srt.layers.parameter import (
|
20
20
|
BasevLLMParameter,
|
21
|
+
BlockQuantScaleParameter,
|
21
22
|
PackedColumnParameter,
|
22
23
|
PackedvLLMParameter,
|
23
24
|
PerTensorScaleParameter,
|
@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
|
|
27
28
|
QuantizationConfig,
|
28
29
|
QuantizeMethodBase,
|
29
30
|
)
|
30
|
-
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
31
31
|
from sglang.srt.utils import set_weight_attrs
|
32
32
|
|
33
33
|
logger = logging.getLogger(__name__)
|