sglang 0.2.14.post2__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +2 -0
- sglang/bench_latency.py +39 -28
- sglang/lang/interpreter.py +3 -0
- sglang/lang/ir.py +5 -0
- sglang/launch_server_llavavid.py +12 -12
- sglang/srt/configs/__init__.py +5 -0
- sglang/srt/configs/exaone.py +195 -0
- sglang/srt/constrained/fsm_cache.py +1 -1
- sglang/srt/conversation.py +24 -2
- sglang/srt/hf_transformers_utils.py +11 -11
- sglang/srt/layers/extend_attention.py +13 -8
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +69 -16
- sglang/srt/managers/controller_multi.py +5 -5
- sglang/srt/managers/controller_single.py +5 -5
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +20 -8
- sglang/srt/managers/tokenizer_manager.py +2 -2
- sglang/srt/managers/tp_worker.py +38 -26
- sglang/srt/model_config.py +3 -3
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +68 -23
- sglang/srt/model_executor/model_runner.py +14 -12
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +57 -25
- sglang/srt/models/exaone.py +399 -0
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +5 -1
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +5 -1
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +6 -2
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +6 -2
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +16 -1
- sglang/srt/openai_api/protocol.py +5 -5
- sglang/srt/sampling/sampling_batch_info.py +79 -6
- sglang/srt/server.py +6 -6
- sglang/srt/utils.py +0 -3
- sglang/test/runners.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/METADATA +7 -7
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/RECORD +55 -52
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
- {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -20,12 +22,15 @@ from typing import TYPE_CHECKING, List
|
|
20
22
|
|
21
23
|
import numpy as np
|
22
24
|
import torch
|
25
|
+
import triton
|
26
|
+
import triton.language as tl
|
23
27
|
|
24
28
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
25
29
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
26
30
|
|
27
31
|
if TYPE_CHECKING:
|
28
32
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
33
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
29
34
|
|
30
35
|
|
31
36
|
class ForwardMode(IntEnum):
|
@@ -42,6 +47,7 @@ class InputMetadata:
|
|
42
47
|
"""Store all inforamtion of a forward pass."""
|
43
48
|
|
44
49
|
forward_mode: ForwardMode
|
50
|
+
sampling_info: SamplingBatchInfo
|
45
51
|
batch_size: int
|
46
52
|
req_pool_indices: torch.Tensor
|
47
53
|
seq_lens: torch.Tensor
|
@@ -169,6 +175,7 @@ class InputMetadata:
|
|
169
175
|
):
|
170
176
|
ret = cls(
|
171
177
|
forward_mode=forward_mode,
|
178
|
+
sampling_info=batch.sampling_info,
|
172
179
|
batch_size=batch.batch_size(),
|
173
180
|
req_pool_indices=batch.req_pool_indices,
|
174
181
|
seq_lens=batch.seq_lens,
|
@@ -179,6 +186,8 @@ class InputMetadata:
|
|
179
186
|
top_logprobs_nums=batch.top_logprobs_nums,
|
180
187
|
)
|
181
188
|
|
189
|
+
ret.sampling_info.prepare_penalties()
|
190
|
+
|
182
191
|
ret.compute_positions(batch)
|
183
192
|
|
184
193
|
ret.compute_extend_infos(batch)
|
@@ -255,6 +264,42 @@ class InputMetadata:
|
|
255
264
|
)
|
256
265
|
|
257
266
|
|
267
|
+
@triton.jit
|
268
|
+
def create_flashinfer_kv_indices_triton(
|
269
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
270
|
+
req_pool_indices_ptr,
|
271
|
+
page_kernel_lens_ptr,
|
272
|
+
kv_indptr,
|
273
|
+
kv_start_idx,
|
274
|
+
max_context_len,
|
275
|
+
kv_indices_ptr,
|
276
|
+
):
|
277
|
+
BLOCK_SIZE: tl.constexpr = 512
|
278
|
+
pid = tl.program_id(axis=0)
|
279
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
280
|
+
kv_indices_offset = tl.load(kv_indptr + pid)
|
281
|
+
|
282
|
+
kv_start = 0
|
283
|
+
kv_end = 0
|
284
|
+
if kv_start_idx:
|
285
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
286
|
+
kv_end = kv_start
|
287
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
288
|
+
|
289
|
+
req_to_token_ptr += req_pool_index * max_context_len
|
290
|
+
kv_indices_ptr += kv_indices_offset
|
291
|
+
|
292
|
+
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
293
|
+
st_offset = tl.arange(0, BLOCK_SIZE)
|
294
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
295
|
+
for _ in range(num_loop):
|
296
|
+
mask = ld_offset < kv_end
|
297
|
+
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
298
|
+
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
299
|
+
ld_offset += BLOCK_SIZE
|
300
|
+
st_offset += BLOCK_SIZE
|
301
|
+
|
302
|
+
|
258
303
|
def update_flashinfer_indices(
|
259
304
|
forward_mode,
|
260
305
|
model_runner,
|
@@ -278,17 +323,18 @@ def update_flashinfer_indices(
|
|
278
323
|
|
279
324
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
280
325
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
)
|
326
|
+
|
327
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
328
|
+
create_flashinfer_kv_indices_triton[(batch_size,)](
|
329
|
+
model_runner.req_to_token_pool.req_to_token,
|
330
|
+
req_pool_indices,
|
331
|
+
paged_kernel_lens,
|
332
|
+
kv_indptr,
|
333
|
+
None,
|
334
|
+
model_runner.req_to_token_pool.req_to_token.size(1),
|
335
|
+
kv_indices,
|
336
|
+
)
|
337
|
+
|
292
338
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
293
339
|
|
294
340
|
if forward_mode == ForwardMode.DECODE:
|
@@ -358,18 +404,17 @@ def update_flashinfer_indices(
|
|
358
404
|
|
359
405
|
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
360
406
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
).contiguous()
|
407
|
+
|
408
|
+
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
409
|
+
create_flashinfer_kv_indices_triton[(batch_size,)](
|
410
|
+
model_runner.req_to_token_pool.req_to_token,
|
411
|
+
req_pool_indices,
|
412
|
+
paged_kernel_lens,
|
413
|
+
kv_indptr,
|
414
|
+
kv_start_idx,
|
415
|
+
model_runner.req_to_token_pool.req_to_token.size(1),
|
416
|
+
kv_indices,
|
417
|
+
)
|
373
418
|
|
374
419
|
if forward_mode == ForwardMode.DECODE:
|
375
420
|
# CUDA graph uses different flashinfer_decode_wrapper
|
@@ -21,7 +21,7 @@ import importlib.resources
|
|
21
21
|
import logging
|
22
22
|
import pkgutil
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional, Type
|
24
|
+
from typing import Optional, Tuple, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
|
|
44
44
|
from vllm.model_executor.models import ModelRegistry
|
45
45
|
|
46
46
|
from sglang.global_config import global_config
|
47
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
+
from sglang.srt.layers.sampler import SampleOutput
|
47
49
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
48
50
|
from sglang.srt.mem_cache.memory_pool import (
|
49
51
|
MHATokenToKVPool,
|
@@ -193,9 +195,9 @@ class ModelRunner:
|
|
193
195
|
monkey_patch_vllm_qvk_linear_loader()
|
194
196
|
|
195
197
|
self.dtype = self.vllm_model_config.dtype
|
196
|
-
if self.model_config.
|
198
|
+
if self.model_config.model_override_args is not None:
|
197
199
|
self.vllm_model_config.hf_config.update(
|
198
|
-
self.model_config.
|
200
|
+
self.model_config.model_override_args
|
199
201
|
)
|
200
202
|
|
201
203
|
self.model = get_model(
|
@@ -346,13 +348,7 @@ class ModelRunner:
|
|
346
348
|
if self.server_args.kv_cache_dtype == "auto":
|
347
349
|
self.kv_cache_dtype = self.dtype
|
348
350
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
349
|
-
|
350
|
-
logger.warning(
|
351
|
-
"FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
|
352
|
-
)
|
353
|
-
self.kv_cache_dtype = self.dtype
|
354
|
-
else:
|
355
|
-
self.kv_cache_dtype = torch.float8_e5m2
|
351
|
+
self.kv_cache_dtype = torch.float8_e5m2
|
356
352
|
else:
|
357
353
|
raise ValueError(
|
358
354
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
@@ -524,7 +520,11 @@ class ModelRunner:
|
|
524
520
|
|
525
521
|
@torch.inference_mode()
|
526
522
|
def forward_decode(self, batch: ScheduleBatch):
|
527
|
-
if
|
523
|
+
if (
|
524
|
+
self.cuda_graph_runner
|
525
|
+
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
526
|
+
and not batch.sampling_info.has_bias()
|
527
|
+
):
|
528
528
|
return self.cuda_graph_runner.replay(batch)
|
529
529
|
|
530
530
|
input_metadata = InputMetadata.from_schedule_batch(
|
@@ -573,7 +573,9 @@ class ModelRunner:
|
|
573
573
|
input_metadata.image_offsets,
|
574
574
|
)
|
575
575
|
|
576
|
-
def forward(
|
576
|
+
def forward(
|
577
|
+
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
578
|
+
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
577
579
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
578
580
|
return self.forward_extend_multi_modal(batch)
|
579
581
|
elif forward_mode == ForwardMode.DECODE:
|
sglang/srt/models/chatglm.py
CHANGED
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
|
|
31
31
|
)
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
|
-
from vllm.model_executor.layers.sampler import Sampler
|
35
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
36
35
|
ParallelLMHead,
|
37
36
|
VocabParallelEmbedding,
|
38
37
|
)
|
39
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
41
|
-
from vllm.sequence import SamplerOutput
|
42
39
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
43
40
|
|
44
41
|
from sglang.srt.layers.activation import SiluAndMul
|
45
42
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
43
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
+
from sglang.srt.layers.sampler import Sampler
|
48
46
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
47
|
|
50
48
|
LoraConfig = None
|
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
|
|
383
381
|
input_metadata: InputMetadata,
|
384
382
|
) -> torch.Tensor:
|
385
383
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
386
|
-
|
384
|
+
logits_output = self.logits_processor(
|
387
385
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
388
386
|
)
|
389
|
-
|
390
|
-
|
391
|
-
self,
|
392
|
-
logits: torch.Tensor,
|
393
|
-
sampling_metadata: SamplingMetadata,
|
394
|
-
) -> Optional[SamplerOutput]:
|
395
|
-
next_tokens = self.sampler(logits, sampling_metadata)
|
396
|
-
return next_tokens
|
387
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
388
|
+
return sample_output, logits_output
|
397
389
|
|
398
390
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
391
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
+
from sglang.srt.layers.sampler import Sampler
|
67
68
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
69
|
|
69
70
|
|
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
|
|
326
327
|
self.config = config
|
327
328
|
self.quant_config = quant_config
|
328
329
|
self.logits_processor = LogitsProcessor(config)
|
330
|
+
self.sampler = Sampler()
|
329
331
|
self.model = CohereModel(config, quant_config)
|
330
332
|
|
331
333
|
@torch.no_grad()
|
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
|
|
340
342
|
positions,
|
341
343
|
input_metadata,
|
342
344
|
)
|
343
|
-
|
345
|
+
logits_output = self.logits_processor(
|
344
346
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
345
347
|
)
|
348
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
349
|
+
return sample_output, logits_output
|
346
350
|
|
347
351
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
348
352
|
stacked_params_mapping = [
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
|
|
382
383
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
383
384
|
)
|
384
385
|
self.logits_processor = LogitsProcessor(config)
|
386
|
+
self.sampler = Sampler()
|
385
387
|
|
386
388
|
@torch.no_grad()
|
387
389
|
def forward(
|
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
|
|
391
393
|
input_metadata: InputMetadata,
|
392
394
|
) -> torch.Tensor:
|
393
395
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
394
|
-
|
396
|
+
logits_output = self.logits_processor(
|
395
397
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
396
398
|
)
|
399
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
400
|
+
return sample_output, logits_output
|
397
401
|
|
398
402
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
399
403
|
expert_params_mapping = [
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
49
50
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
51
|
|
51
52
|
|
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
|
|
385
386
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
386
387
|
)
|
387
388
|
self.logits_processor = LogitsProcessor(config)
|
389
|
+
self.sampler = Sampler()
|
388
390
|
|
389
391
|
@torch.no_grad()
|
390
392
|
def forward(
|
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
|
|
394
396
|
input_metadata: InputMetadata,
|
395
397
|
) -> torch.Tensor:
|
396
398
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
397
|
-
|
399
|
+
logits_output = self.logits_processor(
|
398
400
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
399
401
|
)
|
402
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
403
|
+
return sample_output, logits_output
|
400
404
|
|
401
405
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
402
406
|
stacked_params_mapping = [
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -19,6 +19,7 @@ limitations under the License.
|
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
|
+
from flashinfer import bmm_fp8
|
22
23
|
from torch import nn
|
23
24
|
from transformers import PretrainedConfig
|
24
25
|
from vllm.config import CacheConfig
|
@@ -45,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
45
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
48
50
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
51
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
52
|
|
@@ -160,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
160
162
|
return 0.1 * mscale * math.log(scale) + 1.0
|
161
163
|
|
162
164
|
|
165
|
+
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
166
|
+
finfo = torch.finfo(dtype)
|
167
|
+
min_val, max_val = x.aminmax()
|
168
|
+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
169
|
+
scale = finfo.max / amax
|
170
|
+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
171
|
+
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
172
|
+
|
173
|
+
|
163
174
|
class DeepseekV2Attention(nn.Module):
|
164
175
|
|
165
176
|
def __init__(
|
@@ -254,11 +265,6 @@ class DeepseekV2Attention(nn.Module):
|
|
254
265
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
255
266
|
self.scaling = self.scaling * mscale * mscale
|
256
267
|
|
257
|
-
# self.attn = Attention(self.num_heads,
|
258
|
-
# self.qk_head_dim,
|
259
|
-
# self.scaling,
|
260
|
-
# num_kv_heads=self.num_heads)
|
261
|
-
|
262
268
|
# TODO, support head_size 192
|
263
269
|
self.attn = RadixAttention(
|
264
270
|
self.num_local_heads,
|
@@ -282,7 +288,7 @@ class DeepseekV2Attention(nn.Module):
|
|
282
288
|
q = self.q_proj(hidden_states)[0].view(
|
283
289
|
-1, self.num_local_heads, self.qk_head_dim
|
284
290
|
)
|
285
|
-
|
291
|
+
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
286
292
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
287
293
|
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
288
294
|
latent_cache = latent_cache.unsqueeze(1)
|
@@ -416,12 +422,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
416
422
|
v_head_dim=self.kv_lora_rank,
|
417
423
|
)
|
418
424
|
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
).split([qk_nope_head_dim, v_head_dim], dim=1)
|
423
|
-
self.w_kc = w_kc
|
424
|
-
self.w_vc = w_vc
|
425
|
+
self.w_kc = None
|
426
|
+
self.w_vc = None
|
427
|
+
self.w_scale = None
|
425
428
|
|
426
429
|
def forward(
|
427
430
|
self,
|
@@ -442,8 +445,17 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
442
445
|
-1, self.num_local_heads, self.qk_head_dim
|
443
446
|
)
|
444
447
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
445
|
-
|
446
|
-
|
448
|
+
|
449
|
+
if self.w_kc.dtype == torch.float8_e4m3fn:
|
450
|
+
q_nope_val, q_nope_scale = input_to_float8(
|
451
|
+
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
452
|
+
)
|
453
|
+
q_nope_out = bmm_fp8(
|
454
|
+
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
455
|
+
)
|
456
|
+
else:
|
457
|
+
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
458
|
+
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
|
447
459
|
|
448
460
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
449
461
|
v_input = latent_cache[..., : self.kv_lora_rank]
|
@@ -458,16 +470,21 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
458
470
|
|
459
471
|
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
|
460
472
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
461
|
-
attn_bmm_output = attn_output.new_empty(
|
462
|
-
q_len, self.num_local_heads, self.v_head_dim
|
463
|
-
)
|
464
|
-
torch.bmm(
|
465
|
-
attn_output.transpose(0, 1),
|
466
|
-
self.w_vc.transpose(1, 2).contiguous(),
|
467
|
-
out=attn_bmm_output.transpose(0, 1),
|
468
|
-
)
|
469
473
|
|
470
|
-
|
474
|
+
if self.w_vc.dtype == torch.float8_e4m3fn:
|
475
|
+
attn_output_val, attn_output_scale = input_to_float8(
|
476
|
+
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
477
|
+
)
|
478
|
+
attn_bmm_output = bmm_fp8(
|
479
|
+
attn_output_val,
|
480
|
+
self.w_vc,
|
481
|
+
attn_output_scale,
|
482
|
+
self.w_scale,
|
483
|
+
torch.bfloat16,
|
484
|
+
)
|
485
|
+
else:
|
486
|
+
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
|
487
|
+
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
471
488
|
output, _ = self.o_proj(attn_output)
|
472
489
|
|
473
490
|
return output
|
@@ -632,6 +649,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
632
649
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
633
650
|
)
|
634
651
|
self.logits_processor = LogitsProcessor(config)
|
652
|
+
self.sampler = Sampler()
|
635
653
|
|
636
654
|
def forward(
|
637
655
|
self,
|
@@ -640,9 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
640
658
|
input_metadata: InputMetadata,
|
641
659
|
) -> torch.Tensor:
|
642
660
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
643
|
-
|
661
|
+
logits_output = self.logits_processor(
|
644
662
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
645
663
|
)
|
664
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
665
|
+
return sample_output, logits_output
|
646
666
|
|
647
667
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
648
668
|
stacked_params_mapping = [
|
@@ -695,7 +715,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
695
715
|
weight_loader(
|
696
716
|
param,
|
697
717
|
loaded_weight,
|
698
|
-
|
718
|
+
name,
|
699
719
|
shard_id=shard_id,
|
700
720
|
expert_id=expert_id,
|
701
721
|
)
|
@@ -711,5 +731,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
711
731
|
)
|
712
732
|
weight_loader(param, loaded_weight)
|
713
733
|
|
734
|
+
if global_server_args_dict["enable_mla"]:
|
735
|
+
for layer_id in range(self.config.num_hidden_layers):
|
736
|
+
self_attn = self.model.layers[layer_id].self_attn
|
737
|
+
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
738
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
739
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
740
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
741
|
+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
742
|
+
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
743
|
+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
744
|
+
del self_attn.kv_b_proj
|
745
|
+
|
714
746
|
|
715
747
|
EntryClass = DeepseekV2ForCausalLM
|