sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,319 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """ModelRunner runs the forward passes of the models."""
17
+ from dataclasses import dataclass
18
+ from enum import IntEnum, auto
19
+ from typing import TYPE_CHECKING, List
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
25
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
26
+
27
+ if TYPE_CHECKING:
28
+ from sglang.srt.model_executor.model_runner import ModelRunner
29
+
30
+
31
+ class ForwardMode(IntEnum):
32
+ # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
33
+ PREFILL = auto()
34
+ # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
35
+ EXTEND = auto()
36
+ # Decode one token.
37
+ DECODE = auto()
38
+
39
+
40
+ @dataclass
41
+ class InputMetadata:
42
+ """Store all inforamtion of a forward pass."""
43
+
44
+ forward_mode: ForwardMode
45
+ batch_size: int
46
+ req_pool_indices: torch.Tensor
47
+ seq_lens: torch.Tensor
48
+ req_to_token_pool: ReqToTokenPool
49
+ token_to_kv_pool: BaseTokenToKVPool
50
+
51
+ # Output location of the KV cache
52
+ out_cache_loc: torch.Tensor
53
+
54
+ total_num_tokens: int = None
55
+
56
+ # Position information
57
+ positions: torch.Tensor = None
58
+
59
+ # For extend
60
+ extend_seq_lens: torch.Tensor = None
61
+ extend_start_loc: torch.Tensor = None
62
+ extend_no_prefix: bool = None
63
+
64
+ # Output options
65
+ return_logprob: bool = False
66
+ top_logprobs_nums: List[int] = None
67
+
68
+ # For multimodal
69
+ pixel_values: List[torch.Tensor] = None
70
+ image_sizes: List[List[int]] = None
71
+ image_offsets: List[int] = None
72
+
73
+ # Trition attention backend
74
+ triton_max_seq_len: int = 0
75
+ triton_max_extend_len: int = 0
76
+ triton_start_loc: torch.Tensor = None
77
+ triton_prefix_lens: torch.Tensor = None
78
+
79
+ # FlashInfer attention backend
80
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
81
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
82
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
83
+ flashinfer_use_ragged: bool = False
84
+
85
+ def init_multimuldal_info(self, batch: ScheduleBatch):
86
+ reqs = batch.reqs
87
+ self.pixel_values = [r.pixel_values for r in reqs]
88
+ self.image_sizes = [r.image_size for r in reqs]
89
+ self.image_offsets = [
90
+ (
91
+ (r.image_offset - len(r.prefix_indices))
92
+ if r.image_offset is not None
93
+ else 0
94
+ )
95
+ for r in reqs
96
+ ]
97
+
98
+ def compute_positions(self, batch: ScheduleBatch):
99
+ position_ids_offsets = batch.position_ids_offsets
100
+
101
+ if self.forward_mode == ForwardMode.DECODE:
102
+ if True:
103
+ self.positions = self.seq_lens - 1
104
+ else:
105
+ # Deprecated
106
+ self.positions = (self.seq_lens - 1) + position_ids_offsets
107
+ else:
108
+ if True:
109
+ self.positions = torch.tensor(
110
+ np.concatenate(
111
+ [
112
+ np.arange(len(req.prefix_indices), len(req.fill_ids))
113
+ for req in batch.reqs
114
+ ],
115
+ axis=0,
116
+ ),
117
+ device="cuda",
118
+ )
119
+ else:
120
+ # Deprecated
121
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
122
+ self.positions = torch.tensor(
123
+ np.concatenate(
124
+ [
125
+ np.arange(
126
+ len(req.prefix_indices) + position_ids_offsets_cpu[i],
127
+ len(req.fill_ids) + position_ids_offsets_cpu[i],
128
+ )
129
+ for i, req in enumerate(batch.reqs)
130
+ ],
131
+ axis=0,
132
+ ),
133
+ device="cuda",
134
+ )
135
+
136
+ # Positions should be in long type
137
+ self.positions = self.positions.to(torch.int64)
138
+
139
+ def compute_extend_infos(self, batch: ScheduleBatch):
140
+ if self.forward_mode == ForwardMode.DECODE:
141
+ self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
142
+ else:
143
+ extend_lens_cpu = [
144
+ len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
145
+ ]
146
+ self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
147
+ self.extend_start_loc = torch.zeros_like(self.seq_lens)
148
+ self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
149
+ self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
150
+
151
+ @classmethod
152
+ def from_schedule_batch(
153
+ cls,
154
+ model_runner: "ModelRunner",
155
+ batch: ScheduleBatch,
156
+ forward_mode: ForwardMode,
157
+ ):
158
+ ret = cls(
159
+ forward_mode=forward_mode,
160
+ batch_size=batch.batch_size(),
161
+ req_pool_indices=batch.req_pool_indices,
162
+ seq_lens=batch.seq_lens,
163
+ req_to_token_pool=model_runner.req_to_token_pool,
164
+ token_to_kv_pool=model_runner.token_to_kv_pool,
165
+ out_cache_loc=batch.out_cache_loc,
166
+ return_logprob=batch.return_logprob,
167
+ top_logprobs_nums=batch.top_logprobs_nums,
168
+ )
169
+
170
+ ret.compute_positions(batch)
171
+
172
+ ret.compute_extend_infos(batch)
173
+
174
+ if (
175
+ forward_mode != ForwardMode.DECODE
176
+ or model_runner.server_args.disable_flashinfer
177
+ ):
178
+ ret.total_num_tokens = int(torch.sum(ret.seq_lens))
179
+
180
+ if forward_mode != ForwardMode.DECODE:
181
+ ret.init_multimuldal_info(batch)
182
+
183
+ prefix_lens = None
184
+ if forward_mode != ForwardMode.DECODE:
185
+ prefix_lens = torch.tensor(
186
+ [len(r.prefix_indices) for r in batch.reqs], device="cuda"
187
+ )
188
+
189
+ if model_runner.server_args.disable_flashinfer:
190
+ ret.init_triton_args(batch, prefix_lens)
191
+
192
+ flashinfer_use_ragged = False
193
+ if not model_runner.server_args.disable_flashinfer:
194
+ if (
195
+ forward_mode != ForwardMode.DECODE
196
+ and int(torch.sum(ret.seq_lens)) > 4096
197
+ ):
198
+ flashinfer_use_ragged = True
199
+ ret.init_flashinfer_handlers(
200
+ model_runner, prefix_lens, flashinfer_use_ragged
201
+ )
202
+
203
+ return ret
204
+
205
+ def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
206
+ """Init auxiliary variables for triton attention backend."""
207
+ self.triton_max_seq_len = int(torch.max(self.seq_lens))
208
+ self.triton_prefix_lens = prefix_lens
209
+ self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
210
+ self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
211
+
212
+ if self.forward_mode == ForwardMode.DECODE:
213
+ self.triton_max_extend_len = None
214
+ else:
215
+ extend_seq_lens = self.seq_lens - prefix_lens
216
+ self.triton_max_extend_len = int(torch.max(extend_seq_lens))
217
+
218
+ def init_flashinfer_handlers(
219
+ self, model_runner, prefix_lens, flashinfer_use_ragged
220
+ ):
221
+ update_flashinfer_indices(
222
+ self.forward_mode,
223
+ model_runner,
224
+ self.req_pool_indices,
225
+ self.seq_lens,
226
+ prefix_lens,
227
+ flashinfer_use_ragged=flashinfer_use_ragged,
228
+ )
229
+
230
+ (
231
+ self.flashinfer_prefill_wrapper_ragged,
232
+ self.flashinfer_prefill_wrapper_paged,
233
+ self.flashinfer_decode_wrapper,
234
+ self.flashinfer_use_ragged,
235
+ ) = (
236
+ model_runner.flashinfer_prefill_wrapper_ragged,
237
+ model_runner.flashinfer_prefill_wrapper_paged,
238
+ model_runner.flashinfer_decode_wrapper,
239
+ flashinfer_use_ragged,
240
+ )
241
+
242
+
243
+ def update_flashinfer_indices(
244
+ forward_mode,
245
+ model_runner,
246
+ req_pool_indices,
247
+ seq_lens,
248
+ prefix_lens,
249
+ flashinfer_decode_wrapper=None,
250
+ flashinfer_use_ragged=False,
251
+ ):
252
+ """Init auxiliary variables for FlashInfer attention backend."""
253
+ num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
254
+ num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
255
+ head_dim = model_runner.model_config.head_dim
256
+ batch_size = len(req_pool_indices)
257
+
258
+ if flashinfer_use_ragged:
259
+ paged_kernel_lens = prefix_lens
260
+ else:
261
+ paged_kernel_lens = seq_lens
262
+
263
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
264
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
265
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
266
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
267
+ kv_indices = torch.cat(
268
+ [
269
+ model_runner.req_to_token_pool.req_to_token[
270
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
271
+ ]
272
+ for i in range(batch_size)
273
+ ],
274
+ dim=0,
275
+ ).contiguous()
276
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
277
+
278
+ if forward_mode == ForwardMode.DECODE:
279
+ # CUDA graph uses different flashinfer_decode_wrapper
280
+ if flashinfer_decode_wrapper is None:
281
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
282
+
283
+ flashinfer_decode_wrapper.end_forward()
284
+ flashinfer_decode_wrapper.begin_forward(
285
+ kv_indptr,
286
+ kv_indices,
287
+ kv_last_page_len,
288
+ num_qo_heads,
289
+ num_kv_heads,
290
+ head_dim,
291
+ 1,
292
+ )
293
+ else:
294
+ # extend part
295
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
296
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
297
+
298
+ if flashinfer_use_ragged:
299
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
300
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
301
+ qo_indptr,
302
+ qo_indptr,
303
+ num_qo_heads,
304
+ num_kv_heads,
305
+ head_dim,
306
+ )
307
+
308
+ # cached part
309
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
310
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
311
+ qo_indptr,
312
+ kv_indptr,
313
+ kv_indices,
314
+ kv_last_page_len,
315
+ num_qo_heads,
316
+ num_kv_heads,
317
+ head_dim,
318
+ 1,
319
+ )
@@ -41,21 +41,18 @@ from vllm.distributed import (
41
41
  from vllm.model_executor.models import ModelRegistry
42
42
 
43
43
  from sglang.global_config import global_config
44
- from sglang.srt.managers.schedule_batch import (
45
- Batch,
46
- ForwardMode,
47
- InputMetadata,
48
- global_server_args_dict,
49
- )
44
+ from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
50
45
  from sglang.srt.mem_cache.memory_pool import (
51
46
  MHATokenToKVPool,
52
47
  MLATokenToKVPool,
53
48
  ReqToTokenPool,
54
49
  )
55
50
  from sglang.srt.model_config import AttentionArch
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
56
52
  from sglang.srt.server_args import ServerArgs
57
53
  from sglang.srt.utils import (
58
54
  get_available_gpu_memory,
55
+ is_generation_model,
59
56
  is_llama3_405b_fp8,
60
57
  is_multimodal_model,
61
58
  monkey_patch_vllm_dummy_weight_loader,
@@ -134,10 +131,12 @@ class ModelRunner:
134
131
  server_args.max_total_tokens,
135
132
  )
136
133
  self.init_cublas()
137
- self.init_flash_infer()
134
+ self.init_flashinfer()
138
135
 
139
- # Capture cuda graphs
140
- self.init_cuda_graphs()
136
+ if self.is_generation:
137
+ # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
138
+ # Capture cuda graphs
139
+ self.init_cuda_graphs()
141
140
 
142
141
  def load_model(self):
143
142
  logger.info(
@@ -188,6 +187,10 @@ class ModelRunner:
188
187
  scheduler_config=None,
189
188
  cache_config=None,
190
189
  )
190
+ self.is_generation = is_generation_model(
191
+ self.model_config.hf_config.architectures
192
+ )
193
+
191
194
  logger.info(
192
195
  f"[gpu={self.gpu_id}] Load weight end. "
193
196
  f"type={type(self.model).__name__}, "
@@ -291,7 +294,7 @@ class ModelRunner:
291
294
  c = a @ b
292
295
  return c
293
296
 
294
- def init_flash_infer(self):
297
+ def init_flashinfer(self):
295
298
  if self.server_args.disable_flashinfer:
296
299
  self.flashinfer_prefill_wrapper_ragged = None
297
300
  self.flashinfer_prefill_wrapper_paged = None
@@ -350,65 +353,42 @@ class ModelRunner:
350
353
  )
351
354
 
352
355
  @torch.inference_mode()
353
- def forward_decode(self, batch: Batch):
356
+ def forward_decode(self, batch: ScheduleBatch):
354
357
  if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
355
358
  return self.cuda_graph_runner.replay(batch)
356
359
 
357
- input_metadata = InputMetadata.create(
358
- self,
359
- forward_mode=ForwardMode.DECODE,
360
- req_pool_indices=batch.req_pool_indices,
361
- seq_lens=batch.seq_lens,
362
- prefix_lens=batch.prefix_lens,
363
- position_ids_offsets=batch.position_ids_offsets,
364
- out_cache_loc=batch.out_cache_loc,
365
- top_logprobs_nums=batch.top_logprobs_nums,
366
- return_logprob=batch.return_logprob,
360
+ input_metadata = InputMetadata.from_schedule_batch(
361
+ self, batch, ForwardMode.DECODE
367
362
  )
363
+
368
364
  return self.model.forward(
369
365
  batch.input_ids, input_metadata.positions, input_metadata
370
366
  )
371
367
 
372
368
  @torch.inference_mode()
373
- def forward_extend(self, batch: Batch):
374
- input_metadata = InputMetadata.create(
375
- self,
376
- forward_mode=ForwardMode.EXTEND,
377
- req_pool_indices=batch.req_pool_indices,
378
- seq_lens=batch.seq_lens,
379
- prefix_lens=batch.prefix_lens,
380
- position_ids_offsets=batch.position_ids_offsets,
381
- out_cache_loc=batch.out_cache_loc,
382
- top_logprobs_nums=batch.top_logprobs_nums,
383
- return_logprob=batch.return_logprob,
369
+ def forward_extend(self, batch: ScheduleBatch):
370
+ input_metadata = InputMetadata.from_schedule_batch(
371
+ self, batch, forward_mode=ForwardMode.EXTEND
384
372
  )
385
373
  return self.model.forward(
386
374
  batch.input_ids, input_metadata.positions, input_metadata
387
375
  )
388
376
 
389
377
  @torch.inference_mode()
390
- def forward_extend_multi_modal(self, batch: Batch):
391
- input_metadata = InputMetadata.create(
392
- self,
393
- forward_mode=ForwardMode.EXTEND,
394
- req_pool_indices=batch.req_pool_indices,
395
- seq_lens=batch.seq_lens,
396
- prefix_lens=batch.prefix_lens,
397
- position_ids_offsets=batch.position_ids_offsets,
398
- out_cache_loc=batch.out_cache_loc,
399
- return_logprob=batch.return_logprob,
400
- top_logprobs_nums=batch.top_logprobs_nums,
378
+ def forward_extend_multi_modal(self, batch: ScheduleBatch):
379
+ input_metadata = InputMetadata.from_schedule_batch(
380
+ self, batch, forward_mode=ForwardMode.EXTEND
401
381
  )
402
382
  return self.model.forward(
403
383
  batch.input_ids,
404
384
  input_metadata.positions,
405
385
  input_metadata,
406
- batch.pixel_values,
407
- batch.image_sizes,
408
- batch.image_offsets,
386
+ input_metadata.pixel_values,
387
+ input_metadata.image_sizes,
388
+ input_metadata.image_offsets,
409
389
  )
410
390
 
411
- def forward(self, batch: Batch, forward_mode: ForwardMode):
391
+ def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
412
392
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
413
393
  return self.forward_extend_multi_modal(batch)
414
394
  elif forward_mode == ForwardMode.DECODE:
@@ -433,8 +413,10 @@ def import_model_classes():
433
413
  entry, list
434
414
  ): # To support multiple model classes in one module
435
415
  for tmp in entry:
416
+ assert tmp.__name__ not in model_arch_name_to_cls
436
417
  model_arch_name_to_cls[tmp.__name__] = tmp
437
418
  else:
419
+ assert entry.__name__ not in model_arch_name_to_cls
438
420
  model_arch_name_to_cls[entry.__name__] = entry
439
421
 
440
422
  # compat: some models such as chatglm has incorrect class set in config.json
@@ -444,6 +426,7 @@ def import_model_classes():
444
426
  ):
445
427
  for remap in module.EntryClassRemapping:
446
428
  if isinstance(remap, tuple) and len(remap) == 2:
429
+ assert remap[0] not in model_arch_name_to_cls
447
430
  model_arch_name_to_cls[remap[0]] = remap[1]
448
431
 
449
432
  return model_arch_name_to_cls
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
  LoraConfig = None
51
51
 
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
 
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
- from sglang.srt.model_executor.model_runner import InputMetadata
67
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
68
 
69
69
 
70
70
  @torch.compile
sglang/srt/models/dbrx.py CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
 
51
51
  class DbrxRouter(nn.Module):
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
46
 
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.managers.schedule_batch import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
51
51
 
52
52
  class DeepseekMLP(nn.Module):
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
- from sglang.srt.model_executor.model_runner import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
51
51
 
52
52
  class DeepseekV2MLP(nn.Module):
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
37
 
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.model_executor.model_runner import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
41
 
42
42
 
43
43
  class GemmaMLP(nn.Module):
@@ -38,11 +38,10 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
38
38
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
39
39
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
- from vllm.model_executor.sampling_metadata import SamplingMetadata
42
41
 
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
43
  from sglang.srt.layers.radix_attention import RadixAttention
45
- from sglang.srt.model_executor.model_runner import InputMetadata
44
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
45
 
47
46
 
48
47
  class GemmaRMSNorm(CustomOp):
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
 
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.managers.schedule_batch import InputMetadata
38
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
39
 
40
40
 
41
41
  class GPTBigCodeAttention(nn.Module):
sglang/srt/models/grok.py CHANGED
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
52
52
  from sglang.srt.layers.fused_moe import fused_moe
53
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
54
  from sglang.srt.layers.radix_attention import RadixAttention
55
- from sglang.srt.model_executor.model_runner import InputMetadata
55
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
56
56
 
57
57
  use_fused = True
58
58
 
@@ -23,8 +23,6 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import SiluAndMul
27
- from vllm.model_executor.layers.layernorm import RMSNorm
28
26
  from vllm.model_executor.layers.linear import (
29
27
  MergedColumnParallelLinear,
30
28
  QKVParallelLinear,
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
38
36
  )
39
37
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
38
 
39
+ from sglang.srt.layers.activation import SiluAndMul
40
+ from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.model_runner import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
44
 
45
45
 
46
46
  class InternLM2MLP(nn.Module):
47
-
48
47
  def __init__(
49
48
  self,
50
49
  hidden_size: int,
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
74
73
 
75
74
 
76
75
  class InternLM2Attention(nn.Module):
77
-
78
76
  def __init__(
79
77
  self,
80
78
  hidden_size: int,
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
150
148
 
151
149
 
152
150
  class InternLMDecoderLayer(nn.Module):
153
-
154
151
  def __init__(
155
152
  self,
156
153
  config: PretrainedConfig,
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
207
204
 
208
205
 
209
206
  class InternLM2Model(nn.Module):
210
-
211
207
  def __init__(
212
208
  self,
213
209
  config: PretrainedConfig,
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
254
250
 
255
251
 
256
252
  class InternLM2ForCausalLM(nn.Module):
257
-
258
253
  def __init__(
259
254
  self,
260
255
  config: PretrainedConfig,
@@ -24,8 +24,6 @@ from torch import nn
24
24
  from transformers import LlamaConfig
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.activation import SiluAndMul
28
- from vllm.model_executor.layers.layernorm import RMSNorm
29
27
  from vllm.model_executor.layers.linear import (
30
28
  MergedColumnParallelLinear,
31
29
  QKVParallelLinear,
@@ -39,9 +37,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
39
37
  )
40
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
39
 
42
- from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.activation import SiluAndMul
41
+ from sglang.srt.layers.layernorm import RMSNorm
42
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.model_executor.model_runner import InputMetadata
44
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
45
 
46
46
 
47
47
  class LlamaMLP(nn.Module):
@@ -310,7 +310,7 @@ class LlamaForCausalLM(nn.Module):
310
310
  positions: torch.Tensor,
311
311
  input_metadata: InputMetadata,
312
312
  input_embeds: torch.Tensor = None,
313
- ) -> torch.Tensor:
313
+ ) -> LogitProcessorOutput:
314
314
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
315
  return self.logits_processor(
316
316
  input_ids, hidden_states, self.lm_head.weight, input_metadata
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
27
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
28
- from sglang.srt.model_executor.model_runner import InputMetadata
28
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
31
31