sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,13 @@ limitations under the License.
15
15
 
16
16
  """ModelRunner runs the forward passes of the models."""
17
17
 
18
+ import gc
18
19
  import importlib
19
20
  import importlib.resources
20
21
  import logging
21
22
  import pkgutil
22
- import warnings
23
23
  from functools import lru_cache
24
- from typing import Optional, Type
24
+ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -37,11 +37,15 @@ from vllm.distributed import (
37
37
  get_tp_group,
38
38
  init_distributed_environment,
39
39
  initialize_model_parallel,
40
+ set_custom_all_reduce,
40
41
  )
42
+ from vllm.distributed.parallel_state import in_the_same_node_as
41
43
  from vllm.model_executor.model_loader import get_model
42
44
  from vllm.model_executor.models import ModelRegistry
43
45
 
44
46
  from sglang.global_config import global_config
47
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
+ from sglang.srt.layers.sampler import SampleOutput
45
49
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
46
50
  from sglang.srt.mem_cache.memory_pool import (
47
51
  MHATokenToKVPool,
@@ -88,22 +92,35 @@ class ModelRunner:
88
92
  {
89
93
  "disable_flashinfer": server_args.disable_flashinfer,
90
94
  "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
91
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
95
+ "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
92
96
  "enable_mla": server_args.enable_mla,
93
97
  }
94
98
  )
95
99
 
100
+ min_per_gpu_memory = self.init_torch_distributed()
101
+ self.load_model()
102
+ self.init_memory_pool(
103
+ min_per_gpu_memory,
104
+ server_args.max_num_reqs,
105
+ server_args.max_total_tokens,
106
+ )
107
+ self.init_cublas()
108
+ self.init_flashinfer()
109
+ self.init_cuda_graphs()
110
+
111
+ def init_torch_distributed(self):
96
112
  # Init torch distributed
97
113
  torch.cuda.set_device(self.gpu_id)
98
- logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
114
+ logger.info("Init nccl begin.")
99
115
 
100
- if not server_args.enable_p2p_check:
116
+ if not self.server_args.enable_p2p_check:
101
117
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
102
118
 
103
- if server_args.nccl_init_addr:
104
- nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
119
+ if self.server_args.nccl_init_addr:
120
+ nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
105
121
  else:
106
122
  nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
123
+ set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
107
124
  init_distributed_environment(
108
125
  backend="nccl",
109
126
  world_size=self.tp_size,
@@ -112,43 +129,43 @@ class ModelRunner:
112
129
  distributed_init_method=nccl_init_method,
113
130
  )
114
131
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
115
- self.tp_group = get_tp_group()
116
- total_gpu_memory = get_available_gpu_memory(
132
+ min_per_gpu_memory = get_available_gpu_memory(
117
133
  self.gpu_id, distributed=self.tp_size > 1
118
134
  )
135
+ self.tp_group = get_tp_group()
136
+
137
+ # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
138
+ # so we disable padding in cuda graph.
139
+ if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
140
+ self.server_args.disable_cuda_graph_padding = True
141
+ logger.info(
142
+ "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
143
+ )
119
144
 
145
+ # Check memory for tensor parallelism
120
146
  if self.tp_size > 1:
121
- total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
122
- if total_local_gpu_memory < total_gpu_memory * 0.9:
147
+ local_gpu_memory = get_available_gpu_memory(self.gpu_id)
148
+ if min_per_gpu_memory < local_gpu_memory * 0.9:
123
149
  raise ValueError(
124
150
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
125
151
  )
126
152
 
127
- # Load the model and create memory pool
128
- self.load_model()
129
- self.init_memory_pool(
130
- total_gpu_memory,
131
- server_args.max_num_reqs,
132
- server_args.max_total_tokens,
133
- )
134
- self.init_cublas()
135
- self.init_flashinfer()
136
-
137
- if self.is_generation:
138
- # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
139
- # Capture cuda graphs
140
- self.init_cuda_graphs()
153
+ return min_per_gpu_memory
141
154
 
142
155
  def load_model(self):
143
156
  logger.info(
144
- f"[gpu={self.gpu_id}] Load weight begin. "
145
- f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
157
+ f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
146
158
  )
159
+ if torch.cuda.get_device_capability()[0] < 8:
160
+ logger.info(
161
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
162
+ )
163
+ self.server_args.dtype = "float16"
147
164
 
148
165
  monkey_patch_vllm_dummy_weight_loader()
149
- device_config = DeviceConfig()
150
- load_config = LoadConfig(load_format=self.server_args.load_format)
151
- vllm_model_config = VllmModelConfig(
166
+ self.device_config = DeviceConfig()
167
+ self.load_config = LoadConfig(load_format=self.server_args.load_format)
168
+ self.vllm_model_config = VllmModelConfig(
152
169
  model=self.server_args.model_path,
153
170
  quantization=self.server_args.quantization,
154
171
  tokenizer=None,
@@ -159,43 +176,132 @@ class ModelRunner:
159
176
  skip_tokenizer_init=True,
160
177
  )
161
178
 
179
+ # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
180
+ # Drop this after Sept, 2024.
162
181
  if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
163
- # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
164
182
  self.model_config.hf_config.num_key_value_heads = 8
165
- vllm_model_config.hf_config.num_key_value_heads = 8
183
+ self.vllm_model_config.hf_config.num_key_value_heads = 8
166
184
  monkey_patch_vllm_qvk_linear_loader()
167
185
 
168
- self.dtype = vllm_model_config.dtype
186
+ self.dtype = self.vllm_model_config.dtype
169
187
  if self.model_config.model_overide_args is not None:
170
- vllm_model_config.hf_config.update(self.model_config.model_overide_args)
188
+ self.vllm_model_config.hf_config.update(
189
+ self.model_config.model_overide_args
190
+ )
171
191
 
172
192
  self.model = get_model(
173
- model_config=vllm_model_config,
174
- device_config=device_config,
175
- load_config=load_config,
176
- lora_config=None,
177
- multimodal_config=None,
193
+ model_config=self.vllm_model_config,
194
+ load_config=self.load_config,
195
+ device_config=self.device_config,
178
196
  parallel_config=None,
179
197
  scheduler_config=None,
198
+ lora_config=None,
180
199
  cache_config=None,
181
200
  )
182
201
  self.sliding_window_size = (
183
- self.model.get_window_size()
184
- if hasattr(self.model, "get_window_size")
202
+ self.model.get_attention_sliding_window_size()
203
+ if hasattr(self.model, "get_attention_sliding_window_size")
185
204
  else None
186
205
  )
187
206
  self.is_generation = is_generation_model(
188
- self.model_config.hf_config.architectures
207
+ self.model_config.hf_config.architectures, self.server_args.is_embedding
189
208
  )
190
209
 
191
210
  logger.info(
192
- f"[gpu={self.gpu_id}] Load weight end. "
211
+ f"Load weight end. "
193
212
  f"type={type(self.model).__name__}, "
194
213
  f"dtype={self.dtype}, "
195
214
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
196
215
  )
197
216
 
198
- def profile_max_num_token(self, total_gpu_memory):
217
+ def update_weights(self, model_path: str, load_format: str):
218
+ """Update weights in-place."""
219
+ from vllm.model_executor.model_loader.loader import (
220
+ DefaultModelLoader,
221
+ device_loading_context,
222
+ get_model_loader,
223
+ )
224
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
225
+
226
+ logger.info(
227
+ f"Update weights begin. "
228
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
229
+ )
230
+
231
+ target_device = torch.device(self.device_config.device)
232
+
233
+ try:
234
+ # TODO: Use a better method to check this
235
+ vllm_model_config = VllmModelConfig(
236
+ model=model_path,
237
+ quantization=self.server_args.quantization,
238
+ tokenizer=None,
239
+ tokenizer_mode=None,
240
+ trust_remote_code=self.server_args.trust_remote_code,
241
+ dtype=self.server_args.dtype,
242
+ seed=42,
243
+ skip_tokenizer_init=True,
244
+ )
245
+ except Exception as e:
246
+ logger.error(f"Failed to load model config: {e}")
247
+ return False, "Failed to update model weights"
248
+
249
+ load_config = LoadConfig(load_format=load_format)
250
+
251
+ # Only support vllm DefaultModelLoader for now
252
+ loader = get_model_loader(load_config)
253
+ if not isinstance(loader, DefaultModelLoader):
254
+ logger.error("Failed to get weights iterator: Unsupported loader")
255
+ return False, "Failed to update model weights"
256
+
257
+ def get_weight_iter(config):
258
+ iter = loader._get_weights_iterator(
259
+ config.model,
260
+ config.revision,
261
+ fall_back_to_pt=getattr(
262
+ self.model, "fall_back_to_pt_during_load", True
263
+ ),
264
+ )
265
+ return iter
266
+
267
+ def model_load_weights(model, iter):
268
+ model.load_weights(iter)
269
+ for _, module in self.model.named_modules():
270
+ quant_method = getattr(module, "quant_method", None)
271
+ if quant_method is not None:
272
+ with device_loading_context(module, target_device):
273
+ quant_method.process_weights_after_loading(module)
274
+ return model
275
+
276
+ with set_default_torch_dtype(vllm_model_config.dtype):
277
+ try:
278
+ iter = get_weight_iter(vllm_model_config)
279
+ except Exception as e:
280
+ message = f"Failed to get weights iterator: {e}"
281
+ logger.error(message)
282
+ return False, message
283
+ try:
284
+ model = model_load_weights(self.model, iter)
285
+ except Exception as e:
286
+ message = f"Failed to update weights: {e}. \n Rolling back to original weights"
287
+ logger.error(message)
288
+ del iter
289
+ gc.collect()
290
+ iter = get_weight_iter(self.vllm_model_config)
291
+ self.model = model_load_weights(self.model, iter)
292
+ return False, message
293
+
294
+ self.model = model
295
+ self.server_args.model_path = model_path
296
+ self.server_args.load_format = load_format
297
+ self.vllm_model_config = vllm_model_config
298
+ self.load_config = load_config
299
+ self.model_config.path = model_path
300
+
301
+ logger.info("Update weights end.")
302
+ return True, "Succeeded to update model weights"
303
+
304
+ def profile_max_num_token(self, total_gpu_memory: int):
199
305
  available_gpu_memory = get_available_gpu_memory(
200
306
  self.gpu_id, distributed=self.tp_size > 1
201
307
  )
@@ -206,7 +312,7 @@ class ModelRunner:
206
312
  cell_size = (
207
313
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
208
314
  * self.model_config.num_hidden_layers
209
- * torch._utils._element_size(self.dtype)
315
+ * torch._utils._element_size(self.kv_cache_dtype)
210
316
  )
211
317
  else:
212
318
  cell_size = (
@@ -214,7 +320,7 @@ class ModelRunner:
214
320
  * self.model_config.head_dim
215
321
  * self.model_config.num_hidden_layers
216
322
  * 2
217
- * torch._utils._element_size(self.dtype)
323
+ * torch._utils._element_size(self.kv_cache_dtype)
218
324
  )
219
325
  rest_memory = available_gpu_memory - total_gpu_memory * (
220
326
  1 - self.mem_fraction_static
@@ -223,12 +329,30 @@ class ModelRunner:
223
329
  return max_num_token
224
330
 
225
331
  def init_memory_pool(
226
- self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
332
+ self,
333
+ total_gpu_memory: int,
334
+ max_num_reqs: int = None,
335
+ max_total_tokens: int = None,
227
336
  ):
337
+ if self.server_args.kv_cache_dtype == "auto":
338
+ self.kv_cache_dtype = self.dtype
339
+ elif self.server_args.kv_cache_dtype == "fp8_e5m2":
340
+ if self.server_args.disable_flashinfer or self.server_args.enable_mla:
341
+ logger.warning(
342
+ "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
343
+ )
344
+ self.kv_cache_dtype = self.dtype
345
+ else:
346
+ self.kv_cache_dtype = torch.float8_e5m2
347
+ else:
348
+ raise ValueError(
349
+ f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
350
+ )
351
+
228
352
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
229
353
  if max_total_tokens is not None:
230
354
  if max_total_tokens > self.max_total_num_tokens:
231
- warnings.warn(
355
+ logging.warning(
232
356
  f"max_total_tokens={max_total_tokens} is larger than the profiled value "
233
357
  f"{self.max_total_num_tokens}. "
234
358
  f"Use the profiled value instead."
@@ -261,7 +385,7 @@ class ModelRunner:
261
385
  ):
262
386
  self.token_to_kv_pool = MLATokenToKVPool(
263
387
  self.max_total_num_tokens,
264
- dtype=self.dtype,
388
+ dtype=self.kv_cache_dtype,
265
389
  kv_lora_rank=self.model_config.kv_lora_rank,
266
390
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
267
391
  layer_num=self.model_config.num_hidden_layers,
@@ -272,13 +396,13 @@ class ModelRunner:
272
396
  else:
273
397
  self.token_to_kv_pool = MHATokenToKVPool(
274
398
  self.max_total_num_tokens,
275
- dtype=self.dtype,
399
+ dtype=self.kv_cache_dtype,
276
400
  head_num=self.model_config.get_num_kv_heads(self.tp_size),
277
401
  head_dim=self.model_config.head_dim,
278
402
  layer_num=self.model_config.num_hidden_layers,
279
403
  )
280
404
  logger.info(
281
- f"[gpu={self.gpu_id}] Memory pool end. "
405
+ f"Memory pool end. "
282
406
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
283
407
  )
284
408
 
@@ -292,6 +416,7 @@ class ModelRunner:
292
416
  return c
293
417
 
294
418
  def init_flashinfer(self):
419
+ """Init flashinfer attention kernel wrappers."""
295
420
  if self.server_args.disable_flashinfer:
296
421
  assert (
297
422
  self.sliding_window_size is None
@@ -352,20 +477,29 @@ class ModelRunner:
352
477
  )
353
478
 
354
479
  def init_cuda_graphs(self):
480
+ """Capture cuda graphs."""
481
+ if not self.is_generation:
482
+ # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
483
+ return
484
+
355
485
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
356
486
 
357
487
  if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
358
488
  self.cuda_graph_runner = None
359
489
  return
360
490
 
361
- logger.info(
362
- f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
363
- )
364
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
491
+ logger.info("Capture cuda graph begin. This can take up to several minutes.")
492
+
493
+ if self.server_args.disable_cuda_graph_padding:
494
+ batch_size_list = list(range(1, 32)) + [64, 128]
495
+ else:
496
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
497
+
365
498
  self.cuda_graph_runner = CudaGraphRunner(
366
499
  self,
367
500
  max_batch_size_to_capture=max(batch_size_list),
368
501
  use_torch_compile=self.server_args.enable_torch_compile,
502
+ disable_padding=self.server_args.disable_cuda_graph_padding,
369
503
  )
370
504
  try:
371
505
  self.cuda_graph_runner.capture(batch_size_list)
@@ -381,7 +515,11 @@ class ModelRunner:
381
515
 
382
516
  @torch.inference_mode()
383
517
  def forward_decode(self, batch: ScheduleBatch):
384
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
518
+ if (
519
+ self.cuda_graph_runner
520
+ and self.cuda_graph_runner.can_run(len(batch.reqs))
521
+ and not batch.sampling_info.has_bias()
522
+ ):
385
523
  return self.cuda_graph_runner.replay(batch)
386
524
 
387
525
  input_metadata = InputMetadata.from_schedule_batch(
@@ -401,9 +539,18 @@ class ModelRunner:
401
539
  batch,
402
540
  forward_mode=ForwardMode.EXTEND,
403
541
  )
404
- return self.model.forward(
405
- batch.input_ids, input_metadata.positions, input_metadata
406
- )
542
+ if self.is_generation:
543
+ return self.model.forward(
544
+ batch.input_ids, input_metadata.positions, input_metadata
545
+ )
546
+ else:
547
+ # Only embedding models have get_embedding parameter
548
+ return self.model.forward(
549
+ batch.input_ids,
550
+ input_metadata.positions,
551
+ input_metadata,
552
+ get_embedding=True,
553
+ )
407
554
 
408
555
  @torch.inference_mode()
409
556
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
@@ -421,7 +568,9 @@ class ModelRunner:
421
568
  input_metadata.image_offsets,
422
569
  )
423
570
 
424
- def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
571
+ def forward(
572
+ self, batch: ScheduleBatch, forward_mode: ForwardMode
573
+ ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
425
574
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
426
575
  return self.forward_extend_multi_modal(batch)
427
576
  elif forward_mode == ForwardMode.DECODE:
@@ -477,4 +626,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
477
626
 
478
627
 
479
628
  # Monkey patch model loader
480
- setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
629
+ setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
@@ -31,20 +31,18 @@ from vllm.model_executor.layers.linear import (
31
31
  )
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
34
- from vllm.model_executor.layers.sampler import Sampler
35
34
  from vllm.model_executor.layers.vocab_parallel_embedding import (
36
35
  ParallelLMHead,
37
36
  VocabParallelEmbedding,
38
37
  )
39
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
- from vllm.model_executor.sampling_metadata import SamplingMetadata
41
- from vllm.sequence import SamplerOutput
42
39
  from vllm.transformers_utils.configs import ChatGLMConfig
43
40
 
44
41
  from sglang.srt.layers.activation import SiluAndMul
45
42
  from sglang.srt.layers.layernorm import RMSNorm
46
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.layers.sampler import Sampler
48
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
47
 
50
48
  LoraConfig = None
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
383
381
  input_metadata: InputMetadata,
384
382
  ) -> torch.Tensor:
385
383
  hidden_states = self.transformer(input_ids, positions, input_metadata)
386
- return self.logits_processor(
384
+ logits_output = self.logits_processor(
387
385
  input_ids, hidden_states, self.lm_head.weight, input_metadata
388
386
  )
389
-
390
- def sample(
391
- self,
392
- logits: torch.Tensor,
393
- sampling_metadata: SamplingMetadata,
394
- ) -> Optional[SamplerOutput]:
395
- next_tokens = self.sampler(logits, sampling_metadata)
396
- return next_tokens
387
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
+ return sample_output, logits_output
397
389
 
398
390
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
391
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -64,6 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
  from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
+ from sglang.srt.layers.sampler import Sampler
67
68
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
69
 
69
70
 
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
326
327
  self.config = config
327
328
  self.quant_config = quant_config
328
329
  self.logits_processor = LogitsProcessor(config)
330
+ self.sampler = Sampler()
329
331
  self.model = CohereModel(config, quant_config)
330
332
 
331
333
  @torch.no_grad()
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
340
342
  positions,
341
343
  input_metadata,
342
344
  )
343
- return self.logits_processor(
345
+ logits_output = self.logits_processor(
344
346
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
345
347
  )
348
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
+ return sample_output, logits_output
346
350
 
347
351
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
348
352
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
382
383
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
383
384
  )
384
385
  self.logits_processor = LogitsProcessor(config)
386
+ self.sampler = Sampler()
385
387
 
386
388
  @torch.no_grad()
387
389
  def forward(
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
391
393
  input_metadata: InputMetadata,
392
394
  ) -> torch.Tensor:
393
395
  hidden_states = self.transformer(input_ids, positions, input_metadata)
394
- return self.logits_processor(
396
+ logits_output = self.logits_processor(
395
397
  input_ids, hidden_states, self.lm_head.weight, input_metadata
396
398
  )
399
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
+ return sample_output, logits_output
397
401
 
398
402
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
403
  expert_params_mapping = [
@@ -46,6 +46,7 @@ from sglang.srt.layers.activation import SiluAndMul
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -385,6 +386,7 @@ class DeepseekForCausalLM(nn.Module):
385
386
  config.vocab_size, config.hidden_size, quant_config=quant_config
386
387
  )
387
388
  self.logits_processor = LogitsProcessor(config)
389
+ self.sampler = Sampler()
388
390
 
389
391
  @torch.no_grad()
390
392
  def forward(
@@ -394,9 +396,11 @@ class DeepseekForCausalLM(nn.Module):
394
396
  input_metadata: InputMetadata,
395
397
  ) -> torch.Tensor:
396
398
  hidden_states = self.model(input_ids, positions, input_metadata)
397
- return self.logits_processor(
399
+ logits_output = self.logits_processor(
398
400
  input_ids, hidden_states, self.lm_head.weight, input_metadata
399
401
  )
402
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
403
+ return sample_output, logits_output
400
404
 
401
405
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
402
406
  stacked_params_mapping = [
@@ -45,6 +45,7 @@ from sglang.srt.layers.activation import SiluAndMul
45
45
  from sglang.srt.layers.layernorm import RMSNorm
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
@@ -632,6 +633,7 @@ class DeepseekV2ForCausalLM(nn.Module):
632
633
  config.vocab_size, config.hidden_size, quant_config=quant_config
633
634
  )
634
635
  self.logits_processor = LogitsProcessor(config)
636
+ self.sampler = Sampler()
635
637
 
636
638
  def forward(
637
639
  self,
@@ -640,9 +642,11 @@ class DeepseekV2ForCausalLM(nn.Module):
640
642
  input_metadata: InputMetadata,
641
643
  ) -> torch.Tensor:
642
644
  hidden_states = self.model(input_ids, positions, input_metadata)
643
- return self.logits_processor(
645
+ logits_output = self.logits_processor(
644
646
  input_ids, hidden_states, self.lm_head.weight, input_metadata
645
647
  )
648
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
649
+ return sample_output, logits_output
646
650
 
647
651
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
648
652
  stacked_params_mapping = [
@@ -37,6 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.sampler import Sampler
40
41
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
42
 
42
43
 
@@ -287,6 +288,7 @@ class GemmaForCausalLM(nn.Module):
287
288
  self.quant_config = quant_config
288
289
  self.model = GemmaModel(config, quant_config=quant_config)
289
290
  self.logits_processor = LogitsProcessor(config)
291
+ self.sampler = Sampler()
290
292
 
291
293
  @torch.no_grad()
292
294
  def forward(
@@ -297,9 +299,11 @@ class GemmaForCausalLM(nn.Module):
297
299
  input_embeds: torch.Tensor = None,
298
300
  ) -> torch.Tensor:
299
301
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
300
- return self.logits_processor(
302
+ logits_output = self.logits_processor(
301
303
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
302
304
  )
305
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
306
+ return (sample_output, logits_output)
303
307
 
304
308
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
305
309
  stacked_params_mapping = [