sglang 0.2.10__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +145 -36
  4. sglang/check_env.py +24 -2
  5. sglang/global_config.py +0 -1
  6. sglang/lang/backend/base_backend.py +3 -1
  7. sglang/lang/backend/openai.py +8 -3
  8. sglang/lang/backend/runtime_endpoint.py +46 -29
  9. sglang/lang/choices.py +164 -0
  10. sglang/lang/interpreter.py +6 -13
  11. sglang/lang/ir.py +11 -2
  12. sglang/srt/layers/logits_processor.py +1 -1
  13. sglang/srt/layers/radix_attention.py +2 -5
  14. sglang/srt/managers/schedule_batch.py +95 -324
  15. sglang/srt/managers/tokenizer_manager.py +6 -3
  16. sglang/srt/managers/tp_worker.py +20 -22
  17. sglang/srt/mem_cache/memory_pool.py +9 -14
  18. sglang/srt/model_executor/cuda_graph_runner.py +3 -3
  19. sglang/srt/model_executor/forward_batch_info.py +256 -0
  20. sglang/srt/model_executor/model_runner.py +6 -10
  21. sglang/srt/models/chatglm.py +1 -1
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/dbrx.py +1 -1
  24. sglang/srt/models/deepseek.py +1 -1
  25. sglang/srt/models/deepseek_v2.py +1 -1
  26. sglang/srt/models/gemma.py +1 -1
  27. sglang/srt/models/gemma2.py +1 -1
  28. sglang/srt/models/gpt_bigcode.py +1 -1
  29. sglang/srt/models/grok.py +1 -1
  30. sglang/srt/models/internlm2.py +1 -1
  31. sglang/srt/models/llama2.py +1 -1
  32. sglang/srt/models/llama_classification.py +1 -1
  33. sglang/srt/models/llava.py +1 -2
  34. sglang/srt/models/llavavid.py +1 -2
  35. sglang/srt/models/minicpm.py +1 -1
  36. sglang/srt/models/mixtral.py +1 -1
  37. sglang/srt/models/mixtral_quant.py +1 -1
  38. sglang/srt/models/qwen.py +1 -1
  39. sglang/srt/models/qwen2.py +1 -1
  40. sglang/srt/models/qwen2_moe.py +1 -1
  41. sglang/srt/models/stablelm.py +1 -1
  42. sglang/srt/openai_api/adapter.py +34 -12
  43. sglang/srt/openai_api/protocol.py +6 -0
  44. sglang/srt/server.py +24 -6
  45. sglang/srt/server_args.py +4 -0
  46. sglang/test/test_utils.py +1 -1
  47. sglang/version.py +1 -1
  48. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/METADATA +34 -24
  49. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/RECORD +52 -50
  50. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
  51. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
  52. {sglang-0.2.10.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -39,13 +39,13 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler
39
39
  from sglang.srt.managers.schedule_batch import (
40
40
  FINISH_ABORT,
41
41
  BaseFinishReason,
42
- Batch,
43
- ForwardMode,
44
42
  Req,
43
+ ScheduleBatch,
45
44
  )
46
45
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
47
46
  from sglang.srt.mem_cache.radix_cache import RadixCache
48
47
  from sglang.srt.model_config import ModelConfig
48
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
49
49
  from sglang.srt.model_executor.model_runner import ModelRunner
50
50
  from sglang.srt.server_args import ServerArgs
51
51
  from sglang.srt.utils import (
@@ -172,7 +172,7 @@ class ModelTpServer:
172
172
 
173
173
  # Init running status
174
174
  self.waiting_queue: List[Req] = []
175
- self.running_batch: Batch = None
175
+ self.running_batch: ScheduleBatch = None
176
176
  self.out_pyobjs = []
177
177
  self.decode_forward_ct = 0
178
178
  self.stream_interval = server_args.stream_interval
@@ -200,7 +200,6 @@ class ModelTpServer:
200
200
  )
201
201
  self.new_token_ratio = self.min_new_token_ratio
202
202
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
203
- self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
204
203
 
205
204
  def exposed_step(self, recv_reqs):
206
205
  try:
@@ -290,10 +289,10 @@ class ModelTpServer:
290
289
  "KV cache pool leak detected!"
291
290
  )
292
291
 
293
- if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
292
+ if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
294
293
  warnings.warn(
295
294
  "Warning: "
296
- f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
295
+ f"available req slots={len(self.req_to_token_pool.free_slots)}, "
297
296
  f"total slots={self.req_to_token_pool.size}\n"
298
297
  "Memory pool leak detected!"
299
298
  )
@@ -353,7 +352,7 @@ class ModelTpServer:
353
352
  )
354
353
  self.waiting_queue.append(req)
355
354
 
356
- def get_new_prefill_batch(self) -> Optional[Batch]:
355
+ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
357
356
  # TODO(lsyin): organize this function
358
357
  running_bs = (
359
358
  len(self.running_batch.reqs) if self.running_batch is not None else 0
@@ -364,12 +363,13 @@ class ModelTpServer:
364
363
  # Compute matched prefix length
365
364
  for req in self.waiting_queue:
366
365
  req.input_ids = req.origin_input_ids + req.output_ids
366
+ try_match_ids = req.input_ids
367
+ if req.return_logprob:
368
+ try_match_ids = req.input_ids[: req.logprob_start_len]
369
+ # NOTE: the prefix_indices must always be aligned with last_node
367
370
  prefix_indices, last_node = self.tree_cache.match_prefix(
368
- rid=req.rid,
369
- key=req.input_ids,
371
+ rid=req.rid, key=try_match_ids
370
372
  )
371
- if req.return_logprob:
372
- prefix_indices = prefix_indices[: req.logprob_start_len]
373
373
  req.extend_input_len = len(req.input_ids) - len(prefix_indices)
374
374
  req.prefix_indices = prefix_indices
375
375
  req.last_node = last_node
@@ -525,7 +525,7 @@ class ModelTpServer:
525
525
  )
526
526
 
527
527
  # Return the new batch
528
- new_batch = Batch.init_new(
528
+ new_batch = ScheduleBatch.init_new(
529
529
  can_run_list,
530
530
  self.req_to_token_pool,
531
531
  self.token_to_kv_pool,
@@ -534,7 +534,7 @@ class ModelTpServer:
534
534
  self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
535
535
  return new_batch
536
536
 
537
- def forward_prefill_batch(self, batch: Batch):
537
+ def forward_prefill_batch(self, batch: ScheduleBatch):
538
538
  # Build batch tensors
539
539
  batch.prepare_for_extend(
540
540
  self.model_config.vocab_size, self.int_token_logit_bias
@@ -623,14 +623,13 @@ class ModelTpServer:
623
623
  )
624
624
  req.output_top_logprobs.append(output.output_top_logprobs[i])
625
625
 
626
- def cache_filled_batch(self, batch: Batch):
627
- req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
626
+ def cache_filled_batch(self, batch: ScheduleBatch):
628
627
  for i, req in enumerate(batch.reqs):
629
628
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
630
629
  rid=req.rid,
631
630
  token_ids=tuple(req.input_ids),
632
631
  last_uncached_pos=len(req.prefix_indices),
633
- req_pool_idx=req_pool_indices_cpu[i],
632
+ req_pool_idx=req.req_pool_idx,
634
633
  del_in_memory_pool=False,
635
634
  old_last_node=req.last_node,
636
635
  )
@@ -638,9 +637,9 @@ class ModelTpServer:
638
637
 
639
638
  if req is self.current_inflight_req:
640
639
  # inflight request would get a new req idx
641
- self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
640
+ self.req_to_token_pool.free(req.req_pool_idx)
642
641
 
643
- def forward_decode_batch(self, batch: Batch):
642
+ def forward_decode_batch(self, batch: ScheduleBatch):
644
643
  # Check if decode out of memory
645
644
  if not batch.check_decode_mem():
646
645
  old_ratio = self.new_token_ratio
@@ -699,7 +698,7 @@ class ModelTpServer:
699
698
 
700
699
  self.handle_finished_requests(batch)
701
700
 
702
- def handle_finished_requests(self, batch: Batch):
701
+ def handle_finished_requests(self, batch: ScheduleBatch):
703
702
  output_rids = []
704
703
  output_vids = []
705
704
  decoded_texts = []
@@ -781,14 +780,13 @@ class ModelTpServer:
781
780
  # Remove finished reqs
782
781
  if finished_indices:
783
782
  # Update radix cache
784
- req_pool_indices_cpu = batch.req_pool_indices.tolist()
785
783
  for i in finished_indices:
786
784
  req = batch.reqs[i]
787
785
  self.tree_cache.cache_req(
788
786
  rid=req.rid,
789
787
  token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
790
788
  last_uncached_pos=len(req.prefix_indices),
791
- req_pool_idx=req_pool_indices_cpu[i],
789
+ req_pool_idx=req.req_pool_idx,
792
790
  )
793
791
 
794
792
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -799,7 +797,7 @@ class ModelTpServer:
799
797
  else:
800
798
  batch.reqs = []
801
799
 
802
- def filter_out_inflight(self, batch: Batch):
800
+ def filter_out_inflight(self, batch: ScheduleBatch):
803
801
  # TODO(lsyin): reduce the overhead, make a special version for this
804
802
  if self.current_inflight_req is None:
805
803
  return
@@ -16,6 +16,7 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
+ from typing import List
19
20
 
20
21
  import torch
21
22
 
@@ -27,34 +28,28 @@ class ReqToTokenPool:
27
28
 
28
29
  def __init__(self, size: int, max_context_len: int):
29
30
  self.size = size
30
- self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
31
+ self.free_slots = list(range(size))
31
32
  self.req_to_token = torch.empty(
32
33
  (size, max_context_len), dtype=torch.int32, device="cuda"
33
34
  )
34
- self.can_use_mem_size = size
35
35
 
36
- def alloc(self, need_size: int):
37
- if need_size > self.can_use_mem_size:
36
+ def alloc(self, need_size: int) -> List[int]:
37
+ if need_size > len(self.free_slots):
38
38
  return None
39
39
 
40
- select_index = (
41
- torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
42
- )
43
- self.mem_state[select_index] = False
44
- self.can_use_mem_size -= need_size
40
+ select_index = self.free_slots[:need_size]
41
+ self.free_slots = self.free_slots[need_size:]
45
42
 
46
43
  return select_index
47
44
 
48
45
  def free(self, free_index):
49
- self.mem_state[free_index] = True
50
46
  if isinstance(free_index, (int,)):
51
- self.can_use_mem_size += 1
47
+ self.free_slots.append(free_index)
52
48
  else:
53
- self.can_use_mem_size += free_index.shape[0]
49
+ self.free_slots.extend(free_index)
54
50
 
55
51
  def clear(self):
56
- self.mem_state.fill_(True)
57
- self.can_use_mem_size = len(self.mem_state)
52
+ self.free_slots = list(range(self.size))
58
53
 
59
54
 
60
55
  class BaseTokenToKVPool:
@@ -29,8 +29,8 @@ from sglang.srt.layers.logits_processor import (
29
29
  LogitsMetadata,
30
30
  LogitsProcessor,
31
31
  )
32
- from sglang.srt.managers.schedule_batch import (
33
- Batch,
32
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
33
+ from sglang.srt.model_executor.forward_batch_info import (
34
34
  ForwardMode,
35
35
  InputMetadata,
36
36
  init_flashinfer_args,
@@ -202,7 +202,7 @@ class CudaGraphRunner:
202
202
  self.graph_memory_pool = graph.pool()
203
203
  return graph, None, out, flashinfer_decode_wrapper
204
204
 
205
- def replay(self, batch: Batch):
205
+ def replay(self, batch: ScheduleBatch):
206
206
  assert batch.out_cache_loc is not None
207
207
  raw_bs = len(batch.reqs)
208
208
 
@@ -0,0 +1,256 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """ModelRunner runs the forward passes of the models."""
17
+ from dataclasses import dataclass
18
+ from enum import IntEnum, auto
19
+ from typing import List
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
25
+
26
+
27
+ class ForwardMode(IntEnum):
28
+ # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
29
+ PREFILL = auto()
30
+ # Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
31
+ EXTEND = auto()
32
+ # Decode one token.
33
+ DECODE = auto()
34
+
35
+
36
+ @dataclass
37
+ class InputMetadata:
38
+ """Store all inforamtion of a forward pass."""
39
+
40
+ forward_mode: ForwardMode
41
+ batch_size: int
42
+ total_num_tokens: int
43
+ req_pool_indices: torch.Tensor
44
+ seq_lens: torch.Tensor
45
+ positions: torch.Tensor
46
+ req_to_token_pool: ReqToTokenPool
47
+ token_to_kv_pool: BaseTokenToKVPool
48
+
49
+ # For extend
50
+ extend_seq_lens: torch.Tensor
51
+ extend_start_loc: torch.Tensor
52
+ extend_no_prefix: bool
53
+
54
+ # Output location of the KV cache
55
+ out_cache_loc: torch.Tensor = None
56
+
57
+ # Output options
58
+ return_logprob: bool = False
59
+ top_logprobs_nums: List[int] = None
60
+
61
+ # Trition attention backend
62
+ triton_max_seq_len: int = 0
63
+ triton_max_extend_len: int = 0
64
+ triton_start_loc: torch.Tensor = None
65
+ triton_prefix_lens: torch.Tensor = None
66
+
67
+ # FlashInfer attention backend
68
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
71
+ flashinfer_use_ragged: bool = False
72
+
73
+ @classmethod
74
+ def create(
75
+ cls,
76
+ model_runner,
77
+ forward_mode,
78
+ req_pool_indices,
79
+ seq_lens,
80
+ prefix_lens,
81
+ position_ids_offsets,
82
+ out_cache_loc,
83
+ top_logprobs_nums=None,
84
+ return_logprob=False,
85
+ skip_flashinfer_init=False,
86
+ ):
87
+ flashinfer_use_ragged = False
88
+ if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
89
+ if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
90
+ flashinfer_use_ragged = True
91
+ init_flashinfer_args(
92
+ forward_mode,
93
+ model_runner,
94
+ req_pool_indices,
95
+ seq_lens,
96
+ prefix_lens,
97
+ model_runner.flashinfer_decode_wrapper,
98
+ flashinfer_use_ragged,
99
+ )
100
+
101
+ batch_size = len(req_pool_indices)
102
+
103
+ if forward_mode == ForwardMode.DECODE:
104
+ positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
105
+ extend_seq_lens = extend_start_loc = extend_no_prefix = None
106
+ if not model_runner.server_args.disable_flashinfer:
107
+ # This variable is not needed in this case,
108
+ # we do not compute it to make it compatbile with cuda graph.
109
+ total_num_tokens = None
110
+ else:
111
+ total_num_tokens = int(torch.sum(seq_lens))
112
+ else:
113
+ seq_lens_cpu = seq_lens.cpu().numpy()
114
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
115
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
116
+ positions = torch.tensor(
117
+ np.concatenate(
118
+ [
119
+ np.arange(
120
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
121
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
122
+ )
123
+ for i in range(batch_size)
124
+ ],
125
+ axis=0,
126
+ ),
127
+ device="cuda",
128
+ )
129
+ extend_seq_lens = seq_lens - prefix_lens
130
+ extend_start_loc = torch.zeros_like(seq_lens)
131
+ extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
132
+ extend_no_prefix = torch.all(prefix_lens == 0)
133
+ total_num_tokens = int(torch.sum(seq_lens))
134
+
135
+ ret = cls(
136
+ forward_mode=forward_mode,
137
+ batch_size=batch_size,
138
+ total_num_tokens=total_num_tokens,
139
+ req_pool_indices=req_pool_indices,
140
+ seq_lens=seq_lens,
141
+ positions=positions,
142
+ req_to_token_pool=model_runner.req_to_token_pool,
143
+ token_to_kv_pool=model_runner.token_to_kv_pool,
144
+ out_cache_loc=out_cache_loc,
145
+ extend_seq_lens=extend_seq_lens,
146
+ extend_start_loc=extend_start_loc,
147
+ extend_no_prefix=extend_no_prefix,
148
+ return_logprob=return_logprob,
149
+ top_logprobs_nums=top_logprobs_nums,
150
+ flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
151
+ flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
152
+ flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
153
+ flashinfer_use_ragged=flashinfer_use_ragged,
154
+ )
155
+
156
+ if model_runner.server_args.disable_flashinfer:
157
+ (
158
+ ret.triton_max_seq_len,
159
+ ret.triton_max_extend_len,
160
+ ret.triton_start_loc,
161
+ ret.triton_prefix_lens,
162
+ ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
163
+
164
+ return ret
165
+
166
+
167
+ def init_flashinfer_args(
168
+ forward_mode,
169
+ model_runner,
170
+ req_pool_indices,
171
+ seq_lens,
172
+ prefix_lens,
173
+ flashinfer_decode_wrapper,
174
+ flashinfer_use_ragged=False,
175
+ ):
176
+ """Init auxiliary variables for FlashInfer attention backend."""
177
+ num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
178
+ num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
179
+ head_dim = model_runner.model_config.head_dim
180
+ batch_size = len(req_pool_indices)
181
+ total_num_tokens = int(torch.sum(seq_lens))
182
+
183
+ if flashinfer_use_ragged:
184
+ paged_kernel_lens = prefix_lens
185
+ else:
186
+ paged_kernel_lens = seq_lens
187
+
188
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
189
+ kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
190
+ req_pool_indices_cpu = req_pool_indices.cpu().numpy()
191
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
192
+ kv_indices = torch.cat(
193
+ [
194
+ model_runner.req_to_token_pool.req_to_token[
195
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
196
+ ]
197
+ for i in range(batch_size)
198
+ ],
199
+ dim=0,
200
+ ).contiguous()
201
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
202
+
203
+ if forward_mode == ForwardMode.DECODE:
204
+ flashinfer_decode_wrapper.end_forward()
205
+ flashinfer_decode_wrapper.begin_forward(
206
+ kv_indptr,
207
+ kv_indices,
208
+ kv_last_page_len,
209
+ num_qo_heads,
210
+ num_kv_heads,
211
+ head_dim,
212
+ 1,
213
+ )
214
+ else:
215
+ # extend part
216
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
217
+ qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
218
+
219
+ if flashinfer_use_ragged:
220
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
221
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
222
+ qo_indptr,
223
+ qo_indptr,
224
+ num_qo_heads,
225
+ num_kv_heads,
226
+ head_dim,
227
+ )
228
+
229
+ # cached part
230
+ model_runner.flashinfer_prefill_wrapper_paged.end_forward()
231
+ model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
232
+ qo_indptr,
233
+ kv_indptr,
234
+ kv_indices,
235
+ kv_last_page_len,
236
+ num_qo_heads,
237
+ num_kv_heads,
238
+ head_dim,
239
+ 1,
240
+ )
241
+
242
+
243
+ def init_triton_args(forward_mode, seq_lens, prefix_lens):
244
+ """Init auxiliary variables for triton attention backend."""
245
+ batch_size = len(seq_lens)
246
+ max_seq_len = int(torch.max(seq_lens))
247
+ start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
248
+ start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
249
+
250
+ if forward_mode == ForwardMode.DECODE:
251
+ max_extend_len = None
252
+ else:
253
+ extend_seq_lens = seq_lens - prefix_lens
254
+ max_extend_len = int(torch.max(extend_seq_lens))
255
+
256
+ return max_seq_len, max_extend_len, start_loc, prefix_lens
@@ -41,18 +41,14 @@ from vllm.distributed import (
41
41
  from vllm.model_executor.models import ModelRegistry
42
42
 
43
43
  from sglang.global_config import global_config
44
- from sglang.srt.managers.schedule_batch import (
45
- Batch,
46
- ForwardMode,
47
- InputMetadata,
48
- global_server_args_dict,
49
- )
44
+ from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
50
45
  from sglang.srt.mem_cache.memory_pool import (
51
46
  MHATokenToKVPool,
52
47
  MLATokenToKVPool,
53
48
  ReqToTokenPool,
54
49
  )
55
50
  from sglang.srt.model_config import AttentionArch
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
56
52
  from sglang.srt.server_args import ServerArgs
57
53
  from sglang.srt.utils import (
58
54
  get_available_gpu_memory,
@@ -350,7 +346,7 @@ class ModelRunner:
350
346
  )
351
347
 
352
348
  @torch.inference_mode()
353
- def forward_decode(self, batch: Batch):
349
+ def forward_decode(self, batch: ScheduleBatch):
354
350
  if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
355
351
  return self.cuda_graph_runner.replay(batch)
356
352
 
@@ -370,7 +366,7 @@ class ModelRunner:
370
366
  )
371
367
 
372
368
  @torch.inference_mode()
373
- def forward_extend(self, batch: Batch):
369
+ def forward_extend(self, batch: ScheduleBatch):
374
370
  input_metadata = InputMetadata.create(
375
371
  self,
376
372
  forward_mode=ForwardMode.EXTEND,
@@ -387,7 +383,7 @@ class ModelRunner:
387
383
  )
388
384
 
389
385
  @torch.inference_mode()
390
- def forward_extend_multi_modal(self, batch: Batch):
386
+ def forward_extend_multi_modal(self, batch: ScheduleBatch):
391
387
  input_metadata = InputMetadata.create(
392
388
  self,
393
389
  forward_mode=ForwardMode.EXTEND,
@@ -408,7 +404,7 @@ class ModelRunner:
408
404
  batch.image_offsets,
409
405
  )
410
406
 
411
- def forward(self, batch: Batch, forward_mode: ForwardMode):
407
+ def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
412
408
  if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
413
409
  return self.forward_extend_multi_modal(batch)
414
410
  elif forward_mode == ForwardMode.DECODE:
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
  LoraConfig = None
51
51
 
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
64
64
 
65
65
  from sglang.srt.layers.logits_processor import LogitsProcessor
66
66
  from sglang.srt.layers.radix_attention import RadixAttention
67
- from sglang.srt.model_executor.model_runner import InputMetadata
67
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
68
68
 
69
69
 
70
70
  @torch.compile
sglang/srt/models/dbrx.py CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
 
51
51
  class DbrxRouter(nn.Module):
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
46
 
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
- from sglang.srt.managers.schedule_batch import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
51
51
 
52
52
  class DeepseekMLP(nn.Module):
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
- from sglang.srt.model_executor.model_runner import InputMetadata
49
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
50
 
51
51
 
52
52
  class DeepseekV2MLP(nn.Module):
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
37
 
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.model_executor.model_runner import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
41
41
 
42
42
 
43
43
  class GemmaMLP(nn.Module):
@@ -42,7 +42,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
42
42
 
43
43
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
44
  from sglang.srt.layers.radix_attention import RadixAttention
45
- from sglang.srt.model_executor.model_runner import InputMetadata
45
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
46
 
47
47
 
48
48
  class GemmaRMSNorm(CustomOp):
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
 
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.managers.schedule_batch import InputMetadata
38
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
39
 
40
40
 
41
41
  class GPTBigCodeAttention(nn.Module):
sglang/srt/models/grok.py CHANGED
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
52
52
  from sglang.srt.layers.fused_moe import fused_moe
53
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
54
  from sglang.srt.layers.radix_attention import RadixAttention
55
- from sglang.srt.model_executor.model_runner import InputMetadata
55
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
56
56
 
57
57
  use_fused = True
58
58
 
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
 
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.model_runner import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
44
 
45
45
 
46
46
  class InternLM2MLP(nn.Module):
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
 
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.model_executor.model_runner import InputMetadata
44
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
45
 
46
46
 
47
47
  class LlamaMLP(nn.Module):
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
27
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
28
- from sglang.srt.model_executor.model_runner import InputMetadata
28
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
31
31
 
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
34
 
35
- from sglang.srt.managers.schedule_batch import ForwardMode
36
35
  from sglang.srt.mm_utils import (
37
36
  get_anyres_image_grid_shape,
38
37
  unpad_image,
39
38
  unpad_image_shape,
40
39
  )
41
- from sglang.srt.model_executor.model_runner import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
42
41
  from sglang.srt.models.llama2 import LlamaForCausalLM
43
42
  from sglang.srt.models.mistral import MistralForCausalLM
44
43
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM