sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -15,13 +15,13 @@ limitations under the License.
15
15
 
16
16
  """ModelRunner runs the forward passes of the models."""
17
17
 
18
+ import gc
18
19
  import importlib
19
20
  import importlib.resources
20
21
  import logging
21
22
  import pkgutil
22
- import warnings
23
23
  from functools import lru_cache
24
- from typing import Optional, Type
24
+ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
@@ -37,10 +37,15 @@ from vllm.distributed import (
37
37
  get_tp_group,
38
38
  init_distributed_environment,
39
39
  initialize_model_parallel,
40
+ set_custom_all_reduce,
40
41
  )
42
+ from vllm.distributed.parallel_state import in_the_same_node_as
43
+ from vllm.model_executor.model_loader import get_model
41
44
  from vllm.model_executor.models import ModelRegistry
42
45
 
43
46
  from sglang.global_config import global_config
47
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
+ from sglang.srt.layers.sampler import SampleOutput
44
49
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
45
50
  from sglang.srt.mem_cache.memory_pool import (
46
51
  MHATokenToKVPool,
@@ -53,7 +58,7 @@ from sglang.srt.server_args import ServerArgs
53
58
  from sglang.srt.utils import (
54
59
  get_available_gpu_memory,
55
60
  is_generation_model,
56
- is_llama3_405b_fp8,
61
+ is_llama3_405b_fp8_head_16,
57
62
  is_multimodal_model,
58
63
  monkey_patch_vllm_dummy_weight_loader,
59
64
  monkey_patch_vllm_p2p_access_check,
@@ -87,22 +92,35 @@ class ModelRunner:
87
92
  {
88
93
  "disable_flashinfer": server_args.disable_flashinfer,
89
94
  "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
90
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
95
+ "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
91
96
  "enable_mla": server_args.enable_mla,
92
97
  }
93
98
  )
94
99
 
100
+ min_per_gpu_memory = self.init_torch_distributed()
101
+ self.load_model()
102
+ self.init_memory_pool(
103
+ min_per_gpu_memory,
104
+ server_args.max_num_reqs,
105
+ server_args.max_total_tokens,
106
+ )
107
+ self.init_cublas()
108
+ self.init_flashinfer()
109
+ self.init_cuda_graphs()
110
+
111
+ def init_torch_distributed(self):
95
112
  # Init torch distributed
96
113
  torch.cuda.set_device(self.gpu_id)
97
- logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
114
+ logger.info("Init nccl begin.")
98
115
 
99
- if not server_args.enable_p2p_check:
116
+ if not self.server_args.enable_p2p_check:
100
117
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
101
118
 
102
- if server_args.nccl_init_addr:
103
- nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
119
+ if self.server_args.nccl_init_addr:
120
+ nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
104
121
  else:
105
122
  nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
123
+ set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
106
124
  init_distributed_environment(
107
125
  backend="nccl",
108
126
  world_size=self.tp_size,
@@ -111,43 +129,43 @@ class ModelRunner:
111
129
  distributed_init_method=nccl_init_method,
112
130
  )
113
131
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
114
- self.tp_group = get_tp_group()
115
- total_gpu_memory = get_available_gpu_memory(
132
+ min_per_gpu_memory = get_available_gpu_memory(
116
133
  self.gpu_id, distributed=self.tp_size > 1
117
134
  )
135
+ self.tp_group = get_tp_group()
118
136
 
137
+ # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
138
+ # so we disable padding in cuda graph.
139
+ if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
140
+ self.server_args.disable_cuda_graph_padding = True
141
+ logger.info(
142
+ "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
143
+ )
144
+
145
+ # Check memory for tensor parallelism
119
146
  if self.tp_size > 1:
120
- total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
121
- if total_local_gpu_memory < total_gpu_memory * 0.9:
147
+ local_gpu_memory = get_available_gpu_memory(self.gpu_id)
148
+ if min_per_gpu_memory < local_gpu_memory * 0.9:
122
149
  raise ValueError(
123
150
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
124
151
  )
125
152
 
126
- # Load the model and create memory pool
127
- self.load_model()
128
- self.init_memory_pool(
129
- total_gpu_memory,
130
- server_args.max_num_reqs,
131
- server_args.max_total_tokens,
132
- )
133
- self.init_cublas()
134
- self.init_flashinfer()
135
-
136
- if self.is_generation:
137
- # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
138
- # Capture cuda graphs
139
- self.init_cuda_graphs()
153
+ return min_per_gpu_memory
140
154
 
141
155
  def load_model(self):
142
156
  logger.info(
143
- f"[gpu={self.gpu_id}] Load weight begin. "
144
- f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
157
+ f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
145
158
  )
159
+ if torch.cuda.get_device_capability()[0] < 8:
160
+ logger.info(
161
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
162
+ )
163
+ self.server_args.dtype = "float16"
146
164
 
147
165
  monkey_patch_vllm_dummy_weight_loader()
148
- device_config = DeviceConfig()
149
- load_config = LoadConfig(load_format=self.server_args.load_format)
150
- vllm_model_config = VllmModelConfig(
166
+ self.device_config = DeviceConfig()
167
+ self.load_config = LoadConfig(load_format=self.server_args.load_format)
168
+ self.vllm_model_config = VllmModelConfig(
151
169
  model=self.server_args.model_path,
152
170
  quantization=self.server_args.quantization,
153
171
  tokenizer=None,
@@ -158,47 +176,132 @@ class ModelRunner:
158
176
  skip_tokenizer_init=True,
159
177
  )
160
178
 
161
- if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8:
162
- # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
179
+ # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
180
+ # Drop this after Sept, 2024.
181
+ if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
163
182
  self.model_config.hf_config.num_key_value_heads = 8
164
- vllm_model_config.hf_config.num_key_value_heads = 8
183
+ self.vllm_model_config.hf_config.num_key_value_heads = 8
165
184
  monkey_patch_vllm_qvk_linear_loader()
166
185
 
167
- self.dtype = vllm_model_config.dtype
186
+ self.dtype = self.vllm_model_config.dtype
168
187
  if self.model_config.model_overide_args is not None:
169
- vllm_model_config.hf_config.update(self.model_config.model_overide_args)
170
-
171
- if (
172
- self.server_args.efficient_weight_load
173
- and "llama" in self.server_args.model_path.lower()
174
- and self.server_args.quantization == "fp8"
175
- ):
176
- from sglang.srt.model_loader.model_loader import get_model
177
- else:
178
- from vllm.model_executor.model_loader import get_model
188
+ self.vllm_model_config.hf_config.update(
189
+ self.model_config.model_overide_args
190
+ )
179
191
 
180
192
  self.model = get_model(
181
- model_config=vllm_model_config,
182
- device_config=device_config,
183
- load_config=load_config,
184
- lora_config=None,
185
- multimodal_config=None,
193
+ model_config=self.vllm_model_config,
194
+ load_config=self.load_config,
195
+ device_config=self.device_config,
186
196
  parallel_config=None,
187
197
  scheduler_config=None,
198
+ lora_config=None,
188
199
  cache_config=None,
189
200
  )
201
+ self.sliding_window_size = (
202
+ self.model.get_attention_sliding_window_size()
203
+ if hasattr(self.model, "get_attention_sliding_window_size")
204
+ else None
205
+ )
190
206
  self.is_generation = is_generation_model(
191
- self.model_config.hf_config.architectures
207
+ self.model_config.hf_config.architectures, self.server_args.is_embedding
192
208
  )
193
209
 
194
210
  logger.info(
195
- f"[gpu={self.gpu_id}] Load weight end. "
211
+ f"Load weight end. "
196
212
  f"type={type(self.model).__name__}, "
197
213
  f"dtype={self.dtype}, "
198
214
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
199
215
  )
200
216
 
201
- def profile_max_num_token(self, total_gpu_memory):
217
+ def update_weights(self, model_path: str, load_format: str):
218
+ """Update weights in-place."""
219
+ from vllm.model_executor.model_loader.loader import (
220
+ DefaultModelLoader,
221
+ device_loading_context,
222
+ get_model_loader,
223
+ )
224
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
225
+
226
+ logger.info(
227
+ f"Update weights begin. "
228
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
229
+ )
230
+
231
+ target_device = torch.device(self.device_config.device)
232
+
233
+ try:
234
+ # TODO: Use a better method to check this
235
+ vllm_model_config = VllmModelConfig(
236
+ model=model_path,
237
+ quantization=self.server_args.quantization,
238
+ tokenizer=None,
239
+ tokenizer_mode=None,
240
+ trust_remote_code=self.server_args.trust_remote_code,
241
+ dtype=self.server_args.dtype,
242
+ seed=42,
243
+ skip_tokenizer_init=True,
244
+ )
245
+ except Exception as e:
246
+ logger.error(f"Failed to load model config: {e}")
247
+ return False, "Failed to update model weights"
248
+
249
+ load_config = LoadConfig(load_format=load_format)
250
+
251
+ # Only support vllm DefaultModelLoader for now
252
+ loader = get_model_loader(load_config)
253
+ if not isinstance(loader, DefaultModelLoader):
254
+ logger.error("Failed to get weights iterator: Unsupported loader")
255
+ return False, "Failed to update model weights"
256
+
257
+ def get_weight_iter(config):
258
+ iter = loader._get_weights_iterator(
259
+ config.model,
260
+ config.revision,
261
+ fall_back_to_pt=getattr(
262
+ self.model, "fall_back_to_pt_during_load", True
263
+ ),
264
+ )
265
+ return iter
266
+
267
+ def model_load_weights(model, iter):
268
+ model.load_weights(iter)
269
+ for _, module in self.model.named_modules():
270
+ quant_method = getattr(module, "quant_method", None)
271
+ if quant_method is not None:
272
+ with device_loading_context(module, target_device):
273
+ quant_method.process_weights_after_loading(module)
274
+ return model
275
+
276
+ with set_default_torch_dtype(vllm_model_config.dtype):
277
+ try:
278
+ iter = get_weight_iter(vllm_model_config)
279
+ except Exception as e:
280
+ message = f"Failed to get weights iterator: {e}"
281
+ logger.error(message)
282
+ return False, message
283
+ try:
284
+ model = model_load_weights(self.model, iter)
285
+ except Exception as e:
286
+ message = f"Failed to update weights: {e}. \n Rolling back to original weights"
287
+ logger.error(message)
288
+ del iter
289
+ gc.collect()
290
+ iter = get_weight_iter(self.vllm_model_config)
291
+ self.model = model_load_weights(self.model, iter)
292
+ return False, message
293
+
294
+ self.model = model
295
+ self.server_args.model_path = model_path
296
+ self.server_args.load_format = load_format
297
+ self.vllm_model_config = vllm_model_config
298
+ self.load_config = load_config
299
+ self.model_config.path = model_path
300
+
301
+ logger.info("Update weights end.")
302
+ return True, "Succeeded to update model weights"
303
+
304
+ def profile_max_num_token(self, total_gpu_memory: int):
202
305
  available_gpu_memory = get_available_gpu_memory(
203
306
  self.gpu_id, distributed=self.tp_size > 1
204
307
  )
@@ -209,7 +312,7 @@ class ModelRunner:
209
312
  cell_size = (
210
313
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
211
314
  * self.model_config.num_hidden_layers
212
- * torch._utils._element_size(self.dtype)
315
+ * torch._utils._element_size(self.kv_cache_dtype)
213
316
  )
214
317
  else:
215
318
  cell_size = (
@@ -217,7 +320,7 @@ class ModelRunner:
217
320
  * self.model_config.head_dim
218
321
  * self.model_config.num_hidden_layers
219
322
  * 2
220
- * torch._utils._element_size(self.dtype)
323
+ * torch._utils._element_size(self.kv_cache_dtype)
221
324
  )
222
325
  rest_memory = available_gpu_memory - total_gpu_memory * (
223
326
  1 - self.mem_fraction_static
@@ -226,12 +329,30 @@ class ModelRunner:
226
329
  return max_num_token
227
330
 
228
331
  def init_memory_pool(
229
- self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
332
+ self,
333
+ total_gpu_memory: int,
334
+ max_num_reqs: int = None,
335
+ max_total_tokens: int = None,
230
336
  ):
337
+ if self.server_args.kv_cache_dtype == "auto":
338
+ self.kv_cache_dtype = self.dtype
339
+ elif self.server_args.kv_cache_dtype == "fp8_e5m2":
340
+ if self.server_args.disable_flashinfer or self.server_args.enable_mla:
341
+ logger.warning(
342
+ "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
343
+ )
344
+ self.kv_cache_dtype = self.dtype
345
+ else:
346
+ self.kv_cache_dtype = torch.float8_e5m2
347
+ else:
348
+ raise ValueError(
349
+ f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
350
+ )
351
+
231
352
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
232
353
  if max_total_tokens is not None:
233
354
  if max_total_tokens > self.max_total_num_tokens:
234
- warnings.warn(
355
+ logging.warning(
235
356
  f"max_total_tokens={max_total_tokens} is larger than the profiled value "
236
357
  f"{self.max_total_num_tokens}. "
237
358
  f"Use the profiled value instead."
@@ -264,7 +385,7 @@ class ModelRunner:
264
385
  ):
265
386
  self.token_to_kv_pool = MLATokenToKVPool(
266
387
  self.max_total_num_tokens,
267
- dtype=self.dtype,
388
+ dtype=self.kv_cache_dtype,
268
389
  kv_lora_rank=self.model_config.kv_lora_rank,
269
390
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
270
391
  layer_num=self.model_config.num_hidden_layers,
@@ -275,13 +396,13 @@ class ModelRunner:
275
396
  else:
276
397
  self.token_to_kv_pool = MHATokenToKVPool(
277
398
  self.max_total_num_tokens,
278
- dtype=self.dtype,
399
+ dtype=self.kv_cache_dtype,
279
400
  head_num=self.model_config.get_num_kv_heads(self.tp_size),
280
401
  head_dim=self.model_config.head_dim,
281
402
  layer_num=self.model_config.num_hidden_layers,
282
403
  )
283
404
  logger.info(
284
- f"[gpu={self.gpu_id}] Memory pool end. "
405
+ f"Memory pool end. "
285
406
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
286
407
  )
287
408
 
@@ -295,7 +416,11 @@ class ModelRunner:
295
416
  return c
296
417
 
297
418
  def init_flashinfer(self):
419
+ """Init flashinfer attention kernel wrappers."""
298
420
  if self.server_args.disable_flashinfer:
421
+ assert (
422
+ self.sliding_window_size is None
423
+ ), "turn on flashinfer to support window attention"
299
424
  self.flashinfer_prefill_wrapper_ragged = None
300
425
  self.flashinfer_prefill_wrapper_paged = None
301
426
  self.flashinfer_decode_wrapper = None
@@ -309,36 +434,72 @@ class ModelRunner:
309
434
  else:
310
435
  use_tensor_cores = False
311
436
 
312
- self.flashinfer_workspace_buffers = torch.empty(
313
- 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
314
- )
315
- self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
316
- self.flashinfer_workspace_buffers[0], "NHD"
317
- )
318
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
319
- self.flashinfer_workspace_buffers[1], "NHD"
320
- )
321
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
322
- self.flashinfer_workspace_buffers[0],
323
- "NHD",
324
- use_tensor_cores=use_tensor_cores,
325
- )
437
+ if self.sliding_window_size is None:
438
+ self.flashinfer_workspace_buffer = torch.empty(
439
+ global_config.flashinfer_workspace_size,
440
+ dtype=torch.uint8,
441
+ device="cuda",
442
+ )
443
+ self.flashinfer_prefill_wrapper_ragged = (
444
+ BatchPrefillWithRaggedKVCacheWrapper(
445
+ self.flashinfer_workspace_buffer, "NHD"
446
+ )
447
+ )
448
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
449
+ self.flashinfer_workspace_buffer, "NHD"
450
+ )
451
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
452
+ self.flashinfer_workspace_buffer,
453
+ "NHD",
454
+ use_tensor_cores=use_tensor_cores,
455
+ )
456
+ else:
457
+ self.flashinfer_workspace_buffer = torch.empty(
458
+ global_config.flashinfer_workspace_size,
459
+ dtype=torch.uint8,
460
+ device="cuda",
461
+ )
462
+ self.flashinfer_prefill_wrapper_ragged = None
463
+ self.flashinfer_prefill_wrapper_paged = []
464
+ self.flashinfer_decode_wrapper = []
465
+ for i in range(2):
466
+ self.flashinfer_prefill_wrapper_paged.append(
467
+ BatchPrefillWithPagedKVCacheWrapper(
468
+ self.flashinfer_workspace_buffer, "NHD"
469
+ )
470
+ )
471
+ self.flashinfer_decode_wrapper.append(
472
+ BatchDecodeWithPagedKVCacheWrapper(
473
+ self.flashinfer_workspace_buffer,
474
+ "NHD",
475
+ use_tensor_cores=use_tensor_cores,
476
+ )
477
+ )
326
478
 
327
479
  def init_cuda_graphs(self):
480
+ """Capture cuda graphs."""
481
+ if not self.is_generation:
482
+ # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
483
+ return
484
+
328
485
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
329
486
 
330
487
  if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
331
488
  self.cuda_graph_runner = None
332
489
  return
333
490
 
334
- logger.info(
335
- f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
336
- )
337
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
491
+ logger.info("Capture cuda graph begin. This can take up to several minutes.")
492
+
493
+ if self.server_args.disable_cuda_graph_padding:
494
+ batch_size_list = list(range(1, 32)) + [64, 128]
495
+ else:
496
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
497
+
338
498
  self.cuda_graph_runner = CudaGraphRunner(
339
499
  self,
340
500
  max_batch_size_to_capture=max(batch_size_list),
341
501
  use_torch_compile=self.server_args.enable_torch_compile,
502
+ disable_padding=self.server_args.disable_cuda_graph_padding,
342
503
  )
343
504
  try:
344
505
  self.cuda_graph_runner.capture(batch_size_list)
@@ -354,11 +515,17 @@ class ModelRunner:
354
515
 
355
516
  @torch.inference_mode()
356
517
  def forward_decode(self, batch: ScheduleBatch):
357
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
518
+ if (
519
+ self.cuda_graph_runner
520
+ and self.cuda_graph_runner.can_run(len(batch.reqs))
521
+ and not batch.sampling_info.has_bias()
522
+ ):
358
523
  return self.cuda_graph_runner.replay(batch)
359
524
 
360
525
  input_metadata = InputMetadata.from_schedule_batch(
361
- self, batch, ForwardMode.DECODE
526
+ self,
527
+ batch,
528
+ ForwardMode.DECODE,
362
529
  )
363
530
 
364
531
  return self.model.forward(
@@ -368,16 +535,29 @@ class ModelRunner:
368
535
  @torch.inference_mode()
369
536
  def forward_extend(self, batch: ScheduleBatch):
370
537
  input_metadata = InputMetadata.from_schedule_batch(
371
- self, batch, forward_mode=ForwardMode.EXTEND
372
- )
373
- return self.model.forward(
374
- batch.input_ids, input_metadata.positions, input_metadata
538
+ self,
539
+ batch,
540
+ forward_mode=ForwardMode.EXTEND,
375
541
  )
542
+ if self.is_generation:
543
+ return self.model.forward(
544
+ batch.input_ids, input_metadata.positions, input_metadata
545
+ )
546
+ else:
547
+ # Only embedding models have get_embedding parameter
548
+ return self.model.forward(
549
+ batch.input_ids,
550
+ input_metadata.positions,
551
+ input_metadata,
552
+ get_embedding=True,
553
+ )
376
554
 
377
555
  @torch.inference_mode()
378
556
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
379
557
  input_metadata = InputMetadata.from_schedule_batch(
380
- self, batch, forward_mode=ForwardMode.EXTEND
558
+ self,
559
+ batch,
560
+ forward_mode=ForwardMode.EXTEND,
381
561
  )
382
562
  return self.model.forward(
383
563
  batch.input_ids,
@@ -388,7 +568,9 @@ class ModelRunner:
388
568
  input_metadata.image_offsets,
389
569
  )
390
570
 
391
- def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
571
+ def forward(
572
+ self, batch: ScheduleBatch, forward_mode: ForwardMode
573
+ ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
392
574
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
393
575
  return self.forward_extend_multi_modal(batch)
394
576
  elif forward_mode == ForwardMode.DECODE:
@@ -444,4 +626,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
444
626
 
445
627
 
446
628
  # Monkey patch model loader
447
- setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
629
+ setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
@@ -24,8 +24,6 @@ from torch import nn
24
24
  from torch.nn import LayerNorm
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.activation import SiluAndMul
28
- from vllm.model_executor.layers.layernorm import RMSNorm
29
27
  from vllm.model_executor.layers.linear import (
30
28
  MergedColumnParallelLinear,
31
29
  QKVParallelLinear,
@@ -33,18 +31,18 @@ from vllm.model_executor.layers.linear import (
33
31
  )
34
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
35
33
  from vllm.model_executor.layers.rotary_embedding import get_rope
36
- from vllm.model_executor.layers.sampler import Sampler
37
34
  from vllm.model_executor.layers.vocab_parallel_embedding import (
38
35
  ParallelLMHead,
39
36
  VocabParallelEmbedding,
40
37
  )
41
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
- from vllm.model_executor.sampling_metadata import SamplingMetadata
43
- from vllm.sequence import SamplerOutput
44
39
  from vllm.transformers_utils.configs import ChatGLMConfig
45
40
 
41
+ from sglang.srt.layers.activation import SiluAndMul
42
+ from sglang.srt.layers.layernorm import RMSNorm
46
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.layers.sampler import Sampler
48
46
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
47
 
50
48
  LoraConfig = None
@@ -383,17 +381,11 @@ class ChatGLMForCausalLM(nn.Module):
383
381
  input_metadata: InputMetadata,
384
382
  ) -> torch.Tensor:
385
383
  hidden_states = self.transformer(input_ids, positions, input_metadata)
386
- return self.logits_processor(
384
+ logits_output = self.logits_processor(
387
385
  input_ids, hidden_states, self.lm_head.weight, input_metadata
388
386
  )
389
-
390
- def sample(
391
- self,
392
- logits: torch.Tensor,
393
- sampling_metadata: SamplingMetadata,
394
- ) -> Optional[SamplerOutput]:
395
- next_tokens = self.sampler(logits, sampling_metadata)
396
- return next_tokens
387
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
388
+ return sample_output, logits_output
397
389
 
398
390
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
391
  params_dict = dict(self.named_parameters(remove_duplicate=False))
@@ -50,7 +50,6 @@ from vllm.distributed import (
50
50
  get_tensor_model_parallel_rank,
51
51
  get_tensor_model_parallel_world_size,
52
52
  )
53
- from vllm.model_executor.layers.activation import SiluAndMul
54
53
  from vllm.model_executor.layers.linear import (
55
54
  MergedColumnParallelLinear,
56
55
  QKVParallelLinear,
@@ -62,8 +61,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
62
61
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
63
62
  from vllm.model_executor.utils import set_weight_attrs
64
63
 
64
+ from sglang.srt.layers.activation import SiluAndMul
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
+ from sglang.srt.layers.sampler import Sampler
67
68
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
69
 
69
70
 
@@ -326,6 +327,7 @@ class CohereForCausalLM(nn.Module):
326
327
  self.config = config
327
328
  self.quant_config = quant_config
328
329
  self.logits_processor = LogitsProcessor(config)
330
+ self.sampler = Sampler()
329
331
  self.model = CohereModel(config, quant_config)
330
332
 
331
333
  @torch.no_grad()
@@ -340,9 +342,11 @@ class CohereForCausalLM(nn.Module):
340
342
  positions,
341
343
  input_metadata,
342
344
  )
343
- return self.logits_processor(
345
+ logits_output = self.logits_processor(
344
346
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
345
347
  )
348
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
349
+ return sample_output, logits_output
346
350
 
347
351
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
348
352
  stacked_params_mapping = [
sglang/srt/models/dbrx.py CHANGED
@@ -45,6 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
+ from sglang.srt.layers.sampler import Sampler
48
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
50
 
50
51
 
@@ -382,6 +383,7 @@ class DbrxForCausalLM(nn.Module):
382
383
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
383
384
  )
384
385
  self.logits_processor = LogitsProcessor(config)
386
+ self.sampler = Sampler()
385
387
 
386
388
  @torch.no_grad()
387
389
  def forward(
@@ -391,9 +393,11 @@ class DbrxForCausalLM(nn.Module):
391
393
  input_metadata: InputMetadata,
392
394
  ) -> torch.Tensor:
393
395
  hidden_states = self.transformer(input_ids, positions, input_metadata)
394
- return self.logits_processor(
396
+ logits_output = self.logits_processor(
395
397
  input_ids, hidden_states, self.lm_head.weight, input_metadata
396
398
  )
399
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
400
+ return sample_output, logits_output
397
401
 
398
402
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
399
403
  expert_params_mapping = [