sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from vllm import _custom_ops as ops
|
|
16
16
|
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
17
17
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
|
+
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
19
20
|
|
20
21
|
|
21
22
|
@triton.jit
|
@@ -58,6 +59,7 @@ def fused_moe_kernel(
|
|
58
59
|
compute_type: tl.constexpr,
|
59
60
|
use_fp8_w8a8: tl.constexpr,
|
60
61
|
use_int8_w8a16: tl.constexpr,
|
62
|
+
even_Ks: tl.constexpr,
|
61
63
|
):
|
62
64
|
"""
|
63
65
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
@@ -143,12 +145,21 @@ def fused_moe_kernel(
|
|
143
145
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
144
146
|
# Load the next block of A and B, generate a mask by checking the
|
145
147
|
# K dimension.
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
148
|
+
if even_Ks:
|
149
|
+
a = tl.load(
|
150
|
+
a_ptrs,
|
151
|
+
mask=token_mask[:, None],
|
152
|
+
other=0.0,
|
153
|
+
)
|
154
|
+
b = tl.load(b_ptrs)
|
155
|
+
else:
|
156
|
+
a = tl.load(
|
157
|
+
a_ptrs,
|
158
|
+
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
159
|
+
other=0.0,
|
160
|
+
)
|
161
|
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
162
|
+
|
152
163
|
# We accumulate along the K dimension.
|
153
164
|
if use_int8_w8a16:
|
154
165
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
|
@@ -254,7 +265,9 @@ def invoke_fused_moe_kernel(
|
|
254
265
|
assert topk_weights.stride(1) == 1
|
255
266
|
assert sorted_token_ids.stride(0) == 1
|
256
267
|
|
268
|
+
padded_size = 0
|
257
269
|
if use_fp8_w8a8:
|
270
|
+
padded_size = padding_size
|
258
271
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
259
272
|
assert B_scale is not None
|
260
273
|
elif use_int8_w8a16:
|
@@ -268,6 +281,12 @@ def invoke_fused_moe_kernel(
|
|
268
281
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
|
269
282
|
)
|
270
283
|
|
284
|
+
K = B.shape[2] - padded_size
|
285
|
+
if K % config["BLOCK_SIZE_K"] == 0:
|
286
|
+
even_Ks = True
|
287
|
+
else:
|
288
|
+
even_Ks = False
|
289
|
+
|
271
290
|
fused_moe_kernel[grid](
|
272
291
|
A,
|
273
292
|
B,
|
@@ -279,7 +298,7 @@ def invoke_fused_moe_kernel(
|
|
279
298
|
expert_ids,
|
280
299
|
num_tokens_post_padded,
|
281
300
|
B.shape[1],
|
282
|
-
B.shape[2],
|
301
|
+
B.shape[2] - padded_size,
|
283
302
|
sorted_token_ids.shape[0],
|
284
303
|
topk_ids.numel(),
|
285
304
|
A.stride(0),
|
@@ -296,6 +315,7 @@ def invoke_fused_moe_kernel(
|
|
296
315
|
compute_type=compute_type,
|
297
316
|
use_fp8_w8a8=use_fp8_w8a8,
|
298
317
|
use_int8_w8a16=use_int8_w8a16,
|
318
|
+
even_Ks=even_Ks,
|
299
319
|
**config,
|
300
320
|
)
|
301
321
|
|
@@ -351,20 +371,39 @@ def get_default_config(
|
|
351
371
|
dtype: Optional[str],
|
352
372
|
is_marlin: bool,
|
353
373
|
) -> Dict[str, int]:
|
354
|
-
|
355
|
-
"BLOCK_SIZE_M": 64,
|
356
|
-
"BLOCK_SIZE_N": 64,
|
357
|
-
"BLOCK_SIZE_K": 32,
|
358
|
-
"GROUP_SIZE_M": 8,
|
359
|
-
}
|
360
|
-
# A heuristic: fused marlin works faster with this config for small M
|
361
|
-
if M <= E or (is_marlin and M <= 32):
|
374
|
+
if dtype == "fp8_w8a8":
|
362
375
|
config = {
|
363
|
-
"BLOCK_SIZE_M":
|
364
|
-
"BLOCK_SIZE_N":
|
365
|
-
"BLOCK_SIZE_K":
|
366
|
-
"GROUP_SIZE_M":
|
376
|
+
"BLOCK_SIZE_M": 128,
|
377
|
+
"BLOCK_SIZE_N": 256,
|
378
|
+
"BLOCK_SIZE_K": 128,
|
379
|
+
"GROUP_SIZE_M": 32,
|
380
|
+
"num_warps": 8,
|
381
|
+
"num_stages": 4,
|
367
382
|
}
|
383
|
+
if M <= E:
|
384
|
+
config = {
|
385
|
+
"BLOCK_SIZE_M": 64,
|
386
|
+
"BLOCK_SIZE_N": 128,
|
387
|
+
"BLOCK_SIZE_K": 128,
|
388
|
+
"GROUP_SIZE_M": 1,
|
389
|
+
"num_warps": 4,
|
390
|
+
"num_stages": 4,
|
391
|
+
}
|
392
|
+
else:
|
393
|
+
config = {
|
394
|
+
"BLOCK_SIZE_M": 64,
|
395
|
+
"BLOCK_SIZE_N": 64,
|
396
|
+
"BLOCK_SIZE_K": 32,
|
397
|
+
"GROUP_SIZE_M": 8,
|
398
|
+
}
|
399
|
+
# A heuristic: fused marlin works faster with this config for small M
|
400
|
+
if M <= E or (is_marlin and M <= 32):
|
401
|
+
config = {
|
402
|
+
"BLOCK_SIZE_M": 16,
|
403
|
+
"BLOCK_SIZE_N": 32,
|
404
|
+
"BLOCK_SIZE_K": 64,
|
405
|
+
"GROUP_SIZE_M": 1,
|
406
|
+
}
|
368
407
|
return config
|
369
408
|
|
370
409
|
|
@@ -645,8 +684,12 @@ def fused_experts_impl(
|
|
645
684
|
a1_scale: Optional[torch.Tensor] = None,
|
646
685
|
a2_scale: Optional[torch.Tensor] = None,
|
647
686
|
):
|
687
|
+
padded_size = padding_size
|
688
|
+
if not use_fp8_w8a8:
|
689
|
+
padded_size = 0
|
690
|
+
|
648
691
|
# Check constraints.
|
649
|
-
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
692
|
+
assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
|
650
693
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
651
694
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
652
695
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
@@ -668,7 +711,7 @@ def fused_experts_impl(
|
|
668
711
|
get_config_func = functools.partial(
|
669
712
|
try_get_optimal_moe_config,
|
670
713
|
w1.shape,
|
671
|
-
w2.shape,
|
714
|
+
(w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
|
672
715
|
topk_ids.shape[1],
|
673
716
|
config_dtype,
|
674
717
|
)
|
@@ -19,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
19
19
|
)
|
20
20
|
from sglang.srt.utils import set_weight_attrs
|
21
21
|
|
22
|
-
if torch.cuda.is_available()
|
22
|
+
if torch.cuda.is_available():
|
23
23
|
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
24
24
|
else:
|
25
25
|
fused_experts = None # type: ignore
|
@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
|
|
39
39
|
# The logprobs of input tokens. shape: [#token, vocab_size]
|
40
40
|
input_token_logprobs: torch.Tensor = None
|
41
41
|
|
42
|
-
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
43
|
-
|
44
|
-
|
45
|
-
|
42
|
+
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
43
|
+
input_top_logprobs_val: List = None
|
44
|
+
input_top_logprobs_idx: List = None
|
45
|
+
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
|
46
|
+
output_top_logprobs_val: List = None
|
47
|
+
output_top_logprobs_idx: List = None
|
46
48
|
|
47
49
|
|
48
50
|
@dataclasses.dataclass
|
@@ -89,76 +91,18 @@ class LogitsMetadata:
|
|
89
91
|
|
90
92
|
|
91
93
|
class LogitsProcessor(nn.Module):
|
92
|
-
def __init__(
|
94
|
+
def __init__(
|
95
|
+
self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None
|
96
|
+
):
|
93
97
|
super().__init__()
|
94
98
|
self.config = config
|
99
|
+
self.logit_scale = logit_scale
|
95
100
|
self.do_tensor_parallel_all_gather = (
|
96
101
|
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
97
102
|
)
|
98
|
-
|
99
|
-
|
100
|
-
self,
|
101
|
-
input_token_logprobs: torch.Tensor,
|
102
|
-
logits_metadata: LogitsMetadata,
|
103
|
-
):
|
104
|
-
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
105
|
-
pruned_lens = torch.tensor(
|
106
|
-
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
107
|
-
)
|
108
|
-
|
109
|
-
start = torch.zeros_like(pruned_lens)
|
110
|
-
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
111
|
-
end = torch.clamp(
|
112
|
-
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
113
|
-
)
|
114
|
-
sum_logp = (
|
115
|
-
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
103
|
+
self.final_logit_softcapping = getattr(
|
104
|
+
self.config, "final_logit_softcapping", None
|
116
105
|
)
|
117
|
-
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
118
|
-
return normalized_prompt_logprobs
|
119
|
-
|
120
|
-
@staticmethod
|
121
|
-
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
122
|
-
max_k = max(logits_metadata.top_logprobs_nums)
|
123
|
-
ret = all_logprobs.topk(max_k, dim=1)
|
124
|
-
values = ret.values.tolist()
|
125
|
-
indices = ret.indices.tolist()
|
126
|
-
|
127
|
-
if logits_metadata.forward_mode.is_decode():
|
128
|
-
output_top_logprobs = []
|
129
|
-
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
130
|
-
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
131
|
-
return None, output_top_logprobs
|
132
|
-
else:
|
133
|
-
input_top_logprobs, output_top_logprobs = [], []
|
134
|
-
|
135
|
-
pt = 0
|
136
|
-
for k, pruned_len in zip(
|
137
|
-
logits_metadata.top_logprobs_nums,
|
138
|
-
logits_metadata.extend_logprob_pruned_lens_cpu,
|
139
|
-
):
|
140
|
-
if pruned_len <= 0:
|
141
|
-
input_top_logprobs.append([])
|
142
|
-
output_top_logprobs.append([])
|
143
|
-
continue
|
144
|
-
|
145
|
-
input_top_logprobs.append(
|
146
|
-
[
|
147
|
-
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
148
|
-
for j in range(pruned_len - 1)
|
149
|
-
]
|
150
|
-
)
|
151
|
-
output_top_logprobs.append(
|
152
|
-
list(
|
153
|
-
zip(
|
154
|
-
values[pt + pruned_len - 1][:k],
|
155
|
-
indices[pt + pruned_len - 1][:k],
|
156
|
-
)
|
157
|
-
)
|
158
|
-
)
|
159
|
-
pt += pruned_len
|
160
|
-
|
161
|
-
return input_top_logprobs, output_top_logprobs
|
162
106
|
|
163
107
|
def forward(
|
164
108
|
self,
|
@@ -184,38 +128,33 @@ class LogitsProcessor(nn.Module):
|
|
184
128
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
185
129
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
186
130
|
|
187
|
-
if
|
188
|
-
last_logits.div_(self.
|
131
|
+
if self.final_logit_softcapping:
|
132
|
+
last_logits.div_(self.final_logit_softcapping)
|
189
133
|
torch.tanh(last_logits, out=last_logits)
|
190
|
-
last_logits.mul_(self.
|
134
|
+
last_logits.mul_(self.final_logit_softcapping)
|
191
135
|
|
192
136
|
# Return only last_logits if logprob is not requested
|
193
137
|
if not logits_metadata.return_logprob:
|
194
138
|
return LogitsProcessorOutput(
|
195
139
|
next_token_logits=last_logits,
|
196
|
-
next_token_logprobs=None,
|
197
|
-
normalized_prompt_logprobs=None,
|
198
|
-
input_token_logprobs=None,
|
199
|
-
input_top_logprobs=None,
|
200
|
-
output_top_logprobs=None,
|
201
140
|
)
|
202
141
|
else:
|
203
|
-
last_logprobs =
|
142
|
+
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
143
|
+
last_logits, logits_metadata
|
144
|
+
)
|
204
145
|
|
205
146
|
if logits_metadata.forward_mode.is_decode():
|
206
147
|
if logits_metadata.return_top_logprob:
|
207
|
-
|
208
|
-
last_logprobs, logits_metadata
|
209
|
-
)
|
148
|
+
output_top_logprobs_val, output_top_logprobs_idx = (
|
149
|
+
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
|
150
|
+
)
|
210
151
|
else:
|
211
|
-
|
152
|
+
output_top_logprobs_val = output_top_logprobs_idx = None
|
212
153
|
return LogitsProcessorOutput(
|
213
154
|
next_token_logits=last_logits,
|
214
155
|
next_token_logprobs=last_logprobs,
|
215
|
-
|
216
|
-
|
217
|
-
input_top_logprobs=None,
|
218
|
-
output_top_logprobs=output_top_logprobs,
|
156
|
+
output_top_logprobs_val=output_top_logprobs_val,
|
157
|
+
output_top_logprobs_idx=output_top_logprobs_idx,
|
219
158
|
)
|
220
159
|
else:
|
221
160
|
# Slice the requested tokens to compute logprob
|
@@ -233,24 +172,35 @@ class LogitsProcessor(nn.Module):
|
|
233
172
|
all_logits = self._get_logits(states, lm_head)
|
234
173
|
if self.do_tensor_parallel_all_gather:
|
235
174
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
175
|
+
|
176
|
+
# The LM head's weights may be zero-padded for parallelism. Remove any
|
177
|
+
# extra logits that this padding may have produced.
|
236
178
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
237
179
|
|
238
|
-
if
|
239
|
-
all_logits.div_(self.
|
180
|
+
if self.final_logit_softcapping:
|
181
|
+
all_logits.div_(self.final_logit_softcapping)
|
240
182
|
torch.tanh(all_logits, out=all_logits)
|
241
|
-
all_logits.mul_(self.
|
183
|
+
all_logits.mul_(self.final_logit_softcapping)
|
242
184
|
|
243
185
|
all_logprobs = all_logits
|
244
186
|
del all_logits, hidden_states
|
245
|
-
|
187
|
+
|
188
|
+
all_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
189
|
+
all_logprobs, logits_metadata
|
190
|
+
)
|
246
191
|
|
247
192
|
# Get the logprob of top-k tokens
|
248
193
|
if logits_metadata.return_top_logprob:
|
249
|
-
|
250
|
-
|
251
|
-
|
194
|
+
(
|
195
|
+
input_top_logprobs_val,
|
196
|
+
input_top_logprobs_idx,
|
197
|
+
output_top_logprobs_val,
|
198
|
+
output_top_logprobs_idx,
|
199
|
+
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
252
200
|
else:
|
253
|
-
|
201
|
+
input_top_logprobs_val = input_top_logprobs_idx = (
|
202
|
+
output_top_logprobs_val
|
203
|
+
) = output_top_logprobs_idx = None
|
254
204
|
|
255
205
|
# Compute the normalized logprobs for the requested tokens.
|
256
206
|
# Note that we pad a zero at the end for easy batching.
|
@@ -273,8 +223,10 @@ class LogitsProcessor(nn.Module):
|
|
273
223
|
next_token_logprobs=last_logprobs,
|
274
224
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
275
225
|
input_token_logprobs=input_token_logprobs,
|
276
|
-
|
277
|
-
|
226
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
227
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
228
|
+
output_top_logprobs_val=output_top_logprobs_val,
|
229
|
+
output_top_logprobs_idx=output_top_logprobs_idx,
|
278
230
|
)
|
279
231
|
|
280
232
|
def _get_logits(
|
@@ -288,8 +240,94 @@ class LogitsProcessor(nn.Module):
|
|
288
240
|
else:
|
289
241
|
# GGUF models
|
290
242
|
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
243
|
+
|
244
|
+
# Optional scaling factor
|
245
|
+
if self.logit_scale is not None:
|
246
|
+
logits.mul_(self.logit_scale) # In-place multiply
|
291
247
|
return logits
|
292
248
|
|
249
|
+
@staticmethod
|
250
|
+
def _get_normalized_prompt_logprobs(
|
251
|
+
input_token_logprobs: torch.Tensor,
|
252
|
+
logits_metadata: LogitsMetadata,
|
253
|
+
):
|
254
|
+
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
255
|
+
pruned_lens = torch.tensor(
|
256
|
+
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
257
|
+
)
|
258
|
+
|
259
|
+
start = torch.zeros_like(pruned_lens)
|
260
|
+
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
261
|
+
end = torch.clamp(
|
262
|
+
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
263
|
+
)
|
264
|
+
sum_logp = (
|
265
|
+
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
266
|
+
)
|
267
|
+
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
268
|
+
return normalized_prompt_logprobs
|
269
|
+
|
270
|
+
@staticmethod
|
271
|
+
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
272
|
+
max_k = max(logits_metadata.top_logprobs_nums)
|
273
|
+
ret = all_logprobs.topk(max_k, dim=1)
|
274
|
+
values = ret.values.tolist()
|
275
|
+
indices = ret.indices.tolist()
|
276
|
+
|
277
|
+
if logits_metadata.forward_mode.is_decode():
|
278
|
+
output_top_logprobs_val = []
|
279
|
+
output_top_logprobs_idx = []
|
280
|
+
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
281
|
+
output_top_logprobs_val.append(values[i][:k])
|
282
|
+
output_top_logprobs_idx.append(indices[i][:k])
|
283
|
+
return None, None, output_top_logprobs_val, output_top_logprobs_idx
|
284
|
+
else:
|
285
|
+
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
286
|
+
output_top_logprobs_val, output_top_logprobs_idx = [], []
|
287
|
+
|
288
|
+
pt = 0
|
289
|
+
for k, pruned_len in zip(
|
290
|
+
logits_metadata.top_logprobs_nums,
|
291
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
292
|
+
):
|
293
|
+
if pruned_len <= 0:
|
294
|
+
input_top_logprobs_val.append([])
|
295
|
+
input_top_logprobs_idx.append([])
|
296
|
+
output_top_logprobs_val.append([])
|
297
|
+
output_top_logprobs_idx.append([])
|
298
|
+
continue
|
299
|
+
|
300
|
+
input_top_logprobs_val.append(
|
301
|
+
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
302
|
+
)
|
303
|
+
input_top_logprobs_idx.append(
|
304
|
+
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
305
|
+
)
|
306
|
+
output_top_logprobs_val.append(
|
307
|
+
list(
|
308
|
+
values[pt + pruned_len - 1][:k],
|
309
|
+
)
|
310
|
+
)
|
311
|
+
output_top_logprobs_idx.append(
|
312
|
+
list(
|
313
|
+
indices[pt + pruned_len - 1][:k],
|
314
|
+
)
|
315
|
+
)
|
316
|
+
pt += pruned_len
|
317
|
+
|
318
|
+
return (
|
319
|
+
input_top_logprobs_val,
|
320
|
+
input_top_logprobs_idx,
|
321
|
+
output_top_logprobs_val,
|
322
|
+
output_top_logprobs_idx,
|
323
|
+
)
|
324
|
+
|
325
|
+
@staticmethod
|
326
|
+
def compute_temp_top_p_normalized_logprobs(
|
327
|
+
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
|
328
|
+
) -> torch.Tensor:
|
329
|
+
return torch.nn.functional.log_softmax(last_logits, dim=-1)
|
330
|
+
|
293
331
|
|
294
332
|
def test():
|
295
333
|
all_logprobs = torch.tensor(
|
@@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|
13
13
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
14
14
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
15
15
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
16
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
17
16
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
18
17
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
19
18
|
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
|
23
22
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
24
23
|
|
25
24
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
25
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
26
26
|
|
27
27
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
28
28
|
"aqlm": AQLMConfig,
|
@@ -53,60 +53,16 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
53
53
|
return QUANTIZATION_METHODS[quantization]
|
54
54
|
|
55
55
|
|
56
|
-
def fp8_moe_apply(
|
57
|
-
self,
|
58
|
-
layer: torch.nn.Module,
|
59
|
-
x: torch.Tensor,
|
60
|
-
router_logits: torch.Tensor,
|
61
|
-
top_k: int,
|
62
|
-
renormalize: bool,
|
63
|
-
use_grouped_topk: bool,
|
64
|
-
topk_group: Optional[int] = None,
|
65
|
-
num_expert_group: Optional[int] = None,
|
66
|
-
custom_routing_function: Optional[Callable] = None,
|
67
|
-
) -> torch.Tensor:
|
68
|
-
"""Enhanced apply method for FP8 MoE."""
|
69
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
70
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
71
|
-
|
72
|
-
# Expert selection
|
73
|
-
topk_weights, topk_ids = FusedMoE.select_experts(
|
74
|
-
hidden_states=x,
|
75
|
-
router_logits=router_logits,
|
76
|
-
use_grouped_topk=use_grouped_topk,
|
77
|
-
top_k=top_k,
|
78
|
-
renormalize=renormalize,
|
79
|
-
topk_group=topk_group,
|
80
|
-
num_expert_group=num_expert_group,
|
81
|
-
custom_routing_function=custom_routing_function,
|
82
|
-
)
|
83
|
-
|
84
|
-
# Expert fusion with FP8 quantization
|
85
|
-
return fused_experts(
|
86
|
-
x,
|
87
|
-
layer.w13_weight,
|
88
|
-
layer.w2_weight,
|
89
|
-
topk_weights=topk_weights,
|
90
|
-
topk_ids=topk_ids,
|
91
|
-
inplace=True,
|
92
|
-
use_fp8_w8a8=True,
|
93
|
-
w1_scale=layer.w13_weight_scale,
|
94
|
-
w2_scale=layer.w2_weight_scale,
|
95
|
-
a1_scale=layer.w13_input_scale,
|
96
|
-
a2_scale=layer.w2_input_scale,
|
97
|
-
)
|
98
|
-
|
99
|
-
|
100
56
|
def fp8_get_quant_method(self, layer, prefix):
|
101
57
|
"""Enhanced get_quant_method for FP8 config."""
|
102
58
|
from vllm.model_executor.layers.linear import LinearBase
|
103
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
104
59
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
105
60
|
is_layer_skipped,
|
106
61
|
)
|
107
62
|
|
108
63
|
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
109
64
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
65
|
+
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
110
66
|
|
111
67
|
if isinstance(layer, LinearBase):
|
112
68
|
if is_layer_skipped(prefix, self.ignored_layers):
|
@@ -151,7 +107,6 @@ def awq_get_quant_method(self, layer, prefix):
|
|
151
107
|
|
152
108
|
def apply_monkey_patches():
|
153
109
|
"""Apply all monkey patches in one place."""
|
154
|
-
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
155
110
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
156
111
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
157
112
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|