sglang 0.2.14.post2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +2 -0
- sglang/bench_latency.py +39 -28
- sglang/lang/backend/runtime_endpoint.py +8 -4
- sglang/lang/interpreter.py +3 -0
- sglang/lang/ir.py +5 -0
- sglang/launch_server_llavavid.py +12 -12
- sglang/srt/configs/__init__.py +5 -0
- sglang/srt/configs/exaone.py +195 -0
- sglang/srt/constrained/fsm_cache.py +1 -1
- sglang/srt/conversation.py +24 -2
- sglang/srt/hf_transformers_utils.py +12 -12
- sglang/srt/layers/extend_attention.py +13 -8
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +94 -17
- sglang/srt/managers/controller_multi.py +5 -5
- sglang/srt/managers/controller_single.py +5 -5
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +26 -11
- sglang/srt/managers/tokenizer_manager.py +9 -9
- sglang/srt/managers/tp_worker.py +38 -26
- sglang/srt/model_config.py +3 -3
- sglang/srt/model_executor/cuda_graph_runner.py +26 -9
- sglang/srt/model_executor/forward_batch_info.py +68 -23
- sglang/srt/model_executor/model_runner.py +15 -22
- sglang/srt/models/chatglm.py +9 -15
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +57 -25
- sglang/srt/models/exaone.py +368 -0
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +5 -1
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +5 -1
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/{llama2.py → llama.py} +25 -45
- sglang/srt/models/llama_classification.py +34 -41
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +8 -11
- sglang/srt/models/llavavid.py +5 -6
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -2
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +6 -2
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +16 -1
- sglang/srt/openai_api/protocol.py +5 -5
- sglang/srt/sampling/sampling_batch_info.py +75 -6
- sglang/srt/server.py +6 -6
- sglang/srt/utils.py +0 -3
- sglang/test/runners.py +1 -1
- sglang/test/test_programs.py +68 -0
- sglang/test/test_utils.py +4 -0
- sglang/utils.py +39 -0
- sglang/version.py +1 -1
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/METADATA +9 -8
- sglang-0.3.0.dist-info/RECORD +118 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/WHEEL +1 -1
- sglang-0.2.14.post2.dist-info/RECORD +0 -115
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/top_level.txt +0 -0
@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture
|
|
26
26
|
from vllm.model_executor.custom_op import CustomOp
|
27
27
|
|
28
28
|
from sglang.srt.layers.logits_processor import (
|
29
|
-
LogitProcessorOutput,
|
30
29
|
LogitsMetadata,
|
31
30
|
LogitsProcessor,
|
31
|
+
LogitsProcessorOutput,
|
32
32
|
)
|
33
|
+
from sglang.srt.layers.sampler import SampleOutput
|
33
34
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
34
35
|
from sglang.srt.model_executor.forward_batch_info import (
|
35
36
|
ForwardMode,
|
36
37
|
InputMetadata,
|
37
38
|
update_flashinfer_indices,
|
38
39
|
)
|
40
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
39
41
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
40
42
|
|
41
43
|
|
@@ -44,8 +46,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
44
46
|
if isinstance(sub, CustomOp):
|
45
47
|
if reverse:
|
46
48
|
sub._forward_method = sub.forward_cuda
|
49
|
+
setattr(sub, "is_torch_compile", False)
|
47
50
|
else:
|
48
51
|
sub._forward_method = sub.forward_native
|
52
|
+
setattr(sub, "is_torch_compile", True)
|
49
53
|
if isinstance(sub, torch.nn.Module):
|
50
54
|
_to_torch(sub, reverse)
|
51
55
|
|
@@ -144,6 +148,10 @@ class CudaGraphRunner:
|
|
144
148
|
self.flashinfer_kv_indices.clone(),
|
145
149
|
]
|
146
150
|
|
151
|
+
# Sampling inputs
|
152
|
+
vocab_size = model_runner.model_config.vocab_size
|
153
|
+
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
154
|
+
|
147
155
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
148
156
|
|
149
157
|
if use_torch_compile:
|
@@ -235,6 +243,7 @@ class CudaGraphRunner:
|
|
235
243
|
def run_once():
|
236
244
|
input_metadata = InputMetadata(
|
237
245
|
forward_mode=ForwardMode.DECODE,
|
246
|
+
sampling_info=self.sampling_info[:bs],
|
238
247
|
batch_size=bs,
|
239
248
|
req_pool_indices=req_pool_indices,
|
240
249
|
seq_lens=seq_lens,
|
@@ -299,27 +308,35 @@ class CudaGraphRunner:
|
|
299
308
|
self.flashinfer_handlers[bs],
|
300
309
|
)
|
301
310
|
|
311
|
+
# Sampling inputs
|
312
|
+
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
313
|
+
|
302
314
|
# Replay
|
303
315
|
torch.cuda.synchronize()
|
304
316
|
self.graphs[bs].replay()
|
305
317
|
torch.cuda.synchronize()
|
306
|
-
|
318
|
+
sample_output, logits_output = self.output_buffers[bs]
|
307
319
|
|
308
320
|
# Unpad
|
309
321
|
if bs != raw_bs:
|
310
|
-
|
311
|
-
next_token_logits=
|
322
|
+
logits_output = LogitsProcessorOutput(
|
323
|
+
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
312
324
|
next_token_logprobs=None,
|
313
325
|
normalized_prompt_logprobs=None,
|
314
326
|
input_token_logprobs=None,
|
315
327
|
input_top_logprobs=None,
|
316
328
|
output_top_logprobs=None,
|
317
329
|
)
|
330
|
+
sample_output = SampleOutput(
|
331
|
+
sample_output.success[:raw_bs],
|
332
|
+
sample_output.probs[:raw_bs],
|
333
|
+
sample_output.batch_next_token_ids[:raw_bs],
|
334
|
+
)
|
318
335
|
|
319
336
|
# Extract logprobs
|
320
337
|
if batch.return_logprob:
|
321
|
-
|
322
|
-
|
338
|
+
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
|
339
|
+
logits_output.next_token_logits, dim=-1
|
323
340
|
)
|
324
341
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
325
342
|
if return_top_logprob:
|
@@ -327,8 +344,8 @@ class CudaGraphRunner:
|
|
327
344
|
forward_mode=ForwardMode.DECODE,
|
328
345
|
top_logprobs_nums=batch.top_logprobs_nums,
|
329
346
|
)
|
330
|
-
|
331
|
-
|
347
|
+
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
348
|
+
logits_output.next_token_logprobs, logits_metadata
|
332
349
|
)[1]
|
333
350
|
|
334
|
-
return
|
351
|
+
return sample_output, logits_output
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -20,12 +22,15 @@ from typing import TYPE_CHECKING, List
|
|
20
22
|
|
21
23
|
import numpy as np
|
22
24
|
import torch
|
25
|
+
import triton
|
26
|
+
import triton.language as tl
|
23
27
|
|
24
28
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
25
29
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
26
30
|
|
27
31
|
if TYPE_CHECKING:
|
28
32
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
33
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
29
34
|
|
30
35
|
|
31
36
|
class ForwardMode(IntEnum):
|
@@ -42,6 +47,7 @@ class InputMetadata:
|
|
42
47
|
"""Store all inforamtion of a forward pass."""
|
43
48
|
|
44
49
|
forward_mode: ForwardMode
|
50
|
+
sampling_info: SamplingBatchInfo
|
45
51
|
batch_size: int
|
46
52
|
req_pool_indices: torch.Tensor
|
47
53
|
seq_lens: torch.Tensor
|
@@ -169,6 +175,7 @@ class InputMetadata:
|
|
169
175
|
):
|
170
176
|
ret = cls(
|
171
177
|
forward_mode=forward_mode,
|
178
|
+
sampling_info=batch.sampling_info,
|
172
179
|
batch_size=batch.batch_size(),
|
173
180
|
req_pool_indices=batch.req_pool_indices,
|
174
181
|
seq_lens=batch.seq_lens,
|
@@ -179,6 +186,8 @@ class InputMetadata:
|
|
179
186
|
top_logprobs_nums=batch.top_logprobs_nums,
|
180
187
|
)
|
181
188
|
|
189
|
+
ret.sampling_info.prepare_penalties()
|
190
|
+
|
182
191
|
ret.compute_positions(batch)
|
183
192
|
|
184
193
|
ret.compute_extend_infos(batch)
|
@@ -255,6 +264,42 @@ class InputMetadata:
|
|
255
264
|
)
|
256
265
|
|
257
266
|
|
267
|
+
@triton.jit
|
268
|
+
def create_flashinfer_kv_indices_triton(
|
269
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
270
|
+
req_pool_indices_ptr,
|
271
|
+
page_kernel_lens_ptr,
|
272
|
+
kv_indptr,
|
273
|
+
kv_start_idx,
|
274
|
+
max_context_len,
|
275
|
+
kv_indices_ptr,
|
276
|
+
):
|
277
|
+
BLOCK_SIZE: tl.constexpr = 512
|
278
|
+
pid = tl.program_id(axis=0)
|
279
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
280
|
+
kv_indices_offset = tl.load(kv_indptr + pid)
|
281
|
+
|
282
|
+
kv_start = 0
|
283
|
+
kv_end = 0
|
284
|
+
if kv_start_idx:
|
285
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
286
|
+
kv_end = kv_start
|
287
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
288
|
+
|
289
|
+
req_to_token_ptr += req_pool_index * max_context_len
|
290
|
+
kv_indices_ptr += kv_indices_offset
|
291
|
+
|
292
|
+
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
293
|
+
st_offset = tl.arange(0, BLOCK_SIZE)
|
294
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
295
|
+
for _ in range(num_loop):
|
296
|
+
mask = ld_offset < kv_end
|
297
|
+
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
298
|
+
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
299
|
+
ld_offset += BLOCK_SIZE
|
300
|
+
st_offset += BLOCK_SIZE
|
301
|
+
|
302
|
+
|
258
303
|
def update_flashinfer_indices(
|
259
304
|
forward_mode,
|
260
305
|
model_runner,
|
@@ -278,17 +323,18 @@ def update_flashinfer_indices(
|
|
278
323
|
|
279
324
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
280
325
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
)
|
326
|
+
|
327
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
328
|
+
create_flashinfer_kv_indices_triton[(batch_size,)](
|
329
|
+
model_runner.req_to_token_pool.req_to_token,
|
330
|
+
req_pool_indices,
|
331
|
+
paged_kernel_lens,
|
332
|
+
kv_indptr,
|
333
|
+
None,
|
334
|
+
model_runner.req_to_token_pool.req_to_token.size(1),
|
335
|
+
kv_indices,
|
336
|
+
)
|
337
|
+
|
292
338
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
293
339
|
|
294
340
|
if forward_mode == ForwardMode.DECODE:
|
@@ -358,18 +404,17 @@ def update_flashinfer_indices(
|
|
358
404
|
|
359
405
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
360
406
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
).contiguous()
|
407
|
+
|
408
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
409
|
+
create_flashinfer_kv_indices_triton[(batch_size,)](
|
410
|
+
model_runner.req_to_token_pool.req_to_token,
|
411
|
+
req_pool_indices,
|
412
|
+
paged_kernel_lens,
|
413
|
+
kv_indptr,
|
414
|
+
kv_start_idx,
|
415
|
+
model_runner.req_to_token_pool.req_to_token.size(1),
|
416
|
+
kv_indices,
|
417
|
+
)
|
373
418
|
|
374
419
|
if forward_mode == ForwardMode.DECODE:
|
375
420
|
# CUDA graph uses different flashinfer_decode_wrapper
|
@@ -21,7 +21,7 @@ import importlib.resources
|
|
21
21
|
import logging
|
22
22
|
import pkgutil
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional, Type
|
24
|
+
from typing import Optional, Tuple, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
|
|
44
44
|
from vllm.model_executor.models import ModelRegistry
|
45
45
|
|
46
46
|
from sglang.global_config import global_config
|
47
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
+
from sglang.srt.layers.sampler import SampleOutput
|
47
49
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
48
50
|
from sglang.srt.mem_cache.memory_pool import (
|
49
51
|
MHATokenToKVPool,
|
@@ -160,6 +162,7 @@ class ModelRunner:
|
|
160
162
|
return min_per_gpu_memory
|
161
163
|
|
162
164
|
def load_model(self):
|
165
|
+
torch.set_num_threads(1)
|
163
166
|
logger.info(
|
164
167
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
165
168
|
)
|
@@ -193,9 +196,9 @@ class ModelRunner:
|
|
193
196
|
monkey_patch_vllm_qvk_linear_loader()
|
194
197
|
|
195
198
|
self.dtype = self.vllm_model_config.dtype
|
196
|
-
if self.model_config.
|
199
|
+
if self.model_config.model_override_args is not None:
|
197
200
|
self.vllm_model_config.hf_config.update(
|
198
|
-
self.model_config.
|
201
|
+
self.model_config.model_override_args
|
199
202
|
)
|
200
203
|
|
201
204
|
self.model = get_model(
|
@@ -346,13 +349,7 @@ class ModelRunner:
|
|
346
349
|
if self.server_args.kv_cache_dtype == "auto":
|
347
350
|
self.kv_cache_dtype = self.dtype
|
348
351
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
349
|
-
|
350
|
-
logger.warning(
|
351
|
-
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
352
|
-
)
|
353
|
-
self.kv_cache_dtype = self.dtype
|
354
|
-
else:
|
355
|
-
self.kv_cache_dtype = torch.float8_e5m2
|
352
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
356
353
|
else:
|
357
354
|
raise ValueError(
|
358
355
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -524,7 +521,11 @@ class ModelRunner:
|
|
524
521
|
|
525
522
|
@torch.inference_mode()
|
526
523
|
def forward_decode(self, batch: ScheduleBatch):
|
527
|
-
if
|
524
|
+
if (
|
525
|
+
self.cuda_graph_runner
|
526
|
+
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
527
|
+
and batch.sampling_info.can_run_in_cuda_graph()
|
528
|
+
):
|
528
529
|
return self.cuda_graph_runner.replay(batch)
|
529
530
|
|
530
531
|
input_metadata = InputMetadata.from_schedule_batch(
|
@@ -573,7 +574,9 @@ class ModelRunner:
|
|
573
574
|
input_metadata.image_offsets,
|
574
575
|
)
|
575
576
|
|
576
|
-
def forward(
|
577
|
+
def forward(
|
578
|
+
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
579
|
+
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
577
580
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
578
581
|
return self.forward_extend_multi_modal(batch)
|
579
582
|
elif forward_mode == ForwardMode.DECODE:
|
@@ -604,16 +607,6 @@ def import_model_classes():
|
|
604
607
|
assert entry.__name__ not in model_arch_name_to_cls
|
605
608
|
model_arch_name_to_cls[entry.__name__] = entry
|
606
609
|
|
607
|
-
# compat: some models such as chatglm has incorrect class set in config.json
|
608
|
-
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
609
|
-
if hasattr(module, "EntryClassRemapping") and isinstance(
|
610
|
-
module.EntryClassRemapping, list
|
611
|
-
):
|
612
|
-
for remap in module.EntryClassRemapping:
|
613
|
-
if isinstance(remap, tuple) and len(remap) == 2:
|
614
|
-
assert remap[0] not in model_arch_name_to_cls
|
615
|
-
model_arch_name_to_cls[remap[0]] = remap[1]
|
616
|
-
|
617
610
|
return model_arch_name_to_cls
|
618
611
|
|
619
612
|
|
sglang/srt/models/chatglm.py
CHANGED
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
|
|
31
31
|
)
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
|
-
from vllm.model_executor.layers.sampler import Sampler
|
35
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
36
35
|
ParallelLMHead,
|
37
36
|
VocabParallelEmbedding,
|
38
37
|
)
|
39
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
41
|
-
from vllm.sequence import SamplerOutput
|
42
39
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
43
40
|
|
44
41
|
from sglang.srt.layers.activation import SiluAndMul
|
45
42
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
43
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
+
from sglang.srt.layers.sampler import Sampler
|
48
46
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
47
|
|
50
48
|
LoraConfig = None
|
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
|
|
383
381
|
input_metadata: InputMetadata,
|
384
382
|
) -> torch.Tensor:
|
385
383
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
386
|
-
|
384
|
+
logits_output = self.logits_processor(
|
387
385
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
388
386
|
)
|
389
|
-
|
390
|
-
|
391
|
-
self,
|
392
|
-
logits: torch.Tensor,
|
393
|
-
sampling_metadata: SamplingMetadata,
|
394
|
-
) -> Optional[SamplerOutput]:
|
395
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
396
|
-
return next_tokens
|
387
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
388
|
+
return sample_output, logits_output
|
397
389
|
|
398
390
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
391
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
@@ -410,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module):
|
|
410
402
|
weight_loader(param, loaded_weight)
|
411
403
|
|
412
404
|
|
413
|
-
|
414
|
-
|
415
|
-
|
405
|
+
class ChatGLMModel(ChatGLMForCausalLM):
|
406
|
+
pass
|
407
|
+
|
408
|
+
|
409
|
+
EntryClass = [ChatGLMForCausalLM, ChatGLMModel]
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
+
from sglang.srt.layers.sampler import Sampler
|
67
68
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
69
|
|
69
70
|
|
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
|
|
326
327
|
self.config = config
|
327
328
|
self.quant_config = quant_config
|
328
329
|
self.logits_processor = LogitsProcessor(config)
|
330
|
+
self.sampler = Sampler()
|
329
331
|
self.model = CohereModel(config, quant_config)
|
330
332
|
|
331
333
|
@torch.no_grad()
|
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
|
|
340
342
|
positions,
|
341
343
|
input_metadata,
|
342
344
|
)
|
343
|
-
|
345
|
+
logits_output = self.logits_processor(
|
344
346
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
345
347
|
)
|
348
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
349
|
+
return sample_output, logits_output
|
346
350
|
|
347
351
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
348
352
|
stacked_params_mapping = [
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
|
|
382
383
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
383
384
|
)
|
384
385
|
self.logits_processor = LogitsProcessor(config)
|
386
|
+
self.sampler = Sampler()
|
385
387
|
|
386
388
|
@torch.no_grad()
|
387
389
|
def forward(
|
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
|
|
391
393
|
input_metadata: InputMetadata,
|
392
394
|
) -> torch.Tensor:
|
393
395
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
394
|
-
|
396
|
+
logits_output = self.logits_processor(
|
395
397
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
396
398
|
)
|
399
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
400
|
+
return sample_output, logits_output
|
397
401
|
|
398
402
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
403
|
expert_params_mapping = [
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
49
50
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
51
|
|
51
52
|
|
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
|
|
385
386
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
386
387
|
)
|
387
388
|
self.logits_processor = LogitsProcessor(config)
|
389
|
+
self.sampler = Sampler()
|
388
390
|
|
389
391
|
@torch.no_grad()
|
390
392
|
def forward(
|
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
|
|
394
396
|
input_metadata: InputMetadata,
|
395
397
|
) -> torch.Tensor:
|
396
398
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
397
|
-
|
399
|
+
logits_output = self.logits_processor(
|
398
400
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
399
401
|
)
|
402
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
403
|
+
return sample_output, logits_output
|
400
404
|
|
401
405
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
402
406
|
stacked_params_mapping = [
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -19,6 +19,7 @@ limitations under the License.
|
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
|
+
from flashinfer import bmm_fp8
|
22
23
|
from torch import nn
|
23
24
|
from transformers import PretrainedConfig
|
24
25
|
from vllm.config import CacheConfig
|
@@ -45,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
45
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
48
50
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
51
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
52
|
|
@@ -160,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
160
162
|
return 0.1 * mscale * math.log(scale) + 1.0
|
161
163
|
|
162
164
|
|
165
|
+
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
166
|
+
finfo = torch.finfo(dtype)
|
167
|
+
min_val, max_val = x.aminmax()
|
168
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
169
|
+
scale = finfo.max / amax
|
170
|
+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
171
|
+
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
172
|
+
|
173
|
+
|
163
174
|
class DeepseekV2Attention(nn.Module):
|
164
175
|
|
165
176
|
def __init__(
|
@@ -254,11 +265,6 @@ class DeepseekV2Attention(nn.Module):
|
|
254
265
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
255
266
|
self.scaling = self.scaling * mscale * mscale
|
256
267
|
|
257
|
-
# self.attn = Attention(self.num_heads,
|
258
|
-
# self.qk_head_dim,
|
259
|
-
# self.scaling,
|
260
|
-
# num_kv_heads=self.num_heads)
|
261
|
-
|
262
268
|
# TODO, support head_size 192
|
263
269
|
self.attn = RadixAttention(
|
264
270
|
self.num_local_heads,
|
@@ -282,7 +288,7 @@ class DeepseekV2Attention(nn.Module):
|
|
282
288
|
q = self.q_proj(hidden_states)[0].view(
|
283
289
|
-1, self.num_local_heads, self.qk_head_dim
|
284
290
|
)
|
285
|
-
|
291
|
+
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
286
292
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
287
293
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
288
294
|
latent_cache = latent_cache.unsqueeze(1)
|
@@ -416,12 +422,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
416
422
|
v_head_dim=self.kv_lora_rank,
|
417
423
|
)
|
418
424
|
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
).split([qk_nope_head_dim, v_head_dim], dim=1)
|
423
|
-
self.w_kc = w_kc
|
424
|
-
self.w_vc = w_vc
|
425
|
+
self.w_kc = None
|
426
|
+
self.w_vc = None
|
427
|
+
self.w_scale = None
|
425
428
|
|
426
429
|
def forward(
|
427
430
|
self,
|
@@ -442,8 +445,17 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
442
445
|
-1, self.num_local_heads, self.qk_head_dim
|
443
446
|
)
|
444
447
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
445
|
-
|
446
|
-
|
448
|
+
|
449
|
+
if self.w_kc.dtype == torch.float8_e4m3fn:
|
450
|
+
q_nope_val, q_nope_scale = input_to_float8(
|
451
|
+
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
452
|
+
)
|
453
|
+
q_nope_out = bmm_fp8(
|
454
|
+
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
455
|
+
)
|
456
|
+
else:
|
457
|
+
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
458
|
+
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
447
459
|
|
448
460
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
449
461
|
v_input = latent_cache[..., : self.kv_lora_rank]
|
@@ -458,16 +470,21 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
458
470
|
|
459
471
|
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
460
472
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
461
|
-
attn_bmm_output = attn_output.new_empty(
|
462
|
-
q_len, self.num_local_heads, self.v_head_dim
|
463
|
-
)
|
464
|
-
torch.bmm(
|
465
|
-
attn_output.transpose(0, 1),
|
466
|
-
self.w_vc.transpose(1, 2).contiguous(),
|
467
|
-
out=attn_bmm_output.transpose(0, 1),
|
468
|
-
)
|
469
473
|
|
470
|
-
|
474
|
+
if self.w_vc.dtype == torch.float8_e4m3fn:
|
475
|
+
attn_output_val, attn_output_scale = input_to_float8(
|
476
|
+
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
477
|
+
)
|
478
|
+
attn_bmm_output = bmm_fp8(
|
479
|
+
attn_output_val,
|
480
|
+
self.w_vc,
|
481
|
+
attn_output_scale,
|
482
|
+
self.w_scale,
|
483
|
+
torch.bfloat16,
|
484
|
+
)
|
485
|
+
else:
|
486
|
+
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
487
|
+
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
471
488
|
output, _ = self.o_proj(attn_output)
|
472
489
|
|
473
490
|
return output
|
@@ -632,6 +649,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
632
649
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
633
650
|
)
|
634
651
|
self.logits_processor = LogitsProcessor(config)
|
652
|
+
self.sampler = Sampler()
|
635
653
|
|
636
654
|
def forward(
|
637
655
|
self,
|
@@ -640,9 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
640
658
|
input_metadata: InputMetadata,
|
641
659
|
) -> torch.Tensor:
|
642
660
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
643
|
-
|
661
|
+
logits_output = self.logits_processor(
|
644
662
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
645
663
|
)
|
664
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
665
|
+
return sample_output, logits_output
|
646
666
|
|
647
667
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
648
668
|
stacked_params_mapping = [
|
@@ -695,7 +715,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
695
715
|
weight_loader(
|
696
716
|
param,
|
697
717
|
loaded_weight,
|
698
|
-
|
718
|
+
name,
|
699
719
|
shard_id=shard_id,
|
700
720
|
expert_id=expert_id,
|
701
721
|
)
|
@@ -711,5 +731,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
711
731
|
)
|
712
732
|
weight_loader(param, loaded_weight)
|
713
733
|
|
734
|
+
if global_server_args_dict["enable_mla"]:
|
735
|
+
for layer_id in range(self.config.num_hidden_layers):
|
736
|
+
self_attn = self.model.layers[layer_id].self_attn
|
737
|
+
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
738
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
739
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
740
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
741
|
+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
742
|
+
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
743
|
+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
744
|
+
del self_attn.kv_b_proj
|
745
|
+
|
714
746
|
|
715
747
|
EntryClass = DeepseekV2ForCausalLM
|