sglang 0.2.14.post2__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +12 -12
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -11
  11. sglang/srt/layers/extend_attention.py +13 -8
  12. sglang/srt/layers/logits_processor.py +4 -4
  13. sglang/srt/layers/sampler.py +69 -16
  14. sglang/srt/managers/controller_multi.py +5 -5
  15. sglang/srt/managers/controller_single.py +5 -5
  16. sglang/srt/managers/io_struct.py +6 -1
  17. sglang/srt/managers/schedule_batch.py +20 -8
  18. sglang/srt/managers/tokenizer_manager.py +2 -2
  19. sglang/srt/managers/tp_worker.py +38 -26
  20. sglang/srt/model_config.py +3 -3
  21. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  22. sglang/srt/model_executor/forward_batch_info.py +68 -23
  23. sglang/srt/model_executor/model_runner.py +14 -12
  24. sglang/srt/models/chatglm.py +4 -12
  25. sglang/srt/models/commandr.py +5 -1
  26. sglang/srt/models/dbrx.py +5 -1
  27. sglang/srt/models/deepseek.py +5 -1
  28. sglang/srt/models/deepseek_v2.py +57 -25
  29. sglang/srt/models/exaone.py +399 -0
  30. sglang/srt/models/gemma.py +5 -1
  31. sglang/srt/models/gemma2.py +5 -1
  32. sglang/srt/models/gpt_bigcode.py +5 -1
  33. sglang/srt/models/grok.py +5 -1
  34. sglang/srt/models/internlm2.py +5 -1
  35. sglang/srt/models/llama2.py +7 -3
  36. sglang/srt/models/llama_classification.py +2 -2
  37. sglang/srt/models/minicpm.py +5 -1
  38. sglang/srt/models/mixtral.py +6 -2
  39. sglang/srt/models/mixtral_quant.py +5 -1
  40. sglang/srt/models/qwen.py +5 -2
  41. sglang/srt/models/qwen2.py +6 -2
  42. sglang/srt/models/qwen2_moe.py +5 -14
  43. sglang/srt/models/stablelm.py +5 -1
  44. sglang/srt/openai_api/adapter.py +16 -1
  45. sglang/srt/openai_api/protocol.py +5 -5
  46. sglang/srt/sampling/sampling_batch_info.py +79 -6
  47. sglang/srt/server.py +6 -6
  48. sglang/srt/utils.py +0 -3
  49. sglang/test/runners.py +1 -1
  50. sglang/version.py +1 -1
  51. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/METADATA +7 -7
  52. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/RECORD +55 -52
  53. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  54. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  55. {sglang-0.2.14.post2.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,12 +22,15 @@ from typing import TYPE_CHECKING, List
20
22
 
21
23
  import numpy as np
22
24
  import torch
25
+ import triton
26
+ import triton.language as tl
23
27
 
24
28
  from sglang.srt.managers.schedule_batch import ScheduleBatch
25
29
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
26
30
 
27
31
  if TYPE_CHECKING:
28
32
  from sglang.srt.model_executor.model_runner import ModelRunner
33
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
29
34
 
30
35
 
31
36
  class ForwardMode(IntEnum):
@@ -42,6 +47,7 @@ class InputMetadata:
42
47
  """Store all inforamtion of a forward pass."""
43
48
 
44
49
  forward_mode: ForwardMode
50
+ sampling_info: SamplingBatchInfo
45
51
  batch_size: int
46
52
  req_pool_indices: torch.Tensor
47
53
  seq_lens: torch.Tensor
@@ -169,6 +175,7 @@ class InputMetadata:
169
175
  ):
170
176
  ret = cls(
171
177
  forward_mode=forward_mode,
178
+ sampling_info=batch.sampling_info,
172
179
  batch_size=batch.batch_size(),
173
180
  req_pool_indices=batch.req_pool_indices,
174
181
  seq_lens=batch.seq_lens,
@@ -179,6 +186,8 @@ class InputMetadata:
179
186
  top_logprobs_nums=batch.top_logprobs_nums,
180
187
  )
181
188
 
189
+ ret.sampling_info.prepare_penalties()
190
+
182
191
  ret.compute_positions(batch)
183
192
 
184
193
  ret.compute_extend_infos(batch)
@@ -255,6 +264,42 @@ class InputMetadata:
255
264
  )
256
265
 
257
266
 
267
+ @triton.jit
268
+ def create_flashinfer_kv_indices_triton(
269
+ req_to_token_ptr, # [max_batch, max_context_len]
270
+ req_pool_indices_ptr,
271
+ page_kernel_lens_ptr,
272
+ kv_indptr,
273
+ kv_start_idx,
274
+ max_context_len,
275
+ kv_indices_ptr,
276
+ ):
277
+ BLOCK_SIZE: tl.constexpr = 512
278
+ pid = tl.program_id(axis=0)
279
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
280
+ kv_indices_offset = tl.load(kv_indptr + pid)
281
+
282
+ kv_start = 0
283
+ kv_end = 0
284
+ if kv_start_idx:
285
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
286
+ kv_end = kv_start
287
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
288
+
289
+ req_to_token_ptr += req_pool_index * max_context_len
290
+ kv_indices_ptr += kv_indices_offset
291
+
292
+ ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
293
+ st_offset = tl.arange(0, BLOCK_SIZE)
294
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
295
+ for _ in range(num_loop):
296
+ mask = ld_offset < kv_end
297
+ data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
298
+ tl.store(kv_indices_ptr + st_offset, data, mask=mask)
299
+ ld_offset += BLOCK_SIZE
300
+ st_offset += BLOCK_SIZE
301
+
302
+
258
303
  def update_flashinfer_indices(
259
304
  forward_mode,
260
305
  model_runner,
@@ -278,17 +323,18 @@ def update_flashinfer_indices(
278
323
 
279
324
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
280
325
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
281
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
282
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
283
- kv_indices = torch.cat(
284
- [
285
- model_runner.req_to_token_pool.req_to_token[
286
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
287
- ]
288
- for i in range(batch_size)
289
- ],
290
- dim=0,
291
- ).contiguous()
326
+
327
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
328
+ create_flashinfer_kv_indices_triton[(batch_size,)](
329
+ model_runner.req_to_token_pool.req_to_token,
330
+ req_pool_indices,
331
+ paged_kernel_lens,
332
+ kv_indptr,
333
+ None,
334
+ model_runner.req_to_token_pool.req_to_token.size(1),
335
+ kv_indices,
336
+ )
337
+
292
338
  kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
293
339
 
294
340
  if forward_mode == ForwardMode.DECODE:
@@ -358,18 +404,17 @@ def update_flashinfer_indices(
358
404
 
359
405
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
360
406
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
361
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
362
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
363
- kv_indices = torch.cat(
364
- [
365
- model_runner.req_to_token_pool.req_to_token[
366
- req_pool_indices_cpu[i],
367
- kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
368
- ]
369
- for i in range(batch_size)
370
- ],
371
- dim=0,
372
- ).contiguous()
407
+
408
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
409
+ create_flashinfer_kv_indices_triton[(batch_size,)](
410
+ model_runner.req_to_token_pool.req_to_token,
411
+ req_pool_indices,
412
+ paged_kernel_lens,
413
+ kv_indptr,
414
+ kv_start_idx,
415
+ model_runner.req_to_token_pool.req_to_token.size(1),
416
+ kv_indices,
417
+ )
373
418
 
374
419
  if forward_mode == ForwardMode.DECODE:
375
420
  # CUDA graph uses different flashinfer_decode_wrapper
@@ -21,7 +21,7 @@ import importlib.resources
21
21
  import logging
22
22
  import pkgutil
23
23
  from functools import lru_cache
24
- from typing import Optional, Type
24
+ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
44
44
  from vllm.model_executor.models import ModelRegistry
45
45
 
46
46
  from sglang.global_config import global_config
47
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
+ from sglang.srt.layers.sampler import SampleOutput
47
49
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
48
50
  from sglang.srt.mem_cache.memory_pool import (
49
51
  MHATokenToKVPool,
@@ -193,9 +195,9 @@ class ModelRunner:
193
195
  monkey_patch_vllm_qvk_linear_loader()
194
196
 
195
197
  self.dtype = self.vllm_model_config.dtype
196
- if self.model_config.model_overide_args is not None:
198
+ if self.model_config.model_override_args is not None:
197
199
  self.vllm_model_config.hf_config.update(
198
- self.model_config.model_overide_args
200
+ self.model_config.model_override_args
199
201
  )
200
202
 
201
203
  self.model = get_model(
@@ -346,13 +348,7 @@ class ModelRunner:
346
348
  if self.server_args.kv_cache_dtype == "auto":
347
349
  self.kv_cache_dtype = self.dtype
348
350
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
349
- if self.server_args.disable_flashinfer or self.server_args.enable_mla:
350
- logger.warning(
351
- "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
352
- )
353
- self.kv_cache_dtype = self.dtype
354
- else:
355
- self.kv_cache_dtype = torch.float8_e5m2
351
+ self.kv_cache_dtype = torch.float8_e5m2
356
352
  else:
357
353
  raise ValueError(
358
354
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -524,7 +520,11 @@ class ModelRunner:
524
520
 
525
521
  @torch.inference_mode()
526
522
  def forward_decode(self, batch: ScheduleBatch):
527
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
523
+ if (
524
+ self.cuda_graph_runner
525
+ and self.cuda_graph_runner.can_run(len(batch.reqs))
526
+ and not batch.sampling_info.has_bias()
527
+ ):
528
528
  return self.cuda_graph_runner.replay(batch)
529
529
 
530
530
  input_metadata = InputMetadata.from_schedule_batch(
@@ -573,7 +573,9 @@ class ModelRunner:
573
573
  input_metadata.image_offsets,
574
574
  )
575
575
 
576
- def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
576
+ def forward(
577
+ self, batch: ScheduleBatch, forward_mode: ForwardMode
578
+ ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
577
579
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
578
580
  return self.forward_extend_multi_modal(batch)
579
581
  elif forward_mode == ForwardMode.DECODE:
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
31
31
  )
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
- from vllm.model_executor.layers.sampler import Sampler
35
34
  from vllm.model_executor.layers.vocab_parallel_embedding import (
36
35
  ParallelLMHead,
37
36
  VocabParallelEmbedding,
38
37
  )
39
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
- from vllm.model_executor.sampling_metadata import SamplingMetadata
41
- from vllm.sequence import SamplerOutput
42
39
  from vllm.transformers_utils.configs import ChatGLMConfig
43
40
 
44
41
  from sglang.srt.layers.activation import SiluAndMul
45
42
  from sglang.srt.layers.layernorm import RMSNorm
46
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.layers.sampler import Sampler
48
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
47
 
50
48
  LoraConfig = None
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
383
381
  input_metadata: InputMetadata,
384
382
  ) -> torch.Tensor:
385
383
  hidden_states = self.transformer(input_ids, positions, input_metadata)
386
- return self.logits_processor(
384
+ logits_output = self.logits_processor(
387
385
  input_ids, hidden_states, self.lm_head.weight, input_metadata
388
386
  )
389
-
390
- def sample(
391
- self,
392
- logits: torch.Tensor,
393
- sampling_metadata: SamplingMetadata,
394
- ) -> Optional[SamplerOutput]:
395
- next_tokens = self.sampler(logits, sampling_metadata)
396
- return next_tokens
387
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
+ return sample_output, logits_output
397
389
 
398
390
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
391
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
  from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
+ from sglang.srt.layers.sampler import Sampler
67
68
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
69
 
69
70
 
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
326
327
  self.config = config
327
328
  self.quant_config = quant_config
328
329
  self.logits_processor = LogitsProcessor(config)
330
+ self.sampler = Sampler()
329
331
  self.model = CohereModel(config, quant_config)
330
332
 
331
333
  @torch.no_grad()
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
340
342
  positions,
341
343
  input_metadata,
342
344
  )
343
- return self.logits_processor(
345
+ logits_output = self.logits_processor(
344
346
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
345
347
  )
348
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
+ return sample_output, logits_output
346
350
 
347
351
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
348
352
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
382
383
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
383
384
  )
384
385
  self.logits_processor = LogitsProcessor(config)
386
+ self.sampler = Sampler()
385
387
 
386
388
  @torch.no_grad()
387
389
  def forward(
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
391
393
  input_metadata: InputMetadata,
392
394
  ) -> torch.Tensor:
393
395
  hidden_states = self.transformer(input_ids, positions, input_metadata)
394
- return self.logits_processor(
396
+ logits_output = self.logits_processor(
395
397
  input_ids, hidden_states, self.lm_head.weight, input_metadata
396
398
  )
399
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
+ return sample_output, logits_output
397
401
 
398
402
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
403
  expert_params_mapping = [
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
385
386
  config.vocab_size, config.hidden_size, quant_config=quant_config
386
387
  )
387
388
  self.logits_processor = LogitsProcessor(config)
389
+ self.sampler = Sampler()
388
390
 
389
391
  @torch.no_grad()
390
392
  def forward(
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
394
396
  input_metadata: InputMetadata,
395
397
  ) -> torch.Tensor:
396
398
  hidden_states = self.model(input_ids, positions, input_metadata)
397
- return self.logits_processor(
399
+ logits_output = self.logits_processor(
398
400
  input_ids, hidden_states, self.lm_head.weight, input_metadata
399
401
  )
402
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
+ return sample_output, logits_output
400
404
 
401
405
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
402
406
  stacked_params_mapping = [
@@ -19,6 +19,7 @@ limitations under the License.
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
+ from flashinfer import bmm_fp8
22
23
  from torch import nn
23
24
  from transformers import PretrainedConfig
24
25
  from vllm.config import CacheConfig
@@ -45,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
45
46
  from sglang.srt.layers.layernorm import RMSNorm
46
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
48
50
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
51
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
52
 
@@ -160,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
160
162
  return 0.1 * mscale * math.log(scale) + 1.0
161
163
 
162
164
 
165
+ def input_to_float8(x, dtype=torch.float8_e4m3fn):
166
+ finfo = torch.finfo(dtype)
167
+ min_val, max_val = x.aminmax()
168
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
169
+ scale = finfo.max / amax
170
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
171
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
172
+
173
+
163
174
  class DeepseekV2Attention(nn.Module):
164
175
 
165
176
  def __init__(
@@ -254,11 +265,6 @@ class DeepseekV2Attention(nn.Module):
254
265
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
255
266
  self.scaling = self.scaling * mscale * mscale
256
267
 
257
- # self.attn = Attention(self.num_heads,
258
- # self.qk_head_dim,
259
- # self.scaling,
260
- # num_kv_heads=self.num_heads)
261
-
262
268
  # TODO, support head_size 192
263
269
  self.attn = RadixAttention(
264
270
  self.num_local_heads,
@@ -282,7 +288,7 @@ class DeepseekV2Attention(nn.Module):
282
288
  q = self.q_proj(hidden_states)[0].view(
283
289
  -1, self.num_local_heads, self.qk_head_dim
284
290
  )
285
- q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
291
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
286
292
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
287
293
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
288
294
  latent_cache = latent_cache.unsqueeze(1)
@@ -416,12 +422,9 @@ class DeepseekV2AttentionMLA(nn.Module):
416
422
  v_head_dim=self.kv_lora_rank,
417
423
  )
418
424
 
419
- kv_b_proj = self.kv_b_proj
420
- w_kc, w_vc = kv_b_proj.weight.unflatten(
421
- 0, (-1, qk_nope_head_dim + v_head_dim)
422
- ).split([qk_nope_head_dim, v_head_dim], dim=1)
423
- self.w_kc = w_kc
424
- self.w_vc = w_vc
425
+ self.w_kc = None
426
+ self.w_vc = None
427
+ self.w_scale = None
425
428
 
426
429
  def forward(
427
430
  self,
@@ -442,8 +445,17 @@ class DeepseekV2AttentionMLA(nn.Module):
442
445
  -1, self.num_local_heads, self.qk_head_dim
443
446
  )
444
447
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
445
- q_nope_out = q_input[..., : self.kv_lora_rank]
446
- torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
448
+
449
+ if self.w_kc.dtype == torch.float8_e4m3fn:
450
+ q_nope_val, q_nope_scale = input_to_float8(
451
+ q_nope.transpose(0, 1), torch.float8_e4m3fn
452
+ )
453
+ q_nope_out = bmm_fp8(
454
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
455
+ )
456
+ else:
457
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
458
+ q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
447
459
 
448
460
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
449
461
  v_input = latent_cache[..., : self.kv_lora_rank]
@@ -458,16 +470,21 @@ class DeepseekV2AttentionMLA(nn.Module):
458
470
 
459
471
  attn_output = self.attn(q_input, k_input, v_input, input_metadata)
460
472
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
461
- attn_bmm_output = attn_output.new_empty(
462
- q_len, self.num_local_heads, self.v_head_dim
463
- )
464
- torch.bmm(
465
- attn_output.transpose(0, 1),
466
- self.w_vc.transpose(1, 2).contiguous(),
467
- out=attn_bmm_output.transpose(0, 1),
468
- )
469
473
 
470
- attn_output = attn_bmm_output.flatten(1, 2)
474
+ if self.w_vc.dtype == torch.float8_e4m3fn:
475
+ attn_output_val, attn_output_scale = input_to_float8(
476
+ attn_output.transpose(0, 1), torch.float8_e4m3fn
477
+ )
478
+ attn_bmm_output = bmm_fp8(
479
+ attn_output_val,
480
+ self.w_vc,
481
+ attn_output_scale,
482
+ self.w_scale,
483
+ torch.bfloat16,
484
+ )
485
+ else:
486
+ attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
487
+ attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
471
488
  output, _ = self.o_proj(attn_output)
472
489
 
473
490
  return output
@@ -632,6 +649,7 @@ class DeepseekV2ForCausalLM(nn.Module):
632
649
  config.vocab_size, config.hidden_size, quant_config=quant_config
633
650
  )
634
651
  self.logits_processor = LogitsProcessor(config)
652
+ self.sampler = Sampler()
635
653
 
636
654
  def forward(
637
655
  self,
@@ -640,9 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
640
658
  input_metadata: InputMetadata,
641
659
  ) -> torch.Tensor:
642
660
  hidden_states = self.model(input_ids, positions, input_metadata)
643
- return self.logits_processor(
661
+ logits_output = self.logits_processor(
644
662
  input_ids, hidden_states, self.lm_head.weight, input_metadata
645
663
  )
664
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
665
+ return sample_output, logits_output
646
666
 
647
667
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
648
668
  stacked_params_mapping = [
@@ -695,7 +715,7 @@ class DeepseekV2ForCausalLM(nn.Module):
695
715
  weight_loader(
696
716
  param,
697
717
  loaded_weight,
698
- weight_name,
718
+ name,
699
719
  shard_id=shard_id,
700
720
  expert_id=expert_id,
701
721
  )
@@ -711,5 +731,17 @@ class DeepseekV2ForCausalLM(nn.Module):
711
731
  )
712
732
  weight_loader(param, loaded_weight)
713
733
 
734
+ if global_server_args_dict["enable_mla"]:
735
+ for layer_id in range(self.config.num_hidden_layers):
736
+ self_attn = self.model.layers[layer_id].self_attn
737
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
738
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
739
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
740
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
741
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
742
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
743
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
744
+ del self_attn.kv_b_proj
745
+
714
746
 
715
747
  EntryClass = DeepseekV2ForCausalLM