sglang 0.2.14.post2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/backend/runtime_endpoint.py +8 -4
  4. sglang/lang/interpreter.py +3 -0
  5. sglang/lang/ir.py +5 -0
  6. sglang/launch_server_llavavid.py +12 -12
  7. sglang/srt/configs/__init__.py +5 -0
  8. sglang/srt/configs/exaone.py +195 -0
  9. sglang/srt/constrained/fsm_cache.py +1 -1
  10. sglang/srt/conversation.py +24 -2
  11. sglang/srt/hf_transformers_utils.py +12 -12
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/sampler.py +94 -17
  15. sglang/srt/managers/controller_multi.py +5 -5
  16. sglang/srt/managers/controller_single.py +5 -5
  17. sglang/srt/managers/io_struct.py +6 -1
  18. sglang/srt/managers/schedule_batch.py +26 -11
  19. sglang/srt/managers/tokenizer_manager.py +9 -9
  20. sglang/srt/managers/tp_worker.py +38 -26
  21. sglang/srt/model_config.py +3 -3
  22. sglang/srt/model_executor/cuda_graph_runner.py +26 -9
  23. sglang/srt/model_executor/forward_batch_info.py +68 -23
  24. sglang/srt/model_executor/model_runner.py +15 -22
  25. sglang/srt/models/chatglm.py +9 -15
  26. sglang/srt/models/commandr.py +5 -1
  27. sglang/srt/models/dbrx.py +5 -1
  28. sglang/srt/models/deepseek.py +5 -1
  29. sglang/srt/models/deepseek_v2.py +57 -25
  30. sglang/srt/models/exaone.py +368 -0
  31. sglang/srt/models/gemma.py +5 -1
  32. sglang/srt/models/gemma2.py +5 -1
  33. sglang/srt/models/gpt_bigcode.py +5 -1
  34. sglang/srt/models/grok.py +5 -1
  35. sglang/srt/models/internlm2.py +5 -1
  36. sglang/srt/models/{llama2.py → llama.py} +25 -45
  37. sglang/srt/models/llama_classification.py +34 -41
  38. sglang/srt/models/llama_embedding.py +7 -6
  39. sglang/srt/models/llava.py +8 -11
  40. sglang/srt/models/llavavid.py +5 -6
  41. sglang/srt/models/minicpm.py +5 -1
  42. sglang/srt/models/mistral.py +2 -3
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +6 -2
  47. sglang/srt/models/qwen2_moe.py +5 -14
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/openai_api/adapter.py +16 -1
  50. sglang/srt/openai_api/protocol.py +5 -5
  51. sglang/srt/sampling/sampling_batch_info.py +75 -6
  52. sglang/srt/server.py +6 -6
  53. sglang/srt/utils.py +0 -3
  54. sglang/test/runners.py +1 -1
  55. sglang/test/test_programs.py +68 -0
  56. sglang/test/test_utils.py +4 -0
  57. sglang/utils.py +39 -0
  58. sglang/version.py +1 -1
  59. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/METADATA +9 -8
  60. sglang-0.3.0.dist-info/RECORD +118 -0
  61. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/WHEEL +1 -1
  62. sglang-0.2.14.post2.dist-info/RECORD +0 -115
  63. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/LICENSE +0 -0
  64. {sglang-0.2.14.post2.dist-info → sglang-0.3.0.dist-info}/top_level.txt +0 -0
@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
28
  from sglang.srt.layers.logits_processor import (
29
- LogitProcessorOutput,
30
29
  LogitsMetadata,
31
30
  LogitsProcessor,
31
+ LogitsProcessorOutput,
32
32
  )
33
+ from sglang.srt.layers.sampler import SampleOutput
33
34
  from sglang.srt.managers.schedule_batch import ScheduleBatch
34
35
  from sglang.srt.model_executor.forward_batch_info import (
35
36
  ForwardMode,
36
37
  InputMetadata,
37
38
  update_flashinfer_indices,
38
39
  )
40
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
39
41
  from sglang.srt.utils import monkey_patch_vllm_all_gather
40
42
 
41
43
 
@@ -44,8 +46,10 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
44
46
  if isinstance(sub, CustomOp):
45
47
  if reverse:
46
48
  sub._forward_method = sub.forward_cuda
49
+ setattr(sub, "is_torch_compile", False)
47
50
  else:
48
51
  sub._forward_method = sub.forward_native
52
+ setattr(sub, "is_torch_compile", True)
49
53
  if isinstance(sub, torch.nn.Module):
50
54
  _to_torch(sub, reverse)
51
55
 
@@ -144,6 +148,10 @@ class CudaGraphRunner:
144
148
  self.flashinfer_kv_indices.clone(),
145
149
  ]
146
150
 
151
+ # Sampling inputs
152
+ vocab_size = model_runner.model_config.vocab_size
153
+ self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
154
+
147
155
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
148
156
 
149
157
  if use_torch_compile:
@@ -235,6 +243,7 @@ class CudaGraphRunner:
235
243
  def run_once():
236
244
  input_metadata = InputMetadata(
237
245
  forward_mode=ForwardMode.DECODE,
246
+ sampling_info=self.sampling_info[:bs],
238
247
  batch_size=bs,
239
248
  req_pool_indices=req_pool_indices,
240
249
  seq_lens=seq_lens,
@@ -299,27 +308,35 @@ class CudaGraphRunner:
299
308
  self.flashinfer_handlers[bs],
300
309
  )
301
310
 
311
+ # Sampling inputs
312
+ self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
313
+
302
314
  # Replay
303
315
  torch.cuda.synchronize()
304
316
  self.graphs[bs].replay()
305
317
  torch.cuda.synchronize()
306
- output = self.output_buffers[bs]
318
+ sample_output, logits_output = self.output_buffers[bs]
307
319
 
308
320
  # Unpad
309
321
  if bs != raw_bs:
310
- output = LogitProcessorOutput(
311
- next_token_logits=output.next_token_logits[:raw_bs],
322
+ logits_output = LogitsProcessorOutput(
323
+ next_token_logits=logits_output.next_token_logits[:raw_bs],
312
324
  next_token_logprobs=None,
313
325
  normalized_prompt_logprobs=None,
314
326
  input_token_logprobs=None,
315
327
  input_top_logprobs=None,
316
328
  output_top_logprobs=None,
317
329
  )
330
+ sample_output = SampleOutput(
331
+ sample_output.success[:raw_bs],
332
+ sample_output.probs[:raw_bs],
333
+ sample_output.batch_next_token_ids[:raw_bs],
334
+ )
318
335
 
319
336
  # Extract logprobs
320
337
  if batch.return_logprob:
321
- output.next_token_logprobs = torch.nn.functional.log_softmax(
322
- output.next_token_logits, dim=-1
338
+ logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
339
+ logits_output.next_token_logits, dim=-1
323
340
  )
324
341
  return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
325
342
  if return_top_logprob:
@@ -327,8 +344,8 @@ class CudaGraphRunner:
327
344
  forward_mode=ForwardMode.DECODE,
328
345
  top_logprobs_nums=batch.top_logprobs_nums,
329
346
  )
330
- output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
331
- output.next_token_logprobs, logits_metadata
347
+ logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
348
+ logits_output.next_token_logprobs, logits_metadata
332
349
  )[1]
333
350
 
334
- return output
351
+ return sample_output, logits_output
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,12 +22,15 @@ from typing import TYPE_CHECKING, List
20
22
 
21
23
  import numpy as np
22
24
  import torch
25
+ import triton
26
+ import triton.language as tl
23
27
 
24
28
  from sglang.srt.managers.schedule_batch import ScheduleBatch
25
29
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
26
30
 
27
31
  if TYPE_CHECKING:
28
32
  from sglang.srt.model_executor.model_runner import ModelRunner
33
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
29
34
 
30
35
 
31
36
  class ForwardMode(IntEnum):
@@ -42,6 +47,7 @@ class InputMetadata:
42
47
  """Store all inforamtion of a forward pass."""
43
48
 
44
49
  forward_mode: ForwardMode
50
+ sampling_info: SamplingBatchInfo
45
51
  batch_size: int
46
52
  req_pool_indices: torch.Tensor
47
53
  seq_lens: torch.Tensor
@@ -169,6 +175,7 @@ class InputMetadata:
169
175
  ):
170
176
  ret = cls(
171
177
  forward_mode=forward_mode,
178
+ sampling_info=batch.sampling_info,
172
179
  batch_size=batch.batch_size(),
173
180
  req_pool_indices=batch.req_pool_indices,
174
181
  seq_lens=batch.seq_lens,
@@ -179,6 +186,8 @@ class InputMetadata:
179
186
  top_logprobs_nums=batch.top_logprobs_nums,
180
187
  )
181
188
 
189
+ ret.sampling_info.prepare_penalties()
190
+
182
191
  ret.compute_positions(batch)
183
192
 
184
193
  ret.compute_extend_infos(batch)
@@ -255,6 +264,42 @@ class InputMetadata:
255
264
  )
256
265
 
257
266
 
267
+ @triton.jit
268
+ def create_flashinfer_kv_indices_triton(
269
+ req_to_token_ptr, # [max_batch, max_context_len]
270
+ req_pool_indices_ptr,
271
+ page_kernel_lens_ptr,
272
+ kv_indptr,
273
+ kv_start_idx,
274
+ max_context_len,
275
+ kv_indices_ptr,
276
+ ):
277
+ BLOCK_SIZE: tl.constexpr = 512
278
+ pid = tl.program_id(axis=0)
279
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
280
+ kv_indices_offset = tl.load(kv_indptr + pid)
281
+
282
+ kv_start = 0
283
+ kv_end = 0
284
+ if kv_start_idx:
285
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
286
+ kv_end = kv_start
287
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
288
+
289
+ req_to_token_ptr += req_pool_index * max_context_len
290
+ kv_indices_ptr += kv_indices_offset
291
+
292
+ ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
293
+ st_offset = tl.arange(0, BLOCK_SIZE)
294
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
295
+ for _ in range(num_loop):
296
+ mask = ld_offset < kv_end
297
+ data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
298
+ tl.store(kv_indices_ptr + st_offset, data, mask=mask)
299
+ ld_offset += BLOCK_SIZE
300
+ st_offset += BLOCK_SIZE
301
+
302
+
258
303
  def update_flashinfer_indices(
259
304
  forward_mode,
260
305
  model_runner,
@@ -278,17 +323,18 @@ def update_flashinfer_indices(
278
323
 
279
324
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
280
325
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
281
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
282
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
283
- kv_indices = torch.cat(
284
- [
285
- model_runner.req_to_token_pool.req_to_token[
286
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
287
- ]
288
- for i in range(batch_size)
289
- ],
290
- dim=0,
291
- ).contiguous()
326
+
327
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
328
+ create_flashinfer_kv_indices_triton[(batch_size,)](
329
+ model_runner.req_to_token_pool.req_to_token,
330
+ req_pool_indices,
331
+ paged_kernel_lens,
332
+ kv_indptr,
333
+ None,
334
+ model_runner.req_to_token_pool.req_to_token.size(1),
335
+ kv_indices,
336
+ )
337
+
292
338
  kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
293
339
 
294
340
  if forward_mode == ForwardMode.DECODE:
@@ -358,18 +404,17 @@ def update_flashinfer_indices(
358
404
 
359
405
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
360
406
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
361
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
362
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
363
- kv_indices = torch.cat(
364
- [
365
- model_runner.req_to_token_pool.req_to_token[
366
- req_pool_indices_cpu[i],
367
- kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
368
- ]
369
- for i in range(batch_size)
370
- ],
371
- dim=0,
372
- ).contiguous()
407
+
408
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
409
+ create_flashinfer_kv_indices_triton[(batch_size,)](
410
+ model_runner.req_to_token_pool.req_to_token,
411
+ req_pool_indices,
412
+ paged_kernel_lens,
413
+ kv_indptr,
414
+ kv_start_idx,
415
+ model_runner.req_to_token_pool.req_to_token.size(1),
416
+ kv_indices,
417
+ )
373
418
 
374
419
  if forward_mode == ForwardMode.DECODE:
375
420
  # CUDA graph uses different flashinfer_decode_wrapper
@@ -21,7 +21,7 @@ import importlib.resources
21
21
  import logging
22
22
  import pkgutil
23
23
  from functools import lru_cache
24
- from typing import Optional, Type
24
+ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -44,6 +44,8 @@ from vllm.model_executor.model_loader import get_model
44
44
  from vllm.model_executor.models import ModelRegistry
45
45
 
46
46
  from sglang.global_config import global_config
47
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
+ from sglang.srt.layers.sampler import SampleOutput
47
49
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
48
50
  from sglang.srt.mem_cache.memory_pool import (
49
51
  MHATokenToKVPool,
@@ -160,6 +162,7 @@ class ModelRunner:
160
162
  return min_per_gpu_memory
161
163
 
162
164
  def load_model(self):
165
+ torch.set_num_threads(1)
163
166
  logger.info(
164
167
  f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
165
168
  )
@@ -193,9 +196,9 @@ class ModelRunner:
193
196
  monkey_patch_vllm_qvk_linear_loader()
194
197
 
195
198
  self.dtype = self.vllm_model_config.dtype
196
- if self.model_config.model_overide_args is not None:
199
+ if self.model_config.model_override_args is not None:
197
200
  self.vllm_model_config.hf_config.update(
198
- self.model_config.model_overide_args
201
+ self.model_config.model_override_args
199
202
  )
200
203
 
201
204
  self.model = get_model(
@@ -346,13 +349,7 @@ class ModelRunner:
346
349
  if self.server_args.kv_cache_dtype == "auto":
347
350
  self.kv_cache_dtype = self.dtype
348
351
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
349
- if self.server_args.disable_flashinfer or self.server_args.enable_mla:
350
- logger.warning(
351
- "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
352
- )
353
- self.kv_cache_dtype = self.dtype
354
- else:
355
- self.kv_cache_dtype = torch.float8_e5m2
352
+ self.kv_cache_dtype = torch.float8_e5m2
356
353
  else:
357
354
  raise ValueError(
358
355
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -524,7 +521,11 @@ class ModelRunner:
524
521
 
525
522
  @torch.inference_mode()
526
523
  def forward_decode(self, batch: ScheduleBatch):
527
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
524
+ if (
525
+ self.cuda_graph_runner
526
+ and self.cuda_graph_runner.can_run(len(batch.reqs))
527
+ and batch.sampling_info.can_run_in_cuda_graph()
528
+ ):
528
529
  return self.cuda_graph_runner.replay(batch)
529
530
 
530
531
  input_metadata = InputMetadata.from_schedule_batch(
@@ -573,7 +574,9 @@ class ModelRunner:
573
574
  input_metadata.image_offsets,
574
575
  )
575
576
 
576
- def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
577
+ def forward(
578
+ self, batch: ScheduleBatch, forward_mode: ForwardMode
579
+ ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
577
580
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
578
581
  return self.forward_extend_multi_modal(batch)
579
582
  elif forward_mode == ForwardMode.DECODE:
@@ -604,16 +607,6 @@ def import_model_classes():
604
607
  assert entry.__name__ not in model_arch_name_to_cls
605
608
  model_arch_name_to_cls[entry.__name__] = entry
606
609
 
607
- # compat: some models such as chatglm has incorrect class set in config.json
608
- # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
609
- if hasattr(module, "EntryClassRemapping") and isinstance(
610
- module.EntryClassRemapping, list
611
- ):
612
- for remap in module.EntryClassRemapping:
613
- if isinstance(remap, tuple) and len(remap) == 2:
614
- assert remap[0] not in model_arch_name_to_cls
615
- model_arch_name_to_cls[remap[0]] = remap[1]
616
-
617
610
  return model_arch_name_to_cls
618
611
 
619
612
 
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
31
31
  )
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
- from vllm.model_executor.layers.sampler import Sampler
35
34
  from vllm.model_executor.layers.vocab_parallel_embedding import (
36
35
  ParallelLMHead,
37
36
  VocabParallelEmbedding,
38
37
  )
39
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
- from vllm.model_executor.sampling_metadata import SamplingMetadata
41
- from vllm.sequence import SamplerOutput
42
39
  from vllm.transformers_utils.configs import ChatGLMConfig
43
40
 
44
41
  from sglang.srt.layers.activation import SiluAndMul
45
42
  from sglang.srt.layers.layernorm import RMSNorm
46
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.layers.sampler import Sampler
48
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
47
 
50
48
  LoraConfig = None
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
383
381
  input_metadata: InputMetadata,
384
382
  ) -> torch.Tensor:
385
383
  hidden_states = self.transformer(input_ids, positions, input_metadata)
386
- return self.logits_processor(
384
+ logits_output = self.logits_processor(
387
385
  input_ids, hidden_states, self.lm_head.weight, input_metadata
388
386
  )
389
-
390
- def sample(
391
- self,
392
- logits: torch.Tensor,
393
- sampling_metadata: SamplingMetadata,
394
- ) -> Optional[SamplerOutput]:
395
- next_tokens = self.sampler(logits, sampling_metadata)
396
- return next_tokens
387
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
+ return sample_output, logits_output
397
389
 
398
390
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
391
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -410,6 +402,8 @@ class ChatGLMForCausalLM(nn.Module):
410
402
  weight_loader(param, loaded_weight)
411
403
 
412
404
 
413
- EntryClass = ChatGLMForCausalLM
414
- # compat: glm model.config class == ChatGLMModel
415
- EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
405
+ class ChatGLMModel(ChatGLMForCausalLM):
406
+ pass
407
+
408
+
409
+ EntryClass = [ChatGLMForCausalLM, ChatGLMModel]
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
  from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
+ from sglang.srt.layers.sampler import Sampler
67
68
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
69
 
69
70
 
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
326
327
  self.config = config
327
328
  self.quant_config = quant_config
328
329
  self.logits_processor = LogitsProcessor(config)
330
+ self.sampler = Sampler()
329
331
  self.model = CohereModel(config, quant_config)
330
332
 
331
333
  @torch.no_grad()
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
340
342
  positions,
341
343
  input_metadata,
342
344
  )
343
- return self.logits_processor(
345
+ logits_output = self.logits_processor(
344
346
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
345
347
  )
348
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
+ return sample_output, logits_output
346
350
 
347
351
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
348
352
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
382
383
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
383
384
  )
384
385
  self.logits_processor = LogitsProcessor(config)
386
+ self.sampler = Sampler()
385
387
 
386
388
  @torch.no_grad()
387
389
  def forward(
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
391
393
  input_metadata: InputMetadata,
392
394
  ) -> torch.Tensor:
393
395
  hidden_states = self.transformer(input_ids, positions, input_metadata)
394
- return self.logits_processor(
396
+ logits_output = self.logits_processor(
395
397
  input_ids, hidden_states, self.lm_head.weight, input_metadata
396
398
  )
399
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
+ return sample_output, logits_output
397
401
 
398
402
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
403
  expert_params_mapping = [
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
385
386
  config.vocab_size, config.hidden_size, quant_config=quant_config
386
387
  )
387
388
  self.logits_processor = LogitsProcessor(config)
389
+ self.sampler = Sampler()
388
390
 
389
391
  @torch.no_grad()
390
392
  def forward(
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
394
396
  input_metadata: InputMetadata,
395
397
  ) -> torch.Tensor:
396
398
  hidden_states = self.model(input_ids, positions, input_metadata)
397
- return self.logits_processor(
399
+ logits_output = self.logits_processor(
398
400
  input_ids, hidden_states, self.lm_head.weight, input_metadata
399
401
  )
402
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
+ return sample_output, logits_output
400
404
 
401
405
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
402
406
  stacked_params_mapping = [
@@ -19,6 +19,7 @@ limitations under the License.
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
+ from flashinfer import bmm_fp8
22
23
  from torch import nn
23
24
  from transformers import PretrainedConfig
24
25
  from vllm.config import CacheConfig
@@ -45,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
45
46
  from sglang.srt.layers.layernorm import RMSNorm
46
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
48
50
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
51
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
52
 
@@ -160,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
160
162
  return 0.1 * mscale * math.log(scale) + 1.0
161
163
 
162
164
 
165
+ def input_to_float8(x, dtype=torch.float8_e4m3fn):
166
+ finfo = torch.finfo(dtype)
167
+ min_val, max_val = x.aminmax()
168
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
169
+ scale = finfo.max / amax
170
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
171
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
172
+
173
+
163
174
  class DeepseekV2Attention(nn.Module):
164
175
 
165
176
  def __init__(
@@ -254,11 +265,6 @@ class DeepseekV2Attention(nn.Module):
254
265
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
255
266
  self.scaling = self.scaling * mscale * mscale
256
267
 
257
- # self.attn = Attention(self.num_heads,
258
- # self.qk_head_dim,
259
- # self.scaling,
260
- # num_kv_heads=self.num_heads)
261
-
262
268
  # TODO, support head_size 192
263
269
  self.attn = RadixAttention(
264
270
  self.num_local_heads,
@@ -282,7 +288,7 @@ class DeepseekV2Attention(nn.Module):
282
288
  q = self.q_proj(hidden_states)[0].view(
283
289
  -1, self.num_local_heads, self.qk_head_dim
284
290
  )
285
- q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
291
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
286
292
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
287
293
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
288
294
  latent_cache = latent_cache.unsqueeze(1)
@@ -416,12 +422,9 @@ class DeepseekV2AttentionMLA(nn.Module):
416
422
  v_head_dim=self.kv_lora_rank,
417
423
  )
418
424
 
419
- kv_b_proj = self.kv_b_proj
420
- w_kc, w_vc = kv_b_proj.weight.unflatten(
421
- 0, (-1, qk_nope_head_dim + v_head_dim)
422
- ).split([qk_nope_head_dim, v_head_dim], dim=1)
423
- self.w_kc = w_kc
424
- self.w_vc = w_vc
425
+ self.w_kc = None
426
+ self.w_vc = None
427
+ self.w_scale = None
425
428
 
426
429
  def forward(
427
430
  self,
@@ -442,8 +445,17 @@ class DeepseekV2AttentionMLA(nn.Module):
442
445
  -1, self.num_local_heads, self.qk_head_dim
443
446
  )
444
447
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
445
- q_nope_out = q_input[..., : self.kv_lora_rank]
446
- torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
448
+
449
+ if self.w_kc.dtype == torch.float8_e4m3fn:
450
+ q_nope_val, q_nope_scale = input_to_float8(
451
+ q_nope.transpose(0, 1), torch.float8_e4m3fn
452
+ )
453
+ q_nope_out = bmm_fp8(
454
+ q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
455
+ )
456
+ else:
457
+ q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
458
+ q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
447
459
 
448
460
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
449
461
  v_input = latent_cache[..., : self.kv_lora_rank]
@@ -458,16 +470,21 @@ class DeepseekV2AttentionMLA(nn.Module):
458
470
 
459
471
  attn_output = self.attn(q_input, k_input, v_input, input_metadata)
460
472
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
461
- attn_bmm_output = attn_output.new_empty(
462
- q_len, self.num_local_heads, self.v_head_dim
463
- )
464
- torch.bmm(
465
- attn_output.transpose(0, 1),
466
- self.w_vc.transpose(1, 2).contiguous(),
467
- out=attn_bmm_output.transpose(0, 1),
468
- )
469
473
 
470
- attn_output = attn_bmm_output.flatten(1, 2)
474
+ if self.w_vc.dtype == torch.float8_e4m3fn:
475
+ attn_output_val, attn_output_scale = input_to_float8(
476
+ attn_output.transpose(0, 1), torch.float8_e4m3fn
477
+ )
478
+ attn_bmm_output = bmm_fp8(
479
+ attn_output_val,
480
+ self.w_vc,
481
+ attn_output_scale,
482
+ self.w_scale,
483
+ torch.bfloat16,
484
+ )
485
+ else:
486
+ attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
487
+ attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
471
488
  output, _ = self.o_proj(attn_output)
472
489
 
473
490
  return output
@@ -632,6 +649,7 @@ class DeepseekV2ForCausalLM(nn.Module):
632
649
  config.vocab_size, config.hidden_size, quant_config=quant_config
633
650
  )
634
651
  self.logits_processor = LogitsProcessor(config)
652
+ self.sampler = Sampler()
635
653
 
636
654
  def forward(
637
655
  self,
@@ -640,9 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
640
658
  input_metadata: InputMetadata,
641
659
  ) -> torch.Tensor:
642
660
  hidden_states = self.model(input_ids, positions, input_metadata)
643
- return self.logits_processor(
661
+ logits_output = self.logits_processor(
644
662
  input_ids, hidden_states, self.lm_head.weight, input_metadata
645
663
  )
664
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
665
+ return sample_output, logits_output
646
666
 
647
667
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
648
668
  stacked_params_mapping = [
@@ -695,7 +715,7 @@ class DeepseekV2ForCausalLM(nn.Module):
695
715
  weight_loader(
696
716
  param,
697
717
  loaded_weight,
698
- weight_name,
718
+ name,
699
719
  shard_id=shard_id,
700
720
  expert_id=expert_id,
701
721
  )
@@ -711,5 +731,17 @@ class DeepseekV2ForCausalLM(nn.Module):
711
731
  )
712
732
  weight_loader(param, loaded_weight)
713
733
 
734
+ if global_server_args_dict["enable_mla"]:
735
+ for layer_id in range(self.config.num_hidden_layers):
736
+ self_attn = self.model.layers[layer_id].self_attn
737
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
738
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
739
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
740
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
741
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
742
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
743
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
744
+ del self_attn.kv_b_proj
745
+
714
746
 
715
747
  EntryClass = DeepseekV2ForCausalLM