sglang 0.2.14__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sglang/srt/constrained/fsm_cache.py +11 -2
  2. sglang/srt/constrained/jump_forward.py +1 -0
  3. sglang/srt/layers/activation.py +83 -7
  4. sglang/srt/layers/layernorm.py +0 -3
  5. sglang/srt/layers/logits_processor.py +4 -4
  6. sglang/srt/layers/sampler.py +15 -68
  7. sglang/srt/managers/schedule_batch.py +15 -20
  8. sglang/srt/managers/tp_worker.py +40 -33
  9. sglang/srt/model_executor/cuda_graph_runner.py +17 -31
  10. sglang/srt/model_executor/forward_batch_info.py +1 -8
  11. sglang/srt/model_executor/model_runner.py +5 -11
  12. sglang/srt/models/chatglm.py +12 -4
  13. sglang/srt/models/commandr.py +1 -5
  14. sglang/srt/models/dbrx.py +1 -5
  15. sglang/srt/models/deepseek.py +1 -5
  16. sglang/srt/models/deepseek_v2.py +1 -5
  17. sglang/srt/models/gemma.py +1 -5
  18. sglang/srt/models/gemma2.py +1 -5
  19. sglang/srt/models/gpt_bigcode.py +2 -6
  20. sglang/srt/models/grok.py +1 -5
  21. sglang/srt/models/internlm2.py +1 -5
  22. sglang/srt/models/llama2.py +3 -7
  23. sglang/srt/models/llama_classification.py +2 -2
  24. sglang/srt/models/minicpm.py +1 -5
  25. sglang/srt/models/mixtral.py +1 -5
  26. sglang/srt/models/mixtral_quant.py +1 -5
  27. sglang/srt/models/qwen.py +2 -5
  28. sglang/srt/models/qwen2.py +2 -6
  29. sglang/srt/models/qwen2_moe.py +14 -5
  30. sglang/srt/models/stablelm.py +1 -5
  31. sglang/srt/openai_api/adapter.py +85 -4
  32. sglang/srt/openai_api/protocol.py +2 -0
  33. sglang/srt/sampling/sampling_batch_info.py +1 -74
  34. sglang/srt/sampling/sampling_params.py +4 -0
  35. sglang/srt/server.py +8 -1
  36. sglang/test/runners.py +1 -1
  37. sglang/version.py +1 -1
  38. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +10 -4
  39. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/RECORD +42 -42
  40. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  41. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  42. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import bisect
19
19
  from contextlib import contextmanager
20
+ from typing import Callable, List
20
21
 
21
22
  import torch
22
23
  from flashinfer import BatchDecodeWithPagedKVCacheWrapper
@@ -25,18 +26,16 @@ from vllm.distributed.parallel_state import graph_capture
25
26
  from vllm.model_executor.custom_op import CustomOp
26
27
 
27
28
  from sglang.srt.layers.logits_processor import (
29
+ LogitProcessorOutput,
28
30
  LogitsMetadata,
29
31
  LogitsProcessor,
30
- LogitsProcessorOutput,
31
32
  )
32
- from sglang.srt.layers.sampler import SampleOutput
33
33
  from sglang.srt.managers.schedule_batch import ScheduleBatch
34
34
  from sglang.srt.model_executor.forward_batch_info import (
35
35
  ForwardMode,
36
36
  InputMetadata,
37
37
  update_flashinfer_indices,
38
38
  )
39
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
40
39
  from sglang.srt.utils import monkey_patch_vllm_all_gather
41
40
 
42
41
 
@@ -53,12 +52,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
53
52
 
54
53
  @contextmanager
55
54
  def patch_model(
56
- model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
55
+ model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
57
56
  ):
58
57
  backup_ca_comm = None
59
58
 
60
59
  try:
61
- if use_compile:
60
+ if enable_compile:
62
61
  _to_torch(model)
63
62
  monkey_patch_vllm_all_gather()
64
63
  backup_ca_comm = tp_group.ca_comm
@@ -67,7 +66,7 @@ def patch_model(
67
66
  else:
68
67
  yield model.forward
69
68
  finally:
70
- if use_compile:
69
+ if enable_compile:
71
70
  _to_torch(model, reverse=True)
72
71
  monkey_patch_vllm_all_gather(reverse=True)
73
72
  tp_group.ca_comm = backup_ca_comm
@@ -88,7 +87,7 @@ def set_torch_compile_config():
88
87
  class CudaGraphRunner:
89
88
  def __init__(
90
89
  self,
91
- model_runner,
90
+ model_runner: "ModelRunner",
92
91
  max_batch_size_to_capture: int,
93
92
  use_torch_compile: bool,
94
93
  disable_padding: bool,
@@ -145,22 +144,18 @@ class CudaGraphRunner:
145
144
  self.flashinfer_kv_indices.clone(),
146
145
  ]
147
146
 
148
- # Sampling inputs
149
- vocab_size = model_runner.model_config.vocab_size
150
- self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
151
-
152
147
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
153
148
 
154
149
  if use_torch_compile:
155
150
  set_torch_compile_config()
156
151
 
157
- def can_run(self, batch_size):
152
+ def can_run(self, batch_size: int):
158
153
  if self.disable_padding:
159
154
  return batch_size in self.graphs
160
155
  else:
161
156
  return batch_size <= self.max_bs
162
157
 
163
- def capture(self, batch_size_list):
158
+ def capture(self, batch_size_list: List[int]):
164
159
  self.batch_size_list = batch_size_list
165
160
  with graph_capture() as graph_capture_context:
166
161
  self.stream = graph_capture_context.stream
@@ -181,7 +176,7 @@ class CudaGraphRunner:
181
176
  self.output_buffers[bs] = output_buffers
182
177
  self.flashinfer_handlers[bs] = flashinfer_handler
183
178
 
184
- def capture_one_batch_size(self, bs, forward):
179
+ def capture_one_batch_size(self, bs: int, forward: Callable):
185
180
  graph = torch.cuda.CUDAGraph()
186
181
  stream = self.stream
187
182
 
@@ -240,7 +235,6 @@ class CudaGraphRunner:
240
235
  def run_once():
241
236
  input_metadata = InputMetadata(
242
237
  forward_mode=ForwardMode.DECODE,
243
- sampling_info=self.sampling_info[:bs],
244
238
  batch_size=bs,
245
239
  req_pool_indices=req_pool_indices,
246
240
  seq_lens=seq_lens,
@@ -305,35 +299,27 @@ class CudaGraphRunner:
305
299
  self.flashinfer_handlers[bs],
306
300
  )
307
301
 
308
- # Sampling inputs
309
- self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
310
-
311
302
  # Replay
312
303
  torch.cuda.synchronize()
313
304
  self.graphs[bs].replay()
314
305
  torch.cuda.synchronize()
315
- sample_output, logits_output = self.output_buffers[bs]
306
+ output = self.output_buffers[bs]
316
307
 
317
308
  # Unpad
318
309
  if bs != raw_bs:
319
- logits_output = LogitsProcessorOutput(
320
- next_token_logits=logits_output.next_token_logits[:raw_bs],
310
+ output = LogitProcessorOutput(
311
+ next_token_logits=output.next_token_logits[:raw_bs],
321
312
  next_token_logprobs=None,
322
313
  normalized_prompt_logprobs=None,
323
314
  input_token_logprobs=None,
324
315
  input_top_logprobs=None,
325
316
  output_top_logprobs=None,
326
317
  )
327
- sample_output = SampleOutput(
328
- sample_output.success[:raw_bs],
329
- sample_output.probs[:raw_bs],
330
- sample_output.batch_next_token_ids[:raw_bs],
331
- )
332
318
 
333
319
  # Extract logprobs
334
320
  if batch.return_logprob:
335
- logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
336
- logits_output.next_token_logits, dim=-1
321
+ output.next_token_logprobs = torch.nn.functional.log_softmax(
322
+ output.next_token_logits, dim=-1
337
323
  )
338
324
  return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
339
325
  if return_top_logprob:
@@ -341,8 +327,8 @@ class CudaGraphRunner:
341
327
  forward_mode=ForwardMode.DECODE,
342
328
  top_logprobs_nums=batch.top_logprobs_nums,
343
329
  )
344
- logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
345
- logits_output.next_token_logprobs, logits_metadata
330
+ output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
331
+ output.next_token_logprobs, logits_metadata
346
332
  )[1]
347
333
 
348
- return sample_output, logits_output
334
+ return output
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  """
4
2
  Copyright 2023-2024 SGLang Team
5
3
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,7 +16,7 @@ limitations under the License.
18
16
  """ModelRunner runs the forward passes of the models."""
19
17
  from dataclasses import dataclass
20
18
  from enum import IntEnum, auto
21
- from typing import TYPE_CHECKING, List
19
+ from typing import TYPE_CHECKING, List, Optional
22
20
 
23
21
  import numpy as np
24
22
  import torch
@@ -28,7 +26,6 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
28
26
 
29
27
  if TYPE_CHECKING:
30
28
  from sglang.srt.model_executor.model_runner import ModelRunner
31
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
32
29
 
33
30
 
34
31
  class ForwardMode(IntEnum):
@@ -45,7 +42,6 @@ class InputMetadata:
45
42
  """Store all inforamtion of a forward pass."""
46
43
 
47
44
  forward_mode: ForwardMode
48
- sampling_info: SamplingBatchInfo
49
45
  batch_size: int
50
46
  req_pool_indices: torch.Tensor
51
47
  seq_lens: torch.Tensor
@@ -183,7 +179,6 @@ class InputMetadata:
183
179
  ):
184
180
  ret = cls(
185
181
  forward_mode=forward_mode,
186
- sampling_info=batch.sampling_info,
187
182
  batch_size=batch.batch_size(),
188
183
  req_pool_indices=batch.req_pool_indices,
189
184
  seq_lens=batch.seq_lens,
@@ -194,8 +189,6 @@ class InputMetadata:
194
189
  top_logprobs_nums=batch.top_logprobs_nums,
195
190
  )
196
191
 
197
- ret.sampling_info.prepare_penalties()
198
-
199
192
  ret.compute_positions(batch)
200
193
 
201
194
  ret.compute_extend_infos(batch)
@@ -21,7 +21,7 @@ import importlib.resources
21
21
  import logging
22
22
  import pkgutil
23
23
  from functools import lru_cache
24
- from typing import Optional, Tuple, Type
24
+ from typing import Optional, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -44,8 +44,6 @@ from vllm.model_executor.model_loader import get_model
44
44
  from vllm.model_executor.models import ModelRegistry
45
45
 
46
46
  from sglang.global_config import global_config
47
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
- from sglang.srt.layers.sampler import SampleOutput
49
47
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
50
48
  from sglang.srt.mem_cache.memory_pool import (
51
49
  MHATokenToKVPool,
@@ -161,6 +159,8 @@ class ModelRunner:
161
159
  "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
162
160
  )
163
161
  self.server_args.dtype = "float16"
162
+ if torch.cuda.get_device_capability()[1] < 5:
163
+ raise RuntimeError("SGLang only supports sm75 and above.")
164
164
 
165
165
  monkey_patch_vllm_dummy_weight_loader()
166
166
  self.device_config = DeviceConfig()
@@ -515,11 +515,7 @@ class ModelRunner:
515
515
 
516
516
  @torch.inference_mode()
517
517
  def forward_decode(self, batch: ScheduleBatch):
518
- if (
519
- self.cuda_graph_runner
520
- and self.cuda_graph_runner.can_run(len(batch.reqs))
521
- and not batch.sampling_info.has_bias()
522
- ):
518
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
523
519
  return self.cuda_graph_runner.replay(batch)
524
520
 
525
521
  input_metadata = InputMetadata.from_schedule_batch(
@@ -568,9 +564,7 @@ class ModelRunner:
568
564
  input_metadata.image_offsets,
569
565
  )
570
566
 
571
- def forward(
572
- self, batch: ScheduleBatch, forward_mode: ForwardMode
573
- ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
567
+ def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
574
568
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
575
569
  return self.forward_extend_multi_modal(batch)
576
570
  elif forward_mode == ForwardMode.DECODE:
@@ -31,18 +31,20 @@ from vllm.model_executor.layers.linear import (
31
31
  )
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
+ from vllm.model_executor.layers.sampler import Sampler
34
35
  from vllm.model_executor.layers.vocab_parallel_embedding import (
35
36
  ParallelLMHead,
36
37
  VocabParallelEmbedding,
37
38
  )
38
39
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
41
+ from vllm.sequence import SamplerOutput
39
42
  from vllm.transformers_utils.configs import ChatGLMConfig
40
43
 
41
44
  from sglang.srt.layers.activation import SiluAndMul
42
45
  from sglang.srt.layers.layernorm import RMSNorm
43
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
47
  from sglang.srt.layers.radix_attention import RadixAttention
45
- from sglang.srt.layers.sampler import Sampler
46
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
47
49
 
48
50
  LoraConfig = None
@@ -381,11 +383,17 @@ class ChatGLMForCausalLM(nn.Module):
381
383
  input_metadata: InputMetadata,
382
384
  ) -> torch.Tensor:
383
385
  hidden_states = self.transformer(input_ids, positions, input_metadata)
384
- logits_output = self.logits_processor(
386
+ return self.logits_processor(
385
387
  input_ids, hidden_states, self.lm_head.weight, input_metadata
386
388
  )
387
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
- return sample_output, logits_output
389
+
390
+ def sample(
391
+ self,
392
+ logits: torch.Tensor,
393
+ sampling_metadata: SamplingMetadata,
394
+ ) -> Optional[SamplerOutput]:
395
+ next_tokens = self.sampler(logits, sampling_metadata)
396
+ return next_tokens
389
397
 
390
398
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
391
399
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
  from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
- from sglang.srt.layers.sampler import Sampler
68
67
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
69
68
 
70
69
 
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
327
326
  self.config = config
328
327
  self.quant_config = quant_config
329
328
  self.logits_processor = LogitsProcessor(config)
330
- self.sampler = Sampler()
331
329
  self.model = CohereModel(config, quant_config)
332
330
 
333
331
  @torch.no_grad()
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
342
340
  positions,
343
341
  input_metadata,
344
342
  )
345
- logits_output = self.logits_processor(
343
+ return self.logits_processor(
346
344
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
347
345
  )
348
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
- return sample_output, logits_output
350
346
 
351
347
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
352
348
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
49
 
51
50
 
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
383
382
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
384
383
  )
385
384
  self.logits_processor = LogitsProcessor(config)
386
- self.sampler = Sampler()
387
385
 
388
386
  @torch.no_grad()
389
387
  def forward(
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
393
391
  input_metadata: InputMetadata,
394
392
  ) -> torch.Tensor:
395
393
  hidden_states = self.transformer(input_ids, positions, input_metadata)
396
- logits_output = self.logits_processor(
394
+ return self.logits_processor(
397
395
  input_ids, hidden_states, self.lm_head.weight, input_metadata
398
396
  )
399
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
- return sample_output, logits_output
401
397
 
402
398
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
403
399
  expert_params_mapping = [
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
52
51
 
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
386
385
  config.vocab_size, config.hidden_size, quant_config=quant_config
387
386
  )
388
387
  self.logits_processor = LogitsProcessor(config)
389
- self.sampler = Sampler()
390
388
 
391
389
  @torch.no_grad()
392
390
  def forward(
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
396
394
  input_metadata: InputMetadata,
397
395
  ) -> torch.Tensor:
398
396
  hidden_states = self.model(input_ids, positions, input_metadata)
399
- logits_output = self.logits_processor(
397
+ return self.logits_processor(
400
398
  input_ids, hidden_states, self.lm_head.weight, input_metadata
401
399
  )
402
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
- return sample_output, logits_output
404
400
 
405
401
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
406
402
  stacked_params_mapping = [
@@ -45,7 +45,6 @@ from sglang.srt.layers.activation import SiluAndMul
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
@@ -633,7 +632,6 @@ class DeepseekV2ForCausalLM(nn.Module):
633
632
  config.vocab_size, config.hidden_size, quant_config=quant_config
634
633
  )
635
634
  self.logits_processor = LogitsProcessor(config)
636
- self.sampler = Sampler()
637
635
 
638
636
  def forward(
639
637
  self,
@@ -642,11 +640,9 @@ class DeepseekV2ForCausalLM(nn.Module):
642
640
  input_metadata: InputMetadata,
643
641
  ) -> torch.Tensor:
644
642
  hidden_states = self.model(input_ids, positions, input_metadata)
645
- logits_output = self.logits_processor(
643
+ return self.logits_processor(
646
644
  input_ids, hidden_states, self.lm_head.weight, input_metadata
647
645
  )
648
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
649
- return sample_output, logits_output
650
646
 
651
647
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
652
648
  stacked_params_mapping = [
@@ -37,7 +37,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.layers.sampler import Sampler
41
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
41
 
43
42
 
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
288
287
  self.quant_config = quant_config
289
288
  self.model = GemmaModel(config, quant_config=quant_config)
290
289
  self.logits_processor = LogitsProcessor(config)
291
- self.sampler = Sampler()
292
290
 
293
291
  @torch.no_grad()
294
292
  def forward(
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
299
297
  input_embeds: torch.Tensor = None,
300
298
  ) -> torch.Tensor:
301
299
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
302
- logits_output = self.logits_processor(
300
+ return self.logits_processor(
303
301
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
304
302
  )
305
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
- return (sample_output, logits_output)
307
303
 
308
304
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
309
305
  stacked_params_mapping = [
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.activation import GeluAndMul
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
45
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
45
 
47
46
 
@@ -397,7 +396,6 @@ class Gemma2ForCausalLM(nn.Module):
397
396
  self.quant_config = quant_config
398
397
  self.model = Gemma2Model(config, cache_config, quant_config)
399
398
  self.logits_processor = LogitsProcessor(config)
400
- self.sampler = Sampler()
401
399
 
402
400
  @torch.no_grad()
403
401
  def forward(
@@ -408,11 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
408
406
  input_embeds: torch.Tensor = None,
409
407
  ) -> torch.Tensor:
410
408
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
411
- logits_output = self.logits_processor(
409
+ return self.logits_processor(
412
410
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
413
411
  )
414
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
415
- return sample_output, logits_output
416
412
 
417
413
  def get_attention_sliding_window_size(self):
418
414
  return get_attention_sliding_window_size(self.config)
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import get_act_fn
27
26
  from vllm.model_executor.layers.linear import (
28
27
  ColumnParallelLinear,
29
28
  QKVParallelLinear,
@@ -33,9 +32,9 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
33
32
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
34
33
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
34
 
35
+ from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.layers.sampler import Sampler
39
38
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
40
39
 
41
40
 
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
262
261
  if lora_config:
263
262
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
264
263
  self.logits_processor = LogitsProcessor(config)
265
- self.sampler = Sampler()
266
264
 
267
265
  @torch.no_grad()
268
266
  def forward(
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
272
270
  input_metadata: InputMetadata,
273
271
  ) -> torch.Tensor:
274
272
  hidden_states = self.transformer(input_ids, positions, input_metadata)
275
- logits_output = self.logits_processor(
273
+ return self.logits_processor(
276
274
  input_ids, hidden_states, self.lm_head.weight, input_metadata
277
275
  )
278
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
- return sample_output, logits_output
280
276
 
281
277
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
282
278
  params_dict = dict(self.named_parameters(remove_duplicate=False))
sglang/srt/models/grok.py CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
52
51
 
@@ -298,7 +297,6 @@ class Grok1ModelForCausalLM(nn.Module):
298
297
  self.model = Grok1Model(config, quant_config=quant_config)
299
298
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  # Monkey patch _prepare_weights to load pre-sharded weights
304
302
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@@ -315,11 +313,9 @@ class Grok1ModelForCausalLM(nn.Module):
315
313
  input_embeds: torch.Tensor = None,
316
314
  ) -> torch.Tensor:
317
315
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
318
- logits_output = self.logits_processor(
316
+ return self.logits_processor(
319
317
  input_ids, hidden_states, self.lm_head.weight, input_metadata
320
318
  )
321
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
- return sample_output, logits_output
323
319
 
324
320
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
325
321
  stacked_params_mapping = [
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
263
262
  self.model = InternLM2Model(config, quant_config)
264
263
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
265
264
  self.logits_processor = LogitsProcessor(config)
266
- self.sampler = Sampler()
267
265
 
268
266
  @torch.no_grad()
269
267
  def forward(
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
274
272
  input_embeds: torch.Tensor = None,
275
273
  ) -> torch.Tensor:
276
274
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
277
- logits_output = self.logits_processor(
275
+ return self.logits_processor(
278
276
  input_ids, hidden_states, self.output.weight, input_metadata
279
277
  )
280
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
- return sample_output, logits_output
282
278
 
283
279
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
284
280
  stacked_params_mapping = [
@@ -39,9 +39,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
42
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
45
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
45
 
47
46
 
@@ -303,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
303
302
  self.model = LlamaModel(config, quant_config=quant_config)
304
303
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
305
304
  self.logits_processor = LogitsProcessor(config)
306
- self.sampler = Sampler()
307
305
 
308
306
  @torch.no_grad()
309
307
  def forward(
@@ -312,13 +310,11 @@ class LlamaForCausalLM(nn.Module):
312
310
  positions: torch.Tensor,
313
311
  input_metadata: InputMetadata,
314
312
  input_embeds: torch.Tensor = None,
315
- ) -> LogitsProcessorOutput:
313
+ ) -> LogitProcessorOutput:
316
314
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
317
- logits_output = self.logits_processor(
315
+ return self.logits_processor(
318
316
  input_ids, hidden_states, self.lm_head.weight, input_metadata
319
317
  )
320
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
321
- return sample_output, logits_output
322
318
 
323
319
  def get_module_name(self, name):
324
320
  stacked_params_mapping = [
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
27
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
28
28
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
65
65
  (input_metadata.batch_size, self.config.classification_out_size)
66
66
  ).to(input_ids.device)
67
67
 
68
- return LogitsProcessorOutput(
68
+ return LogitProcessorOutput(
69
69
  next_token_logits=scores,
70
70
  next_token_logprobs=scores,
71
71
  normalized_prompt_logprobs=scores,
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.layers.sampler import Sampler
43
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
43
 
45
44
 
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
298
297
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
299
298
 
300
299
  self.logits_processor = LogitsProcessor(config)
301
- self.sampler = Sampler()
302
300
 
303
301
  @torch.no_grad()
304
302
  def forward(
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
316
314
  lm_head_weight = self.model.embed_tokens.weight
317
315
  else:
318
316
  lm_head_weight = self.lm_head.weight
319
- logits_output = self.logits_processor(
317
+ return self.logits_processor(
320
318
  input_ids, hidden_states, lm_head_weight, input_metadata
321
319
  )
322
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
- return sample_output, logits_output
324
320
 
325
321
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
322
  stacked_params_mapping = [