sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sglang/launch_server_llavavid.py +26 -0
  2. sglang/srt/constrained/fsm_cache.py +11 -2
  3. sglang/srt/constrained/jump_forward.py +1 -0
  4. sglang/srt/hf_transformers_utils.py +0 -149
  5. sglang/srt/layers/activation.py +93 -11
  6. sglang/srt/layers/layernorm.py +47 -4
  7. sglang/srt/layers/logits_processor.py +4 -4
  8. sglang/srt/layers/sampler.py +15 -68
  9. sglang/srt/managers/io_struct.py +5 -4
  10. sglang/srt/managers/schedule_batch.py +20 -25
  11. sglang/srt/managers/tokenizer_manager.py +74 -61
  12. sglang/srt/managers/tp_worker.py +49 -43
  13. sglang/srt/model_executor/cuda_graph_runner.py +17 -31
  14. sglang/srt/model_executor/forward_batch_info.py +9 -26
  15. sglang/srt/model_executor/model_runner.py +20 -17
  16. sglang/srt/models/chatglm.py +13 -5
  17. sglang/srt/models/commandr.py +1 -5
  18. sglang/srt/models/dbrx.py +1 -5
  19. sglang/srt/models/deepseek.py +1 -5
  20. sglang/srt/models/deepseek_v2.py +1 -5
  21. sglang/srt/models/gemma.py +3 -7
  22. sglang/srt/models/gemma2.py +2 -56
  23. sglang/srt/models/gpt_bigcode.py +2 -6
  24. sglang/srt/models/grok.py +10 -8
  25. sglang/srt/models/internlm2.py +1 -5
  26. sglang/srt/models/llama2.py +6 -11
  27. sglang/srt/models/llama_classification.py +2 -6
  28. sglang/srt/models/llama_embedding.py +3 -4
  29. sglang/srt/models/llava.py +69 -91
  30. sglang/srt/models/llavavid.py +40 -86
  31. sglang/srt/models/minicpm.py +1 -5
  32. sglang/srt/models/mixtral.py +1 -5
  33. sglang/srt/models/mixtral_quant.py +1 -5
  34. sglang/srt/models/qwen.py +2 -5
  35. sglang/srt/models/qwen2.py +5 -10
  36. sglang/srt/models/qwen2_moe.py +21 -24
  37. sglang/srt/models/stablelm.py +1 -5
  38. sglang/srt/models/yivl.py +2 -7
  39. sglang/srt/openai_api/adapter.py +85 -4
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +1 -74
  42. sglang/srt/sampling/sampling_params.py +4 -0
  43. sglang/srt/server.py +11 -4
  44. sglang/srt/utils.py +18 -33
  45. sglang/test/runners.py +2 -2
  46. sglang/test/test_layernorm.py +53 -1
  47. sglang/version.py +1 -1
  48. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
  49. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
  50. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
  51. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import bisect
19
19
  from contextlib import contextmanager
20
+ from typing import Callable, List
20
21
 
21
22
  import torch
22
23
  from flashinfer import BatchDecodeWithPagedKVCacheWrapper
@@ -25,18 +26,16 @@ from vllm.distributed.parallel_state import graph_capture
25
26
  from vllm.model_executor.custom_op import CustomOp
26
27
 
27
28
  from sglang.srt.layers.logits_processor import (
29
+ LogitProcessorOutput,
28
30
  LogitsMetadata,
29
31
  LogitsProcessor,
30
- LogitsProcessorOutput,
31
32
  )
32
- from sglang.srt.layers.sampler import SampleOutput
33
33
  from sglang.srt.managers.schedule_batch import ScheduleBatch
34
34
  from sglang.srt.model_executor.forward_batch_info import (
35
35
  ForwardMode,
36
36
  InputMetadata,
37
37
  update_flashinfer_indices,
38
38
  )
39
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
40
39
  from sglang.srt.utils import monkey_patch_vllm_all_gather
41
40
 
42
41
 
@@ -53,12 +52,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
53
52
 
54
53
  @contextmanager
55
54
  def patch_model(
56
- model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
55
+ model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
57
56
  ):
58
57
  backup_ca_comm = None
59
58
 
60
59
  try:
61
- if use_compile:
60
+ if enable_compile:
62
61
  _to_torch(model)
63
62
  monkey_patch_vllm_all_gather()
64
63
  backup_ca_comm = tp_group.ca_comm
@@ -67,7 +66,7 @@ def patch_model(
67
66
  else:
68
67
  yield model.forward
69
68
  finally:
70
- if use_compile:
69
+ if enable_compile:
71
70
  _to_torch(model, reverse=True)
72
71
  monkey_patch_vllm_all_gather(reverse=True)
73
72
  tp_group.ca_comm = backup_ca_comm
@@ -88,7 +87,7 @@ def set_torch_compile_config():
88
87
  class CudaGraphRunner:
89
88
  def __init__(
90
89
  self,
91
- model_runner,
90
+ model_runner: "ModelRunner",
92
91
  max_batch_size_to_capture: int,
93
92
  use_torch_compile: bool,
94
93
  disable_padding: bool,
@@ -145,22 +144,18 @@ class CudaGraphRunner:
145
144
  self.flashinfer_kv_indices.clone(),
146
145
  ]
147
146
 
148
- # Sampling inputs
149
- vocab_size = model_runner.model_config.vocab_size
150
- self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
151
-
152
147
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
153
148
 
154
149
  if use_torch_compile:
155
150
  set_torch_compile_config()
156
151
 
157
- def can_run(self, batch_size):
152
+ def can_run(self, batch_size: int):
158
153
  if self.disable_padding:
159
154
  return batch_size in self.graphs
160
155
  else:
161
156
  return batch_size <= self.max_bs
162
157
 
163
- def capture(self, batch_size_list):
158
+ def capture(self, batch_size_list: List[int]):
164
159
  self.batch_size_list = batch_size_list
165
160
  with graph_capture() as graph_capture_context:
166
161
  self.stream = graph_capture_context.stream
@@ -181,7 +176,7 @@ class CudaGraphRunner:
181
176
  self.output_buffers[bs] = output_buffers
182
177
  self.flashinfer_handlers[bs] = flashinfer_handler
183
178
 
184
- def capture_one_batch_size(self, bs, forward):
179
+ def capture_one_batch_size(self, bs: int, forward: Callable):
185
180
  graph = torch.cuda.CUDAGraph()
186
181
  stream = self.stream
187
182
 
@@ -240,7 +235,6 @@ class CudaGraphRunner:
240
235
  def run_once():
241
236
  input_metadata = InputMetadata(
242
237
  forward_mode=ForwardMode.DECODE,
243
- sampling_info=self.sampling_info[:bs],
244
238
  batch_size=bs,
245
239
  req_pool_indices=req_pool_indices,
246
240
  seq_lens=seq_lens,
@@ -305,35 +299,27 @@ class CudaGraphRunner:
305
299
  self.flashinfer_handlers[bs],
306
300
  )
307
301
 
308
- # Sampling inputs
309
- self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
310
-
311
302
  # Replay
312
303
  torch.cuda.synchronize()
313
304
  self.graphs[bs].replay()
314
305
  torch.cuda.synchronize()
315
- sample_output, logits_output = self.output_buffers[bs]
306
+ output = self.output_buffers[bs]
316
307
 
317
308
  # Unpad
318
309
  if bs != raw_bs:
319
- logits_output = LogitsProcessorOutput(
320
- next_token_logits=logits_output.next_token_logits[:raw_bs],
310
+ output = LogitProcessorOutput(
311
+ next_token_logits=output.next_token_logits[:raw_bs],
321
312
  next_token_logprobs=None,
322
313
  normalized_prompt_logprobs=None,
323
314
  input_token_logprobs=None,
324
315
  input_top_logprobs=None,
325
316
  output_top_logprobs=None,
326
317
  )
327
- sample_output = SampleOutput(
328
- sample_output.success[:raw_bs],
329
- sample_output.probs[:raw_bs],
330
- sample_output.batch_next_token_ids[:raw_bs],
331
- )
332
318
 
333
319
  # Extract logprobs
334
320
  if batch.return_logprob:
335
- logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
336
- logits_output.next_token_logits, dim=-1
321
+ output.next_token_logprobs = torch.nn.functional.log_softmax(
322
+ output.next_token_logits, dim=-1
337
323
  )
338
324
  return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
339
325
  if return_top_logprob:
@@ -341,8 +327,8 @@ class CudaGraphRunner:
341
327
  forward_mode=ForwardMode.DECODE,
342
328
  top_logprobs_nums=batch.top_logprobs_nums,
343
329
  )
344
- logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
345
- logits_output.next_token_logprobs, logits_metadata
330
+ output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
331
+ output.next_token_logprobs, logits_metadata
346
332
  )[1]
347
333
 
348
- return sample_output, logits_output
334
+ return output
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  """
4
2
  Copyright 2023-2024 SGLang Team
5
3
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,7 +26,6 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
28
26
 
29
27
  if TYPE_CHECKING:
30
28
  from sglang.srt.model_executor.model_runner import ModelRunner
31
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
32
29
 
33
30
 
34
31
  class ForwardMode(IntEnum):
@@ -45,7 +42,6 @@ class InputMetadata:
45
42
  """Store all inforamtion of a forward pass."""
46
43
 
47
44
  forward_mode: ForwardMode
48
- sampling_info: SamplingBatchInfo
49
45
  batch_size: int
50
46
  req_pool_indices: torch.Tensor
51
47
  seq_lens: torch.Tensor
@@ -62,6 +58,7 @@ class InputMetadata:
62
58
 
63
59
  # For extend
64
60
  extend_seq_lens: torch.Tensor = None
61
+ extend_prefix_lens: torch.Tensor = None
65
62
  extend_start_loc: torch.Tensor = None
66
63
  extend_no_prefix: bool = None
67
64
 
@@ -73,8 +70,8 @@ class InputMetadata:
73
70
 
74
71
  # For multimodal
75
72
  pixel_values: List[torch.Tensor] = None
76
- image_sizes: List[List[int]] = None
77
- image_offsets: List[int] = None
73
+ image_sizes: List[List[List[int]]] = None
74
+ image_offsets: List[List[int]] = None
78
75
 
79
76
  # Trition attention backend
80
77
  triton_max_seq_len: int = 0
@@ -91,20 +88,8 @@ class InputMetadata:
91
88
  def init_multimuldal_info(self, batch: ScheduleBatch):
92
89
  reqs = batch.reqs
93
90
  self.pixel_values = [r.pixel_values for r in reqs]
94
- self.image_sizes = [r.image_size for r in reqs]
95
- self.image_offsets = []
96
- for r in reqs:
97
- if isinstance(r.image_offset, list):
98
- self.image_offsets.append(
99
- [
100
- (image_offset - len(r.prefix_indices))
101
- for image_offset in r.image_offset
102
- ]
103
- )
104
- elif isinstance(r.image_offset, int):
105
- self.image_offsets.append(r.image_offset - len(r.prefix_indices))
106
- elif r.image_offset is None:
107
- self.image_offsets.append(0)
91
+ self.image_sizes = [r.image_sizes for r in reqs]
92
+ self.image_offsets = [r.image_offsets for r in reqs]
108
93
 
109
94
  def compute_positions(self, batch: ScheduleBatch):
110
95
  position_ids_offsets = batch.position_ids_offsets
@@ -157,6 +142,7 @@ class InputMetadata:
157
142
  for i, r in enumerate(batch.reqs)
158
143
  ]
159
144
  self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
145
+ self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
160
146
  self.extend_start_loc = torch.zeros_like(self.seq_lens)
161
147
  self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
162
148
  self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@@ -183,7 +169,6 @@ class InputMetadata:
183
169
  ):
184
170
  ret = cls(
185
171
  forward_mode=forward_mode,
186
- sampling_info=batch.sampling_info,
187
172
  batch_size=batch.batch_size(),
188
173
  req_pool_indices=batch.req_pool_indices,
189
174
  seq_lens=batch.seq_lens,
@@ -194,8 +179,6 @@ class InputMetadata:
194
179
  top_logprobs_nums=batch.top_logprobs_nums,
195
180
  )
196
181
 
197
- ret.sampling_info.prepare_penalties()
198
-
199
182
  ret.compute_positions(batch)
200
183
 
201
184
  ret.compute_extend_infos(batch)
@@ -245,10 +228,10 @@ class InputMetadata:
245
228
  prefix_lens_cpu,
246
229
  flashinfer_use_ragged,
247
230
  ):
248
- if self.forward_mode != ForwardMode.DECODE:
249
- prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
250
- else:
231
+ if self.forward_mode == ForwardMode.DECODE:
251
232
  prefix_lens = None
233
+ else:
234
+ prefix_lens = self.extend_prefix_lens
252
235
 
253
236
  update_flashinfer_indices(
254
237
  self.forward_mode,
@@ -21,7 +21,7 @@ import importlib.resources
21
21
  import logging
22
22
  import pkgutil
23
23
  from functools import lru_cache
24
- from typing import Optional, Tuple, Type
24
+ from typing import Optional, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -44,15 +44,13 @@ from vllm.model_executor.model_loader import get_model
44
44
  from vllm.model_executor.models import ModelRegistry
45
45
 
46
46
  from sglang.global_config import global_config
47
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
- from sglang.srt.layers.sampler import SampleOutput
49
47
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
50
48
  from sglang.srt.mem_cache.memory_pool import (
51
49
  MHATokenToKVPool,
52
50
  MLATokenToKVPool,
53
51
  ReqToTokenPool,
54
52
  )
55
- from sglang.srt.model_config import AttentionArch
53
+ from sglang.srt.model_config import AttentionArch, ModelConfig
56
54
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
57
55
  from sglang.srt.server_args import ServerArgs
58
56
  from sglang.srt.utils import (
@@ -71,7 +69,7 @@ logger = logging.getLogger(__name__)
71
69
  class ModelRunner:
72
70
  def __init__(
73
71
  self,
74
- model_config,
72
+ model_config: ModelConfig,
75
73
  mem_fraction_static: float,
76
74
  gpu_id: int,
77
75
  tp_rank: int,
@@ -87,7 +85,9 @@ class ModelRunner:
87
85
  self.tp_size = tp_size
88
86
  self.nccl_port = nccl_port
89
87
  self.server_args = server_args
90
- self.is_multimodal_model = is_multimodal_model(self.model_config)
88
+ self.is_multimodal_model = is_multimodal_model(
89
+ self.model_config.hf_config.architectures
90
+ )
91
91
  global_server_args_dict.update(
92
92
  {
93
93
  "disable_flashinfer": server_args.disable_flashinfer,
@@ -97,6 +97,13 @@ class ModelRunner:
97
97
  }
98
98
  )
99
99
 
100
+ if self.is_multimodal_model:
101
+ logger.info(
102
+ "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
103
+ )
104
+ server_args.chunked_prefill_size = None
105
+ server_args.mem_fraction_static *= 0.95
106
+
100
107
  min_per_gpu_memory = self.init_torch_distributed()
101
108
  self.load_model()
102
109
  self.init_memory_pool(
@@ -161,6 +168,8 @@ class ModelRunner:
161
168
  "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
162
169
  )
163
170
  self.server_args.dtype = "float16"
171
+ if torch.cuda.get_device_capability()[1] < 5:
172
+ raise RuntimeError("SGLang only supports sm75 and above.")
164
173
 
165
174
  monkey_patch_vllm_dummy_weight_loader()
166
175
  self.device_config = DeviceConfig()
@@ -507,19 +516,15 @@ class ModelRunner:
507
516
  raise Exception(
508
517
  f"Capture cuda graph failed: {e}\n"
509
518
  "Possible solutions:\n"
510
- "1. disable torch compile by not using --enable-torch-compile\n"
511
- "2. disable cuda graph by --disable-cuda-graph\n"
512
- "3. set --mem-fraction-static to a smaller value\n"
519
+ "1. disable cuda graph by --disable-cuda-graph\n"
520
+ "2. set --mem-fraction-static to a smaller value\n"
521
+ "3. disable torch compile by not using --enable-torch-compile\n"
513
522
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
514
523
  )
515
524
 
516
525
  @torch.inference_mode()
517
526
  def forward_decode(self, batch: ScheduleBatch):
518
- if (
519
- self.cuda_graph_runner
520
- and self.cuda_graph_runner.can_run(len(batch.reqs))
521
- and not batch.sampling_info.has_bias()
522
- ):
527
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
523
528
  return self.cuda_graph_runner.replay(batch)
524
529
 
525
530
  input_metadata = InputMetadata.from_schedule_batch(
@@ -568,9 +573,7 @@ class ModelRunner:
568
573
  input_metadata.image_offsets,
569
574
  )
570
575
 
571
- def forward(
572
- self, batch: ScheduleBatch, forward_mode: ForwardMode
573
- ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
576
+ def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
574
577
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
575
578
  return self.forward_extend_multi_modal(batch)
576
579
  elif forward_mode == ForwardMode.DECODE:
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  # Adapted from
18
18
  # https://github.com/THUDM/ChatGLM2-6B
19
19
  """Inference-only ChatGLM model compatible with THUDM weights."""
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from torch import nn
@@ -31,18 +31,20 @@ from vllm.model_executor.layers.linear import (
31
31
  )
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
+ from vllm.model_executor.layers.sampler import Sampler
34
35
  from vllm.model_executor.layers.vocab_parallel_embedding import (
35
36
  ParallelLMHead,
36
37
  VocabParallelEmbedding,
37
38
  )
38
39
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
41
+ from vllm.sequence import SamplerOutput
39
42
  from vllm.transformers_utils.configs import ChatGLMConfig
40
43
 
41
44
  from sglang.srt.layers.activation import SiluAndMul
42
45
  from sglang.srt.layers.layernorm import RMSNorm
43
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
47
  from sglang.srt.layers.radix_attention import RadixAttention
45
- from sglang.srt.layers.sampler import Sampler
46
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
47
49
 
48
50
  LoraConfig = None
@@ -381,11 +383,17 @@ class ChatGLMForCausalLM(nn.Module):
381
383
  input_metadata: InputMetadata,
382
384
  ) -> torch.Tensor:
383
385
  hidden_states = self.transformer(input_ids, positions, input_metadata)
384
- logits_output = self.logits_processor(
386
+ return self.logits_processor(
385
387
  input_ids, hidden_states, self.lm_head.weight, input_metadata
386
388
  )
387
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
- return sample_output, logits_output
389
+
390
+ def sample(
391
+ self,
392
+ logits: torch.Tensor,
393
+ sampling_metadata: SamplingMetadata,
394
+ ) -> Optional[SamplerOutput]:
395
+ next_tokens = self.sampler(logits, sampling_metadata)
396
+ return next_tokens
389
397
 
390
398
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
391
399
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
  from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
- from sglang.srt.layers.sampler import Sampler
68
67
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
69
68
 
70
69
 
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
327
326
  self.config = config
328
327
  self.quant_config = quant_config
329
328
  self.logits_processor = LogitsProcessor(config)
330
- self.sampler = Sampler()
331
329
  self.model = CohereModel(config, quant_config)
332
330
 
333
331
  @torch.no_grad()
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
342
340
  positions,
343
341
  input_metadata,
344
342
  )
345
- logits_output = self.logits_processor(
343
+ return self.logits_processor(
346
344
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
347
345
  )
348
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
- return sample_output, logits_output
350
346
 
351
347
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
352
348
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
49
 
51
50
 
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
383
382
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
384
383
  )
385
384
  self.logits_processor = LogitsProcessor(config)
386
- self.sampler = Sampler()
387
385
 
388
386
  @torch.no_grad()
389
387
  def forward(
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
393
391
  input_metadata: InputMetadata,
394
392
  ) -> torch.Tensor:
395
393
  hidden_states = self.transformer(input_ids, positions, input_metadata)
396
- logits_output = self.logits_processor(
394
+ return self.logits_processor(
397
395
  input_ids, hidden_states, self.lm_head.weight, input_metadata
398
396
  )
399
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
- return sample_output, logits_output
401
397
 
402
398
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
403
399
  expert_params_mapping = [
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
52
51
 
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
386
385
  config.vocab_size, config.hidden_size, quant_config=quant_config
387
386
  )
388
387
  self.logits_processor = LogitsProcessor(config)
389
- self.sampler = Sampler()
390
388
 
391
389
  @torch.no_grad()
392
390
  def forward(
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
396
394
  input_metadata: InputMetadata,
397
395
  ) -> torch.Tensor:
398
396
  hidden_states = self.model(input_ids, positions, input_metadata)
399
- logits_output = self.logits_processor(
397
+ return self.logits_processor(
400
398
  input_ids, hidden_states, self.lm_head.weight, input_metadata
401
399
  )
402
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
- return sample_output, logits_output
404
400
 
405
401
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
406
402
  stacked_params_mapping = [
@@ -45,7 +45,6 @@ from sglang.srt.layers.activation import SiluAndMul
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
@@ -633,7 +632,6 @@ class DeepseekV2ForCausalLM(nn.Module):
633
632
  config.vocab_size, config.hidden_size, quant_config=quant_config
634
633
  )
635
634
  self.logits_processor = LogitsProcessor(config)
636
- self.sampler = Sampler()
637
635
 
638
636
  def forward(
639
637
  self,
@@ -642,11 +640,9 @@ class DeepseekV2ForCausalLM(nn.Module):
642
640
  input_metadata: InputMetadata,
643
641
  ) -> torch.Tensor:
644
642
  hidden_states = self.model(input_ids, positions, input_metadata)
645
- logits_output = self.logits_processor(
643
+ return self.logits_processor(
646
644
  input_ids, hidden_states, self.lm_head.weight, input_metadata
647
645
  )
648
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
649
- return sample_output, logits_output
650
646
 
651
647
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
652
648
  stacked_params_mapping = [
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import GeluAndMul
27
26
  from vllm.model_executor.layers.linear import (
28
27
  MergedColumnParallelLinear,
29
28
  QKVParallelLinear,
@@ -34,10 +33,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
34
33
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
35
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
35
 
36
+ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.layers.sampler import Sampler
41
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
41
 
43
42
 
@@ -61,7 +60,7 @@ class GemmaMLP(nn.Module):
61
60
  bias=False,
62
61
  quant_config=quant_config,
63
62
  )
64
- self.act_fn = GeluAndMul()
63
+ self.act_fn = GeluAndMul("none")
65
64
 
66
65
  def forward(self, x):
67
66
  gate_up, _ = self.gate_up_proj(x)
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
288
287
  self.quant_config = quant_config
289
288
  self.model = GemmaModel(config, quant_config=quant_config)
290
289
  self.logits_processor = LogitsProcessor(config)
291
- self.sampler = Sampler()
292
290
 
293
291
  @torch.no_grad()
294
292
  def forward(
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
299
297
  input_embeds: torch.Tensor = None,
300
298
  ) -> torch.Tensor:
301
299
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
302
- logits_output = self.logits_processor(
300
+ return self.logits_processor(
303
301
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
304
302
  )
305
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
- return (sample_output, logits_output)
307
303
 
308
304
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
309
305
  stacked_params_mapping = [
@@ -22,11 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.config import CacheConfig, LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
-
26
- # FIXME: temporary solution, remove after next vllm release
27
- from vllm.model_executor.custom_op import CustomOp
28
-
29
- # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
30
25
  from vllm.model_executor.layers.linear import (
31
26
  MergedColumnParallelLinear,
32
27
  QKVParallelLinear,
@@ -39,9 +34,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
39
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
35
 
41
36
  from sglang.srt.layers.activation import GeluAndMul
37
+ from sglang.srt.layers.layernorm import GemmaRMSNorm
42
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
39
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.sampler import Sampler
45
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
41
 
47
42
 
@@ -51,52 +46,6 @@ def get_attention_sliding_window_size(config):
51
46
  return config.sliding_window - 1
52
47
 
53
48
 
54
- class GemmaRMSNorm(CustomOp):
55
- """RMS normalization for Gemma.
56
-
57
- Two differences from the above RMSNorm:
58
- 1. x * (1 + w) instead of x * w.
59
- 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
60
- """
61
-
62
- def __init__(
63
- self,
64
- hidden_size: int,
65
- eps: float = 1e-6,
66
- ) -> None:
67
- super().__init__()
68
- self.weight = nn.Parameter(torch.zeros(hidden_size))
69
- self.variance_epsilon = eps
70
-
71
- def forward_native(
72
- self,
73
- x: torch.Tensor,
74
- residual: Optional[torch.Tensor] = None,
75
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
76
- """PyTorch-native implementation equivalent to forward()."""
77
- orig_dtype = x.dtype
78
- if residual is not None:
79
- x = x + residual
80
- residual = x
81
-
82
- x = x.float()
83
- variance = x.pow(2).mean(dim=-1, keepdim=True)
84
- x = x * torch.rsqrt(variance + self.variance_epsilon)
85
- # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
86
- # See https://github.com/huggingface/transformers/pull/29402
87
- x = x * (1.0 + self.weight.float())
88
- x = x.to(orig_dtype)
89
- return x if residual is None else (x, residual)
90
-
91
- def forward_cuda(
92
- self,
93
- x: torch.Tensor,
94
- residual: Optional[torch.Tensor] = None,
95
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
96
- # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
97
- return self.forward_native(x, residual)
98
-
99
-
100
49
  # FIXME: temporary solution, remove after next vllm release
101
50
  from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
102
51
 
@@ -397,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
397
346
  self.quant_config = quant_config
398
347
  self.model = Gemma2Model(config, cache_config, quant_config)
399
348
  self.logits_processor = LogitsProcessor(config)
400
- self.sampler = Sampler()
401
349
 
402
350
  @torch.no_grad()
403
351
  def forward(
@@ -408,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
408
356
  input_embeds: torch.Tensor = None,
409
357
  ) -> torch.Tensor:
410
358
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
411
- logits_output = self.logits_processor(
359
+ return self.logits_processor(
412
360
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
413
361
  )
414
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
415
- return sample_output, logits_output
416
362
 
417
363
  def get_attention_sliding_window_size(self):
418
364
  return get_attention_sliding_window_size(self.config)