sglang 0.2.14.post1__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +26 -0
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -160
  11. sglang/srt/layers/activation.py +10 -4
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/layernorm.py +47 -1
  14. sglang/srt/layers/logits_processor.py +4 -4
  15. sglang/srt/layers/sampler.py +69 -16
  16. sglang/srt/managers/controller_multi.py +5 -5
  17. sglang/srt/managers/controller_single.py +5 -5
  18. sglang/srt/managers/io_struct.py +11 -5
  19. sglang/srt/managers/schedule_batch.py +25 -13
  20. sglang/srt/managers/tokenizer_manager.py +76 -63
  21. sglang/srt/managers/tp_worker.py +47 -36
  22. sglang/srt/model_config.py +3 -3
  23. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  24. sglang/srt/model_executor/forward_batch_info.py +78 -43
  25. sglang/srt/model_executor/model_runner.py +29 -18
  26. sglang/srt/models/chatglm.py +5 -13
  27. sglang/srt/models/commandr.py +5 -1
  28. sglang/srt/models/dbrx.py +5 -1
  29. sglang/srt/models/deepseek.py +5 -1
  30. sglang/srt/models/deepseek_v2.py +57 -25
  31. sglang/srt/models/exaone.py +399 -0
  32. sglang/srt/models/gemma.py +7 -3
  33. sglang/srt/models/gemma2.py +6 -52
  34. sglang/srt/models/gpt_bigcode.py +5 -1
  35. sglang/srt/models/grok.py +14 -4
  36. sglang/srt/models/internlm2.py +5 -1
  37. sglang/srt/models/llama2.py +10 -7
  38. sglang/srt/models/llama_classification.py +2 -6
  39. sglang/srt/models/llama_embedding.py +3 -4
  40. sglang/srt/models/llava.py +69 -91
  41. sglang/srt/models/llavavid.py +40 -86
  42. sglang/srt/models/minicpm.py +5 -1
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +9 -6
  47. sglang/srt/models/qwen2_moe.py +12 -33
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/models/yivl.py +2 -7
  50. sglang/srt/openai_api/adapter.py +16 -1
  51. sglang/srt/openai_api/protocol.py +5 -5
  52. sglang/srt/sampling/sampling_batch_info.py +79 -6
  53. sglang/srt/server.py +9 -9
  54. sglang/srt/utils.py +18 -36
  55. sglang/test/runners.py +2 -2
  56. sglang/test/test_layernorm.py +53 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/METADATA +8 -8
  59. sglang-0.2.15.dist-info/RECORD +118 -0
  60. sglang-0.2.14.post1.dist-info/RECORD +0 -114
  61. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -26,16 +26,18 @@ from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
28
  from sglang.srt.layers.logits_processor import (
29
- LogitProcessorOutput,
30
29
  LogitsMetadata,
31
30
  LogitsProcessor,
31
+ LogitsProcessorOutput,
32
32
  )
33
+ from sglang.srt.layers.sampler import SampleOutput
33
34
  from sglang.srt.managers.schedule_batch import ScheduleBatch
34
35
  from sglang.srt.model_executor.forward_batch_info import (
35
36
  ForwardMode,
36
37
  InputMetadata,
37
38
  update_flashinfer_indices,
38
39
  )
40
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
39
41
  from sglang.srt.utils import monkey_patch_vllm_all_gather
40
42
 
41
43
 
@@ -144,6 +146,10 @@ class CudaGraphRunner:
144
146
  self.flashinfer_kv_indices.clone(),
145
147
  ]
146
148
 
149
+ # Sampling inputs
150
+ vocab_size = model_runner.model_config.vocab_size
151
+ self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
152
+
147
153
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
148
154
 
149
155
  if use_torch_compile:
@@ -235,6 +241,7 @@ class CudaGraphRunner:
235
241
  def run_once():
236
242
  input_metadata = InputMetadata(
237
243
  forward_mode=ForwardMode.DECODE,
244
+ sampling_info=self.sampling_info[:bs],
238
245
  batch_size=bs,
239
246
  req_pool_indices=req_pool_indices,
240
247
  seq_lens=seq_lens,
@@ -299,27 +306,35 @@ class CudaGraphRunner:
299
306
  self.flashinfer_handlers[bs],
300
307
  )
301
308
 
309
+ # Sampling inputs
310
+ self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
311
+
302
312
  # Replay
303
313
  torch.cuda.synchronize()
304
314
  self.graphs[bs].replay()
305
315
  torch.cuda.synchronize()
306
- output = self.output_buffers[bs]
316
+ sample_output, logits_output = self.output_buffers[bs]
307
317
 
308
318
  # Unpad
309
319
  if bs != raw_bs:
310
- output = LogitProcessorOutput(
311
- next_token_logits=output.next_token_logits[:raw_bs],
320
+ logits_output = LogitsProcessorOutput(
321
+ next_token_logits=logits_output.next_token_logits[:raw_bs],
312
322
  next_token_logprobs=None,
313
323
  normalized_prompt_logprobs=None,
314
324
  input_token_logprobs=None,
315
325
  input_top_logprobs=None,
316
326
  output_top_logprobs=None,
317
327
  )
328
+ sample_output = SampleOutput(
329
+ sample_output.success[:raw_bs],
330
+ sample_output.probs[:raw_bs],
331
+ sample_output.batch_next_token_ids[:raw_bs],
332
+ )
318
333
 
319
334
  # Extract logprobs
320
335
  if batch.return_logprob:
321
- output.next_token_logprobs = torch.nn.functional.log_softmax(
322
- output.next_token_logits, dim=-1
336
+ logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
337
+ logits_output.next_token_logits, dim=-1
323
338
  )
324
339
  return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
325
340
  if return_top_logprob:
@@ -327,8 +342,8 @@ class CudaGraphRunner:
327
342
  forward_mode=ForwardMode.DECODE,
328
343
  top_logprobs_nums=batch.top_logprobs_nums,
329
344
  )
330
- output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
331
- output.next_token_logprobs, logits_metadata
345
+ logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
346
+ logits_output.next_token_logprobs, logits_metadata
332
347
  )[1]
333
348
 
334
- return output
349
+ return sample_output, logits_output
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,16 +18,19 @@ limitations under the License.
16
18
  """ModelRunner runs the forward passes of the models."""
17
19
  from dataclasses import dataclass
18
20
  from enum import IntEnum, auto
19
- from typing import TYPE_CHECKING, List, Optional
21
+ from typing import TYPE_CHECKING, List
20
22
 
21
23
  import numpy as np
22
24
  import torch
25
+ import triton
26
+ import triton.language as tl
23
27
 
24
28
  from sglang.srt.managers.schedule_batch import ScheduleBatch
25
29
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
26
30
 
27
31
  if TYPE_CHECKING:
28
32
  from sglang.srt.model_executor.model_runner import ModelRunner
33
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
29
34
 
30
35
 
31
36
  class ForwardMode(IntEnum):
@@ -42,6 +47,7 @@ class InputMetadata:
42
47
  """Store all inforamtion of a forward pass."""
43
48
 
44
49
  forward_mode: ForwardMode
50
+ sampling_info: SamplingBatchInfo
45
51
  batch_size: int
46
52
  req_pool_indices: torch.Tensor
47
53
  seq_lens: torch.Tensor
@@ -58,6 +64,7 @@ class InputMetadata:
58
64
 
59
65
  # For extend
60
66
  extend_seq_lens: torch.Tensor = None
67
+ extend_prefix_lens: torch.Tensor = None
61
68
  extend_start_loc: torch.Tensor = None
62
69
  extend_no_prefix: bool = None
63
70
 
@@ -69,8 +76,8 @@ class InputMetadata:
69
76
 
70
77
  # For multimodal
71
78
  pixel_values: List[torch.Tensor] = None
72
- image_sizes: List[List[int]] = None
73
- image_offsets: List[int] = None
79
+ image_sizes: List[List[List[int]]] = None
80
+ image_offsets: List[List[int]] = None
74
81
 
75
82
  # Trition attention backend
76
83
  triton_max_seq_len: int = 0
@@ -87,20 +94,8 @@ class InputMetadata:
87
94
  def init_multimuldal_info(self, batch: ScheduleBatch):
88
95
  reqs = batch.reqs
89
96
  self.pixel_values = [r.pixel_values for r in reqs]
90
- self.image_sizes = [r.image_size for r in reqs]
91
- self.image_offsets = []
92
- for r in reqs:
93
- if isinstance(r.image_offset, list):
94
- self.image_offsets.append(
95
- [
96
- (image_offset - len(r.prefix_indices))
97
- for image_offset in r.image_offset
98
- ]
99
- )
100
- elif isinstance(r.image_offset, int):
101
- self.image_offsets.append(r.image_offset - len(r.prefix_indices))
102
- elif r.image_offset is None:
103
- self.image_offsets.append(0)
97
+ self.image_sizes = [r.image_sizes for r in reqs]
98
+ self.image_offsets = [r.image_offsets for r in reqs]
104
99
 
105
100
  def compute_positions(self, batch: ScheduleBatch):
106
101
  position_ids_offsets = batch.position_ids_offsets
@@ -153,6 +148,7 @@ class InputMetadata:
153
148
  for i, r in enumerate(batch.reqs)
154
149
  ]
155
150
  self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
151
+ self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
156
152
  self.extend_start_loc = torch.zeros_like(self.seq_lens)
157
153
  self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
158
154
  self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@@ -179,6 +175,7 @@ class InputMetadata:
179
175
  ):
180
176
  ret = cls(
181
177
  forward_mode=forward_mode,
178
+ sampling_info=batch.sampling_info,
182
179
  batch_size=batch.batch_size(),
183
180
  req_pool_indices=batch.req_pool_indices,
184
181
  seq_lens=batch.seq_lens,
@@ -189,6 +186,8 @@ class InputMetadata:
189
186
  top_logprobs_nums=batch.top_logprobs_nums,
190
187
  )
191
188
 
189
+ ret.sampling_info.prepare_penalties()
190
+
192
191
  ret.compute_positions(batch)
193
192
 
194
193
  ret.compute_extend_infos(batch)
@@ -238,10 +237,10 @@ class InputMetadata:
238
237
  prefix_lens_cpu,
239
238
  flashinfer_use_ragged,
240
239
  ):
241
- if self.forward_mode != ForwardMode.DECODE:
242
- prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
243
- else:
240
+ if self.forward_mode == ForwardMode.DECODE:
244
241
  prefix_lens = None
242
+ else:
243
+ prefix_lens = self.extend_prefix_lens
245
244
 
246
245
  update_flashinfer_indices(
247
246
  self.forward_mode,
@@ -265,6 +264,42 @@ class InputMetadata:
265
264
  )
266
265
 
267
266
 
267
+ @triton.jit
268
+ def create_flashinfer_kv_indices_triton(
269
+ req_to_token_ptr, # [max_batch, max_context_len]
270
+ req_pool_indices_ptr,
271
+ page_kernel_lens_ptr,
272
+ kv_indptr,
273
+ kv_start_idx,
274
+ max_context_len,
275
+ kv_indices_ptr,
276
+ ):
277
+ BLOCK_SIZE: tl.constexpr = 512
278
+ pid = tl.program_id(axis=0)
279
+ req_pool_index = tl.load(req_pool_indices_ptr + pid)
280
+ kv_indices_offset = tl.load(kv_indptr + pid)
281
+
282
+ kv_start = 0
283
+ kv_end = 0
284
+ if kv_start_idx:
285
+ kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
286
+ kv_end = kv_start
287
+ kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
288
+
289
+ req_to_token_ptr += req_pool_index * max_context_len
290
+ kv_indices_ptr += kv_indices_offset
291
+
292
+ ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
293
+ st_offset = tl.arange(0, BLOCK_SIZE)
294
+ num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
295
+ for _ in range(num_loop):
296
+ mask = ld_offset < kv_end
297
+ data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
298
+ tl.store(kv_indices_ptr + st_offset, data, mask=mask)
299
+ ld_offset += BLOCK_SIZE
300
+ st_offset += BLOCK_SIZE
301
+
302
+
268
303
  def update_flashinfer_indices(
269
304
  forward_mode,
270
305
  model_runner,
@@ -288,17 +323,18 @@ def update_flashinfer_indices(
288
323
 
289
324
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
290
325
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
291
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
292
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
293
- kv_indices = torch.cat(
294
- [
295
- model_runner.req_to_token_pool.req_to_token[
296
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
297
- ]
298
- for i in range(batch_size)
299
- ],
300
- dim=0,
301
- ).contiguous()
326
+
327
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
328
+ create_flashinfer_kv_indices_triton[(batch_size,)](
329
+ model_runner.req_to_token_pool.req_to_token,
330
+ req_pool_indices,
331
+ paged_kernel_lens,
332
+ kv_indptr,
333
+ None,
334
+ model_runner.req_to_token_pool.req_to_token.size(1),
335
+ kv_indices,
336
+ )
337
+
302
338
  kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
303
339
 
304
340
  if forward_mode == ForwardMode.DECODE:
@@ -368,18 +404,17 @@ def update_flashinfer_indices(
368
404
 
369
405
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
370
406
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
371
- req_pool_indices_cpu = req_pool_indices.cpu().numpy()
372
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
373
- kv_indices = torch.cat(
374
- [
375
- model_runner.req_to_token_pool.req_to_token[
376
- req_pool_indices_cpu[i],
377
- kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
378
- ]
379
- for i in range(batch_size)
380
- ],
381
- dim=0,
382
- ).contiguous()
407
+
408
+ kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
409
+ create_flashinfer_kv_indices_triton[(batch_size,)](
410
+ model_runner.req_to_token_pool.req_to_token,
411
+ req_pool_indices,
412
+ paged_kernel_lens,
413
+ kv_indptr,
414
+ kv_start_idx,
415
+ model_runner.req_to_token_pool.req_to_token.size(1),
416
+ kv_indices,
417
+ )
383
418
 
384
419
  if forward_mode == ForwardMode.DECODE:
385
420
  # CUDA graph uses different flashinfer_decode_wrapper
@@ -21,7 +21,7 @@ import importlib.resources
21
21
  import logging
22
22
  import pkgutil
23
23
  from functools import lru_cache
24
- from typing import Optional, Type
24
+ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -44,13 +44,15 @@ from vllm.model_executor.model_loader import get_model
44
44
  from vllm.model_executor.models import ModelRegistry
45
45
 
46
46
  from sglang.global_config import global_config
47
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
+ from sglang.srt.layers.sampler import SampleOutput
47
49
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
48
50
  from sglang.srt.mem_cache.memory_pool import (
49
51
  MHATokenToKVPool,
50
52
  MLATokenToKVPool,
51
53
  ReqToTokenPool,
52
54
  )
53
- from sglang.srt.model_config import AttentionArch
55
+ from sglang.srt.model_config import AttentionArch, ModelConfig
54
56
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
55
57
  from sglang.srt.server_args import ServerArgs
56
58
  from sglang.srt.utils import (
@@ -69,7 +71,7 @@ logger = logging.getLogger(__name__)
69
71
  class ModelRunner:
70
72
  def __init__(
71
73
  self,
72
- model_config,
74
+ model_config: ModelConfig,
73
75
  mem_fraction_static: float,
74
76
  gpu_id: int,
75
77
  tp_rank: int,
@@ -85,7 +87,9 @@ class ModelRunner:
85
87
  self.tp_size = tp_size
86
88
  self.nccl_port = nccl_port
87
89
  self.server_args = server_args
88
- self.is_multimodal_model = is_multimodal_model(self.model_config)
90
+ self.is_multimodal_model = is_multimodal_model(
91
+ self.model_config.hf_config.architectures
92
+ )
89
93
  global_server_args_dict.update(
90
94
  {
91
95
  "disable_flashinfer": server_args.disable_flashinfer,
@@ -95,6 +99,13 @@ class ModelRunner:
95
99
  }
96
100
  )
97
101
 
102
+ if self.is_multimodal_model:
103
+ logger.info(
104
+ "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
105
+ )
106
+ server_args.chunked_prefill_size = None
107
+ server_args.mem_fraction_static *= 0.95
108
+
98
109
  min_per_gpu_memory = self.init_torch_distributed()
99
110
  self.load_model()
100
111
  self.init_memory_pool(
@@ -184,9 +195,9 @@ class ModelRunner:
184
195
  monkey_patch_vllm_qvk_linear_loader()
185
196
 
186
197
  self.dtype = self.vllm_model_config.dtype
187
- if self.model_config.model_overide_args is not None:
198
+ if self.model_config.model_override_args is not None:
188
199
  self.vllm_model_config.hf_config.update(
189
- self.model_config.model_overide_args
200
+ self.model_config.model_override_args
190
201
  )
191
202
 
192
203
  self.model = get_model(
@@ -337,13 +348,7 @@ class ModelRunner:
337
348
  if self.server_args.kv_cache_dtype == "auto":
338
349
  self.kv_cache_dtype = self.dtype
339
350
  elif self.server_args.kv_cache_dtype == "fp8_e5m2":
340
- if self.server_args.disable_flashinfer or self.server_args.enable_mla:
341
- logger.warning(
342
- "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
343
- )
344
- self.kv_cache_dtype = self.dtype
345
- else:
346
- self.kv_cache_dtype = torch.float8_e5m2
351
+ self.kv_cache_dtype = torch.float8_e5m2
347
352
  else:
348
353
  raise ValueError(
349
354
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -507,15 +512,19 @@ class ModelRunner:
507
512
  raise Exception(
508
513
  f"Capture cuda graph failed: {e}\n"
509
514
  "Possible solutions:\n"
510
- "1. disable torch compile by not using --enable-torch-compile\n"
511
- "2. disable cuda graph by --disable-cuda-graph\n"
512
- "3. set --mem-fraction-static to a smaller value\n"
515
+ "1. disable cuda graph by --disable-cuda-graph\n"
516
+ "2. set --mem-fraction-static to a smaller value\n"
517
+ "3. disable torch compile by not using --enable-torch-compile\n"
513
518
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
514
519
  )
515
520
 
516
521
  @torch.inference_mode()
517
522
  def forward_decode(self, batch: ScheduleBatch):
518
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
523
+ if (
524
+ self.cuda_graph_runner
525
+ and self.cuda_graph_runner.can_run(len(batch.reqs))
526
+ and not batch.sampling_info.has_bias()
527
+ ):
519
528
  return self.cuda_graph_runner.replay(batch)
520
529
 
521
530
  input_metadata = InputMetadata.from_schedule_batch(
@@ -564,7 +573,9 @@ class ModelRunner:
564
573
  input_metadata.image_offsets,
565
574
  )
566
575
 
567
- def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
576
+ def forward(
577
+ self, batch: ScheduleBatch, forward_mode: ForwardMode
578
+ ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
568
579
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
569
580
  return self.forward_extend_multi_modal(batch)
570
581
  elif forward_mode == ForwardMode.DECODE:
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  # Adapted from
18
18
  # https://github.com/THUDM/ChatGLM2-6B
19
19
  """Inference-only ChatGLM model compatible with THUDM weights."""
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from torch import nn
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
31
31
  )
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
- from vllm.model_executor.layers.sampler import Sampler
35
34
  from vllm.model_executor.layers.vocab_parallel_embedding import (
36
35
  ParallelLMHead,
37
36
  VocabParallelEmbedding,
38
37
  )
39
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
- from vllm.model_executor.sampling_metadata import SamplingMetadata
41
- from vllm.sequence import SamplerOutput
42
39
  from vllm.transformers_utils.configs import ChatGLMConfig
43
40
 
44
41
  from sglang.srt.layers.activation import SiluAndMul
45
42
  from sglang.srt.layers.layernorm import RMSNorm
46
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.layers.sampler import Sampler
48
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
47
 
50
48
  LoraConfig = None
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
383
381
  input_metadata: InputMetadata,
384
382
  ) -> torch.Tensor:
385
383
  hidden_states = self.transformer(input_ids, positions, input_metadata)
386
- return self.logits_processor(
384
+ logits_output = self.logits_processor(
387
385
  input_ids, hidden_states, self.lm_head.weight, input_metadata
388
386
  )
389
-
390
- def sample(
391
- self,
392
- logits: torch.Tensor,
393
- sampling_metadata: SamplingMetadata,
394
- ) -> Optional[SamplerOutput]:
395
- next_tokens = self.sampler(logits, sampling_metadata)
396
- return next_tokens
387
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
+ return sample_output, logits_output
397
389
 
398
390
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
391
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
  from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
+ from sglang.srt.layers.sampler import Sampler
67
68
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
69
 
69
70
 
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
326
327
  self.config = config
327
328
  self.quant_config = quant_config
328
329
  self.logits_processor = LogitsProcessor(config)
330
+ self.sampler = Sampler()
329
331
  self.model = CohereModel(config, quant_config)
330
332
 
331
333
  @torch.no_grad()
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
340
342
  positions,
341
343
  input_metadata,
342
344
  )
343
- return self.logits_processor(
345
+ logits_output = self.logits_processor(
344
346
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
345
347
  )
348
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
+ return sample_output, logits_output
346
350
 
347
351
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
348
352
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
382
383
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
383
384
  )
384
385
  self.logits_processor = LogitsProcessor(config)
386
+ self.sampler = Sampler()
385
387
 
386
388
  @torch.no_grad()
387
389
  def forward(
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
391
393
  input_metadata: InputMetadata,
392
394
  ) -> torch.Tensor:
393
395
  hidden_states = self.transformer(input_ids, positions, input_metadata)
394
- return self.logits_processor(
396
+ logits_output = self.logits_processor(
395
397
  input_ids, hidden_states, self.lm_head.weight, input_metadata
396
398
  )
399
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
+ return sample_output, logits_output
397
401
 
398
402
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
403
  expert_params_mapping = [
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
385
386
  config.vocab_size, config.hidden_size, quant_config=quant_config
386
387
  )
387
388
  self.logits_processor = LogitsProcessor(config)
389
+ self.sampler = Sampler()
388
390
 
389
391
  @torch.no_grad()
390
392
  def forward(
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
394
396
  input_metadata: InputMetadata,
395
397
  ) -> torch.Tensor:
396
398
  hidden_states = self.model(input_ids, positions, input_metadata)
397
- return self.logits_processor(
399
+ logits_output = self.logits_processor(
398
400
  input_ids, hidden_states, self.lm_head.weight, input_metadata
399
401
  )
402
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
+ return sample_output, logits_output
400
404
 
401
405
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
402
406
  stacked_params_mapping = [