sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,9 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
- from abc import ABC, abstractmethod
20
19
  from typing import List, Tuple, Union
21
20
 
21
+ import numpy as np
22
22
  import torch
23
23
 
24
24
  logger = logging.getLogger(__name__)
@@ -27,12 +27,17 @@ logger = logging.getLogger(__name__)
27
27
  class ReqToTokenPool:
28
28
  """A memory pool that maps a request to its token locations."""
29
29
 
30
- def __init__(self, size: int, max_context_len: int):
30
+ def __init__(self, size: int, max_context_len: int, device: str):
31
31
  self.size = size
32
- self.free_slots = list(range(size))
32
+ self.max_context_len = max_context_len
33
+ self.device = device
33
34
  self.req_to_token = torch.empty(
34
- (size, max_context_len), dtype=torch.int32, device="cuda"
35
+ (size, max_context_len), dtype=torch.int32, device=device
35
36
  )
37
+ self.free_slots = list(range(size))
38
+
39
+ def available_size(self):
40
+ return len(self.free_slots)
36
41
 
37
42
  def alloc(self, need_size: int) -> List[int]:
38
43
  if need_size > len(self.free_slots):
@@ -53,86 +58,55 @@ class ReqToTokenPool:
53
58
  self.free_slots = list(range(self.size))
54
59
 
55
60
 
56
- class BaseTokenToKVPool(ABC):
61
+ class BaseTokenToKVPool:
57
62
  """A memory pool that maps a token to its kv cache locations"""
58
63
 
59
64
  def __init__(
60
65
  self,
61
66
  size: int,
62
67
  dtype: torch.dtype,
68
+ device: str,
63
69
  ):
64
70
  self.size = size
65
71
  self.dtype = dtype
72
+ self.device = device
66
73
  if dtype == torch.float8_e5m2:
67
74
  # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
68
75
  self.store_dtype = torch.uint8
69
76
  else:
70
77
  self.store_dtype = dtype
71
78
 
72
- # We also add one slot. This slot is used for writing dummy output from padded tokens.
73
- self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
74
-
75
- # Prefetch buffer
76
- self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
77
- self.prefetch_chunk_size = 512
78
-
79
- self.can_use_mem_size = self.size
79
+ self.free_slots = None
80
80
  self.clear()
81
81
 
82
82
  def available_size(self):
83
- return self.can_use_mem_size + len(self.prefetch_buffer)
83
+ return len(self.free_slots)
84
84
 
85
85
  def alloc(self, need_size: int):
86
- buffer_len = len(self.prefetch_buffer)
87
- if need_size <= buffer_len:
88
- select_index = self.prefetch_buffer[:need_size]
89
- self.prefetch_buffer = self.prefetch_buffer[need_size:]
90
- return select_index
91
-
92
- addition_size = need_size - buffer_len
93
- alloc_size = max(addition_size, self.prefetch_chunk_size)
94
- select_index = (
95
- torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
96
- )
97
-
98
- if select_index.shape[0] < addition_size:
86
+ if need_size > len(self.free_slots):
99
87
  return None
100
88
 
101
- self.mem_state[select_index] = False
102
- self.can_use_mem_size -= len(select_index)
103
-
104
- self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
105
- ret_index = self.prefetch_buffer[:need_size]
106
- self.prefetch_buffer = self.prefetch_buffer[need_size:]
89
+ select_index = self.free_slots[:need_size]
90
+ self.free_slots = self.free_slots[need_size:]
107
91
 
108
- return ret_index
92
+ return torch.tensor(select_index, dtype=torch.int32, device=self.device)
109
93
 
110
94
  def free(self, free_index: torch.Tensor):
111
- self.mem_state[free_index] = True
112
- self.can_use_mem_size += len(free_index)
95
+ self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
113
96
 
114
97
  def clear(self):
115
- self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
98
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
99
+ self.free_slots = np.arange(1, self.size + 1)
116
100
 
117
- self.mem_state.fill_(True)
118
- self.can_use_mem_size = self.size
119
-
120
- # We also add one slot. This slot is used for writing dummy output from padded tokens.
121
- self.mem_state[0] = False
122
-
123
- @abstractmethod
124
101
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
125
102
  raise NotImplementedError()
126
103
 
127
- @abstractmethod
128
104
  def get_value_buffer(self, layer_id: int) -> torch.Tensor:
129
105
  raise NotImplementedError()
130
106
 
131
- @abstractmethod
132
107
  def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
133
108
  raise NotImplementedError()
134
109
 
135
- @abstractmethod
136
110
  def set_kv_buffer(
137
111
  self,
138
112
  layer_id: int,
@@ -152,19 +126,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
152
126
  head_num: int,
153
127
  head_dim: int,
154
128
  layer_num: int,
129
+ device: str,
155
130
  ):
156
- super().__init__(size, dtype)
131
+ super().__init__(size, dtype, device)
157
132
 
158
133
  # [size, head_num, head_dim] for each layer
134
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
159
135
  self.k_buffer = [
160
136
  torch.empty(
161
- (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
137
+ (size + 1, head_num, head_dim),
138
+ dtype=self.store_dtype,
139
+ device=device,
162
140
  )
163
141
  for _ in range(layer_num)
164
142
  ]
165
143
  self.v_buffer = [
166
144
  torch.empty(
167
- (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
145
+ (size + 1, head_num, head_dim),
146
+ dtype=self.store_dtype,
147
+ device=device,
168
148
  )
169
149
  for _ in range(layer_num)
170
150
  ]
@@ -210,15 +190,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
210
190
  kv_lora_rank: int,
211
191
  qk_rope_head_dim: int,
212
192
  layer_num: int,
193
+ device: str,
213
194
  ):
214
- super().__init__(size, dtype)
195
+ super().__init__(size, dtype, device)
215
196
 
216
197
  self.kv_lora_rank = kv_lora_rank
198
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
217
199
  self.kv_buffer = [
218
200
  torch.empty(
219
201
  (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
220
202
  dtype=self.store_dtype,
221
- device="cuda",
203
+ device=device,
222
204
  )
223
205
  for _ in range(layer_num)
224
206
  ]
@@ -31,8 +31,7 @@ from sglang.srt.layers.logits_processor import (
31
31
  LogitsProcessor,
32
32
  LogitsProcessorOutput,
33
33
  )
34
- from sglang.srt.managers.schedule_batch import ScheduleBatch
35
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
36
35
  from sglang.srt.utils import monkey_patch_vllm_all_gather
37
36
 
38
37
  if TYPE_CHECKING:
@@ -143,7 +142,6 @@ class CudaGraphRunner:
143
142
  self.seq_lens = torch.full(
144
143
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
145
144
  )
146
- self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
147
145
  self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
148
146
 
149
147
  # Capture
@@ -189,7 +187,6 @@ class CudaGraphRunner:
189
187
  input_ids = self.input_ids[:bs]
190
188
  req_pool_indices = self.req_pool_indices[:bs]
191
189
  seq_lens = self.seq_lens[:bs]
192
- position_ids_offsets = self.position_ids_offsets[:bs]
193
190
  out_cache_loc = self.out_cache_loc[:bs]
194
191
 
195
192
  # Attention backend
@@ -199,9 +196,10 @@ class CudaGraphRunner:
199
196
 
200
197
  # Run and capture
201
198
  def run_once():
202
- input_metadata = InputMetadata(
199
+ forward_batch = ForwardBatch(
203
200
  forward_mode=ForwardMode.DECODE,
204
201
  batch_size=bs,
202
+ input_ids=input_ids,
205
203
  req_pool_indices=req_pool_indices,
206
204
  seq_lens=seq_lens,
207
205
  req_to_token_pool=self.model_runner.req_to_token_pool,
@@ -210,9 +208,9 @@ class CudaGraphRunner:
210
208
  out_cache_loc=out_cache_loc,
211
209
  return_logprob=False,
212
210
  top_logprobs_nums=[0] * bs,
213
- positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
211
+ positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
214
212
  )
215
- return forward(input_ids, input_metadata.positions, input_metadata)
213
+ return forward(input_ids, forward_batch.positions, forward_batch)
216
214
 
217
215
  for _ in range(2):
218
216
  torch.cuda.synchronize()
@@ -235,24 +233,22 @@ class CudaGraphRunner:
235
233
  self.graph_memory_pool = graph.pool()
236
234
  return graph, out
237
235
 
238
- def replay(self, batch: ScheduleBatch):
239
- assert batch.out_cache_loc is not None
240
- raw_bs = len(batch.reqs)
236
+ def replay(self, forward_batch: ForwardBatch):
237
+ assert forward_batch.out_cache_loc is not None
238
+ raw_bs = forward_batch.batch_size
241
239
 
242
240
  # Pad
243
241
  index = bisect.bisect_left(self.capture_bs, raw_bs)
244
242
  bs = self.capture_bs[index]
245
243
  if bs != raw_bs:
246
244
  self.seq_lens.fill_(self.seq_len_fill_value)
247
- self.position_ids_offsets.fill_(1)
248
245
  self.out_cache_loc.zero_()
249
246
 
250
247
  # Common inputs
251
- self.input_ids[:raw_bs] = batch.input_ids
252
- self.req_pool_indices[:raw_bs] = batch.req_pool_indices
253
- self.seq_lens[:raw_bs] = batch.seq_lens
254
- self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
255
- self.out_cache_loc[:raw_bs] = batch.out_cache_loc
248
+ self.input_ids[:raw_bs] = forward_batch.input_ids
249
+ self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
250
+ self.seq_lens[:raw_bs] = forward_batch.seq_lens
251
+ self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
256
252
 
257
253
  # Attention backend
258
254
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
@@ -275,15 +271,15 @@ class CudaGraphRunner:
275
271
  )
276
272
 
277
273
  # Extract logprobs
278
- if batch.return_logprob:
274
+ if forward_batch.return_logprob:
279
275
  logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
280
276
  logits_output.next_token_logits, dim=-1
281
277
  )
282
- return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
278
+ return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
283
279
  if return_top_logprob:
284
280
  logits_metadata = LogitsMetadata(
285
281
  forward_mode=ForwardMode.DECODE,
286
- top_logprobs_nums=batch.top_logprobs_nums,
282
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
287
283
  )
288
284
  logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
289
285
  logits_output.next_token_logprobs, logits_metadata
@@ -15,19 +15,33 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- """Meta data for a forward pass."""
18
+ """
19
+ Store information about a forward batch.
20
+
21
+ The following is the flow of data structures for a batch:
22
+
23
+ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
24
+
25
+ - ScheduleBatch is managed by `scheduler.py::Scheduler`.
26
+ It contains high-level scheduling data. Most of the data is on the CPU.
27
+ - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
+ - ForwardBatch is managed by `model_runner.py::ModelRunner`.
29
+ It contains low-level tensor data. Most of the data consists of GPU tensors.
30
+ """
31
+
19
32
  from dataclasses import dataclass
20
33
  from enum import IntEnum, auto
21
- from typing import TYPE_CHECKING, List
34
+ from typing import TYPE_CHECKING, List, Optional
22
35
 
23
36
  import numpy as np
24
37
  import torch
25
38
 
26
39
  if TYPE_CHECKING:
27
- from sglang.srt.layers.attention_backend import AttentionBackend
28
- from sglang.srt.managers.schedule_batch import ScheduleBatch
40
+ from sglang.srt.layers.attention import AttentionBackend
41
+ from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
29
42
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
30
43
  from sglang.srt.model_executor.model_runner import ModelRunner
44
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
31
45
 
32
46
 
33
47
  class ForwardMode(IntEnum):
@@ -37,7 +51,7 @@ class ForwardMode(IntEnum):
37
51
  EXTEND = auto()
38
52
  # Decode one token.
39
53
  DECODE = auto()
40
- # Contains both PREFILL and EXTEND.
54
+ # Contains both EXTEND and DECODE.
41
55
  MIXED = auto()
42
56
 
43
57
  def is_prefill(self):
@@ -54,121 +68,106 @@ class ForwardMode(IntEnum):
54
68
 
55
69
 
56
70
  @dataclass
57
- class InputMetadata:
58
- """Store all inforamtion of a forward pass."""
71
+ class ForwardBatch:
72
+ """Store all inputs of a forward pass."""
59
73
 
74
+ # The forward mode
60
75
  forward_mode: ForwardMode
76
+ # The batch size
61
77
  batch_size: int
78
+ # The input ids
79
+ input_ids: torch.Tensor
80
+ # The indices of requests in the req_to_token_pool
62
81
  req_pool_indices: torch.Tensor
82
+ # The sequence length
63
83
  seq_lens: torch.Tensor
64
- req_to_token_pool: ReqToTokenPool
65
- token_to_kv_pool: BaseTokenToKVPool
66
- attn_backend: AttentionBackend
67
-
68
- # Output location of the KV cache
84
+ # The indices of output tokens in the token_to_kv_pool
69
85
  out_cache_loc: torch.Tensor
70
86
 
87
+ # For logprob
88
+ return_logprob: bool = False
89
+ top_logprobs_nums: Optional[List[int]] = None
90
+
71
91
  # Position information
72
92
  positions: torch.Tensor = None
73
93
 
74
94
  # For extend
75
- extend_seq_lens: torch.Tensor = None
76
- extend_prefix_lens: torch.Tensor = None
77
- extend_start_loc: torch.Tensor = None
78
- extend_no_prefix: bool = None
79
-
80
- # For logprob
81
- return_logprob: bool = False
82
- top_logprobs_nums: List[int] = None
83
- extend_seq_lens_cpu: List[int] = None
84
- extend_logprob_start_lens_cpu: List[int] = None
95
+ extend_seq_lens: Optional[torch.Tensor] = None
96
+ extend_prefix_lens: Optional[torch.Tensor] = None
97
+ extend_start_loc: Optional[torch.Tensor] = None
98
+ extend_seq_lens_cpu: Optional[List[int]] = None
99
+ extend_logprob_start_lens_cpu: Optional[List[int]] = None
85
100
 
86
101
  # For multimodal
87
- pixel_values: List[torch.Tensor] = None
88
- image_sizes: List[List[List[int]]] = None
89
- image_offsets: List[List[int]] = None
90
- modalities: List[List[str]] = None
91
-
92
- def init_multimuldal_info(self, batch: ScheduleBatch):
93
- reqs = batch.reqs
94
- self.pixel_values = [r.pixel_values for r in reqs]
95
- self.image_sizes = [r.image_sizes for r in reqs]
96
- self.image_offsets = [r.image_offsets for r in reqs]
97
- self.modalities = [r.modalities for r in reqs]
98
-
99
- def compute_positions(self, batch: ScheduleBatch):
100
- if self.forward_mode.is_decode():
101
- if True:
102
- self.positions = self.seq_lens - 1
103
- else:
104
- # Deprecated
105
- self.positions = (self.seq_lens - 1) + batch.position_ids_offsets
106
- else:
107
- if True:
108
- self.positions = torch.tensor(
109
- np.concatenate(
110
- [
111
- np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
112
- for i, req in enumerate(batch.reqs)
113
- ],
114
- axis=0,
115
- ),
116
- device="cuda",
117
- )
118
- else:
119
- # Deprecated
120
- position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy()
121
- self.positions = torch.tensor(
122
- np.concatenate(
123
- [
124
- np.arange(
125
- batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
126
- len(req.fill_ids) + position_ids_offsets_cpu[i],
127
- )
128
- for i, req in enumerate(batch.reqs)
129
- ],
130
- axis=0,
131
- ),
132
- device="cuda",
133
- )
134
-
135
- # Positions should be in long type
136
- self.positions = self.positions.to(torch.int64)
137
-
138
- def compute_extend_infos(self, batch: ScheduleBatch):
139
- self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
140
- self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
141
- self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
142
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
143
- self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
144
- self.extend_seq_lens_cpu = batch.extend_lens_cpu
145
- self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
102
+ image_inputs: Optional[List[ImageInputs]] = None
103
+
104
+ # For LoRA
105
+ lora_paths: Optional[List[str]] = None
106
+
107
+ # Sampling info
108
+ sampling_info: SamplingBatchInfo = None
109
+
110
+ # Attention backend
111
+ req_to_token_pool: ReqToTokenPool = None
112
+ token_to_kv_pool: BaseTokenToKVPool = None
113
+ attn_backend: AttentionBackend = None
146
114
 
147
115
  @classmethod
148
- def from_schedule_batch(
116
+ def init_new(
149
117
  cls,
150
- model_runner: "ModelRunner",
151
- batch: ScheduleBatch,
118
+ batch: ModelWorkerBatch,
119
+ model_runner: ModelRunner,
152
120
  ):
121
+ device = "cuda"
122
+
153
123
  ret = cls(
154
124
  forward_mode=batch.forward_mode,
155
- batch_size=batch.batch_size(),
125
+ batch_size=len(batch.seq_lens),
126
+ input_ids=batch.input_ids,
156
127
  req_pool_indices=batch.req_pool_indices,
157
128
  seq_lens=batch.seq_lens,
158
- req_to_token_pool=model_runner.req_to_token_pool,
159
- token_to_kv_pool=model_runner.token_to_kv_pool,
160
- attn_backend=model_runner.attn_backend,
161
129
  out_cache_loc=batch.out_cache_loc,
162
130
  return_logprob=batch.return_logprob,
163
131
  top_logprobs_nums=batch.top_logprobs_nums,
132
+ lora_paths=batch.lora_paths,
133
+ sampling_info=batch.sampling_info,
164
134
  )
165
135
 
166
- ret.compute_positions(batch)
167
-
168
- if not batch.forward_mode.is_decode():
169
- ret.init_multimuldal_info(batch)
170
- ret.compute_extend_infos(batch)
171
-
172
- model_runner.attn_backend.init_forward_metadata(batch, ret)
136
+ # Init position information
137
+ if ret.forward_mode.is_decode():
138
+ ret.positions = (ret.seq_lens - 1).to(torch.int64)
139
+ else:
140
+ ret.positions = torch.tensor(
141
+ np.concatenate(
142
+ [
143
+ np.arange(prefix_len, prefix_len + extend_len)
144
+ for prefix_len, extend_len in zip(
145
+ batch.extend_prefix_lens, batch.extend_seq_lens
146
+ )
147
+ ],
148
+ axis=0,
149
+ ),
150
+ device=device,
151
+ ).to(torch.int64)
152
+
153
+ ret.image_inputs = batch.image_inputs
154
+ ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
155
+ ret.extend_prefix_lens = torch.tensor(
156
+ batch.extend_prefix_lens, device=device
157
+ )
158
+ ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
159
+ ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
160
+ ret.extend_seq_lens_cpu = batch.extend_seq_lens
161
+ ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
162
+
163
+ # Init attention information
164
+ ret.req_to_token_pool = model_runner.req_to_token_pool
165
+ ret.token_to_kv_pool = model_runner.token_to_kv_pool
166
+ ret.attn_backend = model_runner.attn_backend
167
+ model_runner.attn_backend.init_forward_metadata(ret)
168
+
169
+ # Init lora information
170
+ if model_runner.server_args.lora_paths is not None:
171
+ model_runner.lora_manager.prepare_lora_batch(ret)
173
172
 
174
173
  return ret