sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,416 @@
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """Inference-only BaiChuan model compatible with HuggingFace weights."""
21
+ import math
22
+ from typing import Iterable, Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import PretrainedConfig
27
+ from vllm.config import CacheConfig
28
+ from vllm.distributed import (
29
+ get_tensor_model_parallel_rank,
30
+ get_tensor_model_parallel_world_size,
31
+ )
32
+ from vllm.model_executor.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
+ from vllm.model_executor.layers.rotary_embedding import get_rope
39
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead,
41
+ VocabParallelEmbedding,
42
+ )
43
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
+
45
+ from sglang.srt.layers.activation import SiluAndMul
46
+ from sglang.srt.layers.layernorm import RMSNorm
47
+ from sglang.srt.layers.logits_processor import LogitsProcessor
48
+ from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
+
51
+
52
+ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
53
+ closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
54
+ base = torch.tensor(
55
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
56
+ dtype=torch.float32,
57
+ )
58
+ powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
59
+ slopes = torch.pow(base, powers)
60
+
61
+ if closest_power_of_2 != total_num_heads:
62
+ extra_base = torch.tensor(
63
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
64
+ dtype=torch.float32,
65
+ )
66
+ num_remaining_heads = min(
67
+ closest_power_of_2, total_num_heads - closest_power_of_2
68
+ )
69
+ extra_powers = torch.arange(
70
+ start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
71
+ )
72
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
73
+ return slopes
74
+
75
+
76
+ class BaiChuanMLP(nn.Module):
77
+
78
+ def __init__(
79
+ self,
80
+ hidden_size: int,
81
+ intermediate_size: int,
82
+ hidden_act: str,
83
+ quant_config: Optional[QuantizationConfig] = None,
84
+ ):
85
+ super().__init__()
86
+ self.gate_up_proj = MergedColumnParallelLinear(
87
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
88
+ )
89
+ self.down_proj = RowParallelLinear(
90
+ intermediate_size, hidden_size, bias=False, quant_config=quant_config
91
+ )
92
+ if hidden_act != "silu":
93
+ raise ValueError(
94
+ f"Unsupported activation: {hidden_act}. "
95
+ "Only silu is supported for now."
96
+ )
97
+ self.act_fn = SiluAndMul()
98
+
99
+ def forward(self, x):
100
+ gate_up, _ = self.gate_up_proj(x)
101
+ x = self.act_fn(gate_up)
102
+ x, _ = self.down_proj(x)
103
+ return x
104
+
105
+
106
+ class BaiChuanAttention(nn.Module):
107
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
108
+
109
+ def __init__(
110
+ self,
111
+ hidden_size: int,
112
+ num_heads: int,
113
+ position_embedding: str,
114
+ rope_theta: float = 10000,
115
+ max_position_embeddings: int = 8192,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ layer_id: int = 0,
118
+ ):
119
+ super().__init__()
120
+ self.hidden_size = hidden_size
121
+ tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
122
+ tp_size = get_tensor_model_parallel_world_size()
123
+ self.total_num_heads = num_heads
124
+ assert self.total_num_heads % tensor_model_parallel_world_size == 0
125
+ self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
126
+ self.head_dim = hidden_size // self.total_num_heads
127
+ self.postion_embedding = position_embedding
128
+ self.rope_theta = rope_theta
129
+ self.max_position_embeddings = max_position_embeddings
130
+ self.total_num_kv_heads = self.num_heads
131
+ if self.total_num_kv_heads >= tp_size:
132
+ # Number of KV heads is greater than TP size, so we partition
133
+ # the KV heads across multiple tensor parallel GPUs.
134
+ assert self.total_num_kv_heads % tp_size == 0
135
+ else:
136
+ # Number of KV heads is less than TP size, so we replicate
137
+ # the KV heads across multiple tensor parallel GPUs.
138
+ assert tp_size % self.total_num_kv_heads == 0
139
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
140
+
141
+ # pylint: disable=invalid-name
142
+ self.W_pack = QKVParallelLinear(
143
+ hidden_size,
144
+ self.head_dim,
145
+ self.total_num_heads,
146
+ self.total_num_heads,
147
+ bias=False,
148
+ quant_config=quant_config,
149
+ )
150
+ self.o_proj = RowParallelLinear(
151
+ self.total_num_heads * self.head_dim,
152
+ hidden_size,
153
+ bias=False,
154
+ quant_config=quant_config,
155
+ )
156
+ # Create the alibi slopes and slice them.
157
+ if self.postion_embedding == "ALIBI":
158
+ tp_rank = get_tensor_model_parallel_rank()
159
+ head_start = tp_rank * self.num_heads
160
+ head_end = (tp_rank + 1) * self.num_heads
161
+ alibi_slopes = _get_alibi_slopes(self.total_num_heads)
162
+ alibi_slopes = alibi_slopes[head_start:head_end].tolist()
163
+
164
+ scaling = self.head_dim**-0.5
165
+ self.attn = RadixAttention(
166
+ self.num_heads,
167
+ self.head_dim,
168
+ scaling,
169
+ num_kv_heads=self.num_kv_heads,
170
+ layer_id=layer_id,
171
+ )
172
+ else:
173
+ self.rotary_emb = get_rope(
174
+ self.head_dim,
175
+ rotary_dim=self.head_dim,
176
+ max_position=self.max_position_embeddings,
177
+ base=self.rope_theta,
178
+ )
179
+ self.scaling = self.head_dim**-0.5
180
+ self.attn = RadixAttention(
181
+ self.num_heads,
182
+ self.head_dim,
183
+ self.scaling,
184
+ num_kv_heads=self.num_kv_heads,
185
+ layer_id=layer_id,
186
+ )
187
+
188
+ def forward(
189
+ self,
190
+ positions: torch.Tensor,
191
+ hidden_states: torch.Tensor,
192
+ input_metadata: InputMetadata,
193
+ ) -> torch.Tensor:
194
+ qkv, _ = self.W_pack(hidden_states)
195
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
196
+ if self.postion_embedding != "ALIBI":
197
+ q, k = self.rotary_emb(positions, q, k)
198
+ attn_output = self.attn(q, k, v, input_metadata)
199
+ output, _ = self.o_proj(attn_output)
200
+ return output
201
+
202
+
203
+ class BaiChuanDecoderLayer(nn.Module):
204
+
205
+ def __init__(
206
+ self,
207
+ config: PretrainedConfig,
208
+ position_embedding: str,
209
+ layer_id: int = 0,
210
+ quant_config: Optional[QuantizationConfig] = None,
211
+ ):
212
+ super().__init__()
213
+ self.hidden_size = config.hidden_size
214
+ rope_theta = getattr(config, "rope_theta", 10000)
215
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
216
+ self.self_attn = BaiChuanAttention(
217
+ hidden_size=self.hidden_size,
218
+ num_heads=config.num_attention_heads,
219
+ position_embedding=position_embedding,
220
+ rope_theta=rope_theta,
221
+ layer_id=layer_id,
222
+ max_position_embeddings=max_position_embeddings,
223
+ quant_config=quant_config,
224
+ )
225
+ self.mlp = BaiChuanMLP(
226
+ hidden_size=self.hidden_size,
227
+ intermediate_size=config.intermediate_size,
228
+ hidden_act=config.hidden_act,
229
+ quant_config=quant_config,
230
+ )
231
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
232
+ self.post_attention_layernorm = RMSNorm(
233
+ config.hidden_size, eps=config.rms_norm_eps
234
+ )
235
+
236
+ def forward(
237
+ self,
238
+ positions: torch.Tensor,
239
+ hidden_states: torch.Tensor,
240
+ input_metadata: InputMetadata,
241
+ residual: Optional[torch.Tensor],
242
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
243
+ # Self Attention
244
+ if residual is None:
245
+ residual = hidden_states
246
+ hidden_states = self.input_layernorm(hidden_states)
247
+ else:
248
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
249
+ hidden_states = self.self_attn(
250
+ positions=positions,
251
+ hidden_states=hidden_states,
252
+ input_metadata=input_metadata,
253
+ )
254
+
255
+ # Fully Connected
256
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
257
+ hidden_states = self.mlp(hidden_states)
258
+ return hidden_states, residual
259
+
260
+
261
+ class BaiChuanModel(nn.Module):
262
+
263
+ def __init__(
264
+ self,
265
+ config: PretrainedConfig,
266
+ position_embedding: str,
267
+ quant_config: Optional[QuantizationConfig] = None,
268
+ ):
269
+ super().__init__()
270
+ self.config = config
271
+ self.padding_idx = config.pad_token_id
272
+ self.vocab_size = config.vocab_size
273
+
274
+ self.embed_tokens = VocabParallelEmbedding(
275
+ config.vocab_size,
276
+ config.hidden_size,
277
+ )
278
+ self.layers = nn.ModuleList(
279
+ [
280
+ BaiChuanDecoderLayer(
281
+ config,
282
+ layer_id=i,
283
+ position_embedding=position_embedding,
284
+ quant_config=quant_config,
285
+ )
286
+ for i in range(config.num_hidden_layers)
287
+ ]
288
+ )
289
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
+
291
+ def forward(
292
+ self,
293
+ input_ids: torch.Tensor,
294
+ positions: torch.Tensor,
295
+ input_metadata: InputMetadata,
296
+ ) -> torch.Tensor:
297
+ hidden_states = self.embed_tokens(input_ids)
298
+ residual = None
299
+ for i in range(len(self.layers)):
300
+ layer = self.layers[i]
301
+ hidden_states, residual = layer(
302
+ positions,
303
+ hidden_states,
304
+ input_metadata,
305
+ residual,
306
+ )
307
+ hidden_states, _ = self.norm(hidden_states, residual)
308
+ return hidden_states
309
+
310
+
311
+ class BaiChuanBaseForCausalLM(nn.Module):
312
+ packed_modules_mapping = {
313
+ "W_pack": ["W_pack"],
314
+ "gate_up_proj": [
315
+ "gate_proj",
316
+ "up_proj",
317
+ ],
318
+ }
319
+ # LoRA specific attributes
320
+ supported_lora_modules = [
321
+ "W_pack",
322
+ "o_proj",
323
+ "gate_up_proj",
324
+ "down_proj",
325
+ ]
326
+ embedding_modules = {}
327
+ embedding_padding_modules = []
328
+
329
+ def __init__(
330
+ self,
331
+ config: PretrainedConfig,
332
+ position_embedding: str,
333
+ cache_config: Optional[CacheConfig] = None,
334
+ quant_config: Optional[QuantizationConfig] = None,
335
+ ):
336
+ super().__init__()
337
+
338
+ self.config = config
339
+
340
+ self.quant_config = quant_config
341
+ self.model = BaiChuanModel(config, position_embedding, quant_config)
342
+ self.lm_head = ParallelLMHead(
343
+ config.vocab_size, config.hidden_size, quant_config=quant_config
344
+ )
345
+ if self.config.tie_word_embeddings:
346
+ self.lm_head.weight = self.model.embed_tokens.weight
347
+ self.logits_processor = LogitsProcessor(config)
348
+
349
+ def forward(
350
+ self,
351
+ input_ids: torch.Tensor,
352
+ positions: torch.Tensor,
353
+ input_metadata: InputMetadata,
354
+ ) -> torch.Tensor:
355
+ hidden_states = self.model(input_ids, positions, input_metadata)
356
+ return self.logits_processor(
357
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
358
+ )
359
+
360
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
361
+ stacked_params_mapping = [
362
+ # (param_name, shard_name, shard_id)
363
+ ("gate_up_proj", "gate_proj", 0),
364
+ ("gate_up_proj", "up_proj", 1),
365
+ ]
366
+ params_dict = dict(self.named_parameters())
367
+ for name, loaded_weight in weights:
368
+ if "rotary_emb.inv_freq" in name:
369
+ continue
370
+ if name == "lm_head.weight":
371
+ # Unlike Baichuan, Baichuan2 normalizes the head weights.
372
+ # Refer to:
373
+ # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
374
+ # Distinguish between Baichuan and Baichuan2 by checking the
375
+ # vocab size. This is suggested by
376
+ # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
377
+ is_baichuan2 = self.config.vocab_size == 125696
378
+ if is_baichuan2:
379
+ loaded_weight = torch.nn.functional.normalize(loaded_weight)
380
+
381
+ for param_name, weight_name, shard_id in stacked_params_mapping:
382
+ if weight_name not in name:
383
+ continue
384
+ name = name.replace(weight_name, param_name)
385
+ # Skip loading extra bias for GPTQ models.
386
+ if name.endswith(".bias") and name not in params_dict:
387
+ continue
388
+ param = params_dict[name]
389
+ weight_loader = param.weight_loader
390
+ weight_loader(param, loaded_weight, shard_id)
391
+ break
392
+ else:
393
+ # Skip loading extra bias for GPTQ models.
394
+ if name.endswith(".bias") and name not in params_dict:
395
+ continue
396
+ param = params_dict[name]
397
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
398
+ weight_loader(param, loaded_weight)
399
+
400
+
401
+ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
402
+ """Baichuan 13B and Baichuan2 7B/13B."""
403
+
404
+ def __init__(
405
+ self,
406
+ config,
407
+ cache_config: Optional[CacheConfig] = None,
408
+ quant_config: Optional[QuantizationConfig] = None,
409
+ ):
410
+ if config.hidden_size == 4096: # baichuan2 7b
411
+ super().__init__(config, "ROPE", cache_config, quant_config)
412
+ else: # baichuan 13b, baichuan2 13b
413
+ super().__init__(config, "ALIBI", cache_config, quant_config)
414
+
415
+
416
+ EntryClass = [BaichuanForCausalLM]
@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
42
42
  from sglang.srt.layers.layernorm import RMSNorm
43
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
- from sglang.srt.layers.sampler import Sampler
46
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
47
46
 
48
47
  LoraConfig = None
@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
371
370
  self.transformer = ChatGLMModel(config, cache_config, quant_config)
372
371
  self.lm_head = self.transformer.output_layer
373
372
  self.logits_processor = LogitsProcessor(config)
374
- self.sampler = Sampler()
375
373
 
376
374
  @torch.no_grad()
377
375
  def forward(
@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
381
379
  input_metadata: InputMetadata,
382
380
  ) -> torch.Tensor:
383
381
  hidden_states = self.transformer(input_ids, positions, input_metadata)
384
- logits_output = self.logits_processor(
382
+ return self.logits_processor(
385
383
  input_ids, hidden_states, self.lm_head.weight, input_metadata
386
384
  )
387
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
- return sample_output, logits_output
389
385
 
390
386
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
391
387
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
  from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
- from sglang.srt.layers.sampler import Sampler
68
67
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
69
68
 
70
69
 
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
327
326
  self.config = config
328
327
  self.quant_config = quant_config
329
328
  self.logits_processor = LogitsProcessor(config)
330
- self.sampler = Sampler()
331
329
  self.model = CohereModel(config, quant_config)
332
330
 
333
331
  @torch.no_grad()
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
342
340
  positions,
343
341
  input_metadata,
344
342
  )
345
- logits_output = self.logits_processor(
343
+ return self.logits_processor(
346
344
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
347
345
  )
348
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
- return sample_output, logits_output
350
346
 
351
347
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
352
348
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.layers.sampler import Sampler
49
48
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
49
 
51
50
 
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
383
382
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
384
383
  )
385
384
  self.logits_processor = LogitsProcessor(config)
386
- self.sampler = Sampler()
387
385
 
388
386
  @torch.no_grad()
389
387
  def forward(
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
393
391
  input_metadata: InputMetadata,
394
392
  ) -> torch.Tensor:
395
393
  hidden_states = self.transformer(input_ids, positions, input_metadata)
396
- logits_output = self.logits_processor(
394
+ return self.logits_processor(
397
395
  input_ids, hidden_states, self.lm_head.weight, input_metadata
398
396
  )
399
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
- return sample_output, logits_output
401
397
 
402
398
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
403
399
  expert_params_mapping = [
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
51
50
 
52
51
 
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
386
385
  config.vocab_size, config.hidden_size, quant_config=quant_config
387
386
  )
388
387
  self.logits_processor = LogitsProcessor(config)
389
- self.sampler = Sampler()
390
388
 
391
389
  @torch.no_grad()
392
390
  def forward(
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
396
394
  input_metadata: InputMetadata,
397
395
  ) -> torch.Tensor:
398
396
  hidden_states = self.model(input_ids, positions, input_metadata)
399
- logits_output = self.logits_processor(
397
+ return self.logits_processor(
400
398
  input_ids, hidden_states, self.lm_head.weight, input_metadata
401
399
  )
402
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
- return sample_output, logits_output
404
400
 
405
401
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
406
402
  stacked_params_mapping = [
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.layers.sampler import Sampler
50
49
  from sglang.srt.managers.schedule_batch import global_server_args_dict
51
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
51
 
@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
649
648
  config.vocab_size, config.hidden_size, quant_config=quant_config
650
649
  )
651
650
  self.logits_processor = LogitsProcessor(config)
652
- self.sampler = Sampler()
653
651
 
654
652
  def forward(
655
653
  self,
@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
658
656
  input_metadata: InputMetadata,
659
657
  ) -> torch.Tensor:
660
658
  hidden_states = self.model(input_ids, positions, input_metadata)
661
- logits_output = self.logits_processor(
659
+ return self.logits_processor(
662
660
  input_ids, hidden_states, self.lm_head.weight, input_metadata
663
661
  )
664
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
665
- return sample_output, logits_output
666
662
 
667
663
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
668
664
  stacked_params_mapping = [
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.layers.sampler import Sampler
44
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
44
 
46
45
 
@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
304
303
  self.transformer = ExaoneModel(config, quant_config=quant_config)
305
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
306
305
  self.logits_processor = LogitsProcessor(config)
307
- self.sampler = Sampler()
308
306
 
309
307
  @torch.no_grad()
310
308
  def forward(
@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
317
315
  hidden_states = self.transformer(
318
316
  input_ids, positions, input_metadata, input_embeds
319
317
  )
320
- logits_output = self.logits_processor(
318
+ return self.logits_processor(
321
319
  input_ids, hidden_states, self.lm_head.weight, input_metadata
322
320
  )
323
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
324
- return sample_output, logits_output
325
321
 
326
322
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
327
323
  stacked_params_mapping = [
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.layers.sampler import Sampler
41
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
41
 
43
42
 
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
288
287
  self.quant_config = quant_config
289
288
  self.model = GemmaModel(config, quant_config=quant_config)
290
289
  self.logits_processor = LogitsProcessor(config)
291
- self.sampler = Sampler()
292
290
 
293
291
  @torch.no_grad()
294
292
  def forward(
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
299
297
  input_embeds: torch.Tensor = None,
300
298
  ) -> torch.Tensor:
301
299
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
302
- logits_output = self.logits_processor(
300
+ return self.logits_processor(
303
301
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
304
302
  )
305
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
- return (sample_output, logits_output)
307
303
 
308
304
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
309
305
  stacked_params_mapping = [
@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import GemmaRMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.layers.sampler import Sampler
41
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
42
41
 
43
42
 
@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
347
346
  self.quant_config = quant_config
348
347
  self.model = Gemma2Model(config, cache_config, quant_config)
349
348
  self.logits_processor = LogitsProcessor(config)
350
- self.sampler = Sampler()
351
349
 
352
350
  @torch.no_grad()
353
351
  def forward(
@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
358
356
  input_embeds: torch.Tensor = None,
359
357
  ) -> torch.Tensor:
360
358
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
361
- logits_output = self.logits_processor(
359
+ return self.logits_processor(
362
360
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
363
361
  )
364
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
365
- return sample_output, logits_output
366
362
 
367
363
  def get_attention_sliding_window_size(self):
368
364
  return get_attention_sliding_window_size(self.config)
@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
  from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.layers.sampler import Sampler
39
38
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
40
39
 
41
40
 
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
262
261
  if lora_config:
263
262
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
264
263
  self.logits_processor = LogitsProcessor(config)
265
- self.sampler = Sampler()
266
264
 
267
265
  @torch.no_grad()
268
266
  def forward(
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
272
270
  input_metadata: InputMetadata,
273
271
  ) -> torch.Tensor:
274
272
  hidden_states = self.transformer(input_ids, positions, input_metadata)
275
- logits_output = self.logits_processor(
273
+ return self.logits_processor(
276
274
  input_ids, hidden_states, self.lm_head.weight, input_metadata
277
275
  )
278
- sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
- return sample_output, logits_output
280
276
 
281
277
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
282
278
  params_dict = dict(self.named_parameters(remove_duplicate=False))