sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -16,9 +16,9 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
- from abc import ABC, abstractmethod
20
19
  from typing import List, Tuple, Union
21
20
 
21
+ import numpy as np
22
22
  import torch
23
23
 
24
24
  logger = logging.getLogger(__name__)
@@ -27,12 +27,17 @@ logger = logging.getLogger(__name__)
27
27
  class ReqToTokenPool:
28
28
  """A memory pool that maps a request to its token locations."""
29
29
 
30
- def __init__(self, size: int, max_context_len: int):
30
+ def __init__(self, size: int, max_context_len: int, device: str):
31
31
  self.size = size
32
- self.free_slots = list(range(size))
32
+ self.max_context_len = max_context_len
33
+ self.device = device
33
34
  self.req_to_token = torch.empty(
34
- (size, max_context_len), dtype=torch.int32, device="cuda"
35
+ (size, max_context_len), dtype=torch.int32, device=device
35
36
  )
37
+ self.free_slots = list(range(size))
38
+
39
+ def available_size(self):
40
+ return len(self.free_slots)
36
41
 
37
42
  def alloc(self, need_size: int) -> List[int]:
38
43
  if need_size > len(self.free_slots):
@@ -53,86 +58,55 @@ class ReqToTokenPool:
53
58
  self.free_slots = list(range(self.size))
54
59
 
55
60
 
56
- class BaseTokenToKVPool(ABC):
61
+ class BaseTokenToKVPool:
57
62
  """A memory pool that maps a token to its kv cache locations"""
58
63
 
59
64
  def __init__(
60
65
  self,
61
66
  size: int,
62
67
  dtype: torch.dtype,
68
+ device: str,
63
69
  ):
64
70
  self.size = size
65
71
  self.dtype = dtype
72
+ self.device = device
66
73
  if dtype == torch.float8_e5m2:
67
74
  # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
68
75
  self.store_dtype = torch.uint8
69
76
  else:
70
77
  self.store_dtype = dtype
71
78
 
72
- # We also add one slot. This slot is used for writing dummy output from padded tokens.
73
- self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
74
-
75
- # Prefetch buffer
76
- self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
77
- self.prefetch_chunk_size = 512
78
-
79
- self.can_use_mem_size = self.size
79
+ self.free_slots = None
80
80
  self.clear()
81
81
 
82
82
  def available_size(self):
83
- return self.can_use_mem_size + len(self.prefetch_buffer)
83
+ return len(self.free_slots)
84
84
 
85
85
  def alloc(self, need_size: int):
86
- buffer_len = len(self.prefetch_buffer)
87
- if need_size <= buffer_len:
88
- select_index = self.prefetch_buffer[:need_size]
89
- self.prefetch_buffer = self.prefetch_buffer[need_size:]
90
- return select_index
91
-
92
- addition_size = need_size - buffer_len
93
- alloc_size = max(addition_size, self.prefetch_chunk_size)
94
- select_index = (
95
- torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
96
- )
97
-
98
- if select_index.shape[0] < addition_size:
86
+ if need_size > len(self.free_slots):
99
87
  return None
100
88
 
101
- self.mem_state[select_index] = False
102
- self.can_use_mem_size -= len(select_index)
103
-
104
- self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
105
- ret_index = self.prefetch_buffer[:need_size]
106
- self.prefetch_buffer = self.prefetch_buffer[need_size:]
89
+ select_index = self.free_slots[:need_size]
90
+ self.free_slots = self.free_slots[need_size:]
107
91
 
108
- return ret_index
92
+ return torch.tensor(select_index, dtype=torch.int32, device=self.device)
109
93
 
110
94
  def free(self, free_index: torch.Tensor):
111
- self.mem_state[free_index] = True
112
- self.can_use_mem_size += len(free_index)
95
+ self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
113
96
 
114
97
  def clear(self):
115
- self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
98
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
99
+ self.free_slots = np.arange(1, self.size + 1)
116
100
 
117
- self.mem_state.fill_(True)
118
- self.can_use_mem_size = self.size
119
-
120
- # We also add one slot. This slot is used for writing dummy output from padded tokens.
121
- self.mem_state[0] = False
122
-
123
- @abstractmethod
124
101
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
125
102
  raise NotImplementedError()
126
103
 
127
- @abstractmethod
128
104
  def get_value_buffer(self, layer_id: int) -> torch.Tensor:
129
105
  raise NotImplementedError()
130
106
 
131
- @abstractmethod
132
107
  def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
133
108
  raise NotImplementedError()
134
109
 
135
- @abstractmethod
136
110
  def set_kv_buffer(
137
111
  self,
138
112
  layer_id: int,
@@ -152,19 +126,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
152
126
  head_num: int,
153
127
  head_dim: int,
154
128
  layer_num: int,
129
+ device: str,
155
130
  ):
156
- super().__init__(size, dtype)
131
+ super().__init__(size, dtype, device)
157
132
 
158
133
  # [size, head_num, head_dim] for each layer
134
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
159
135
  self.k_buffer = [
160
136
  torch.empty(
161
- (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
137
+ (size + 1, head_num, head_dim),
138
+ dtype=self.store_dtype,
139
+ device=device,
162
140
  )
163
141
  for _ in range(layer_num)
164
142
  ]
165
143
  self.v_buffer = [
166
144
  torch.empty(
167
- (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
145
+ (size + 1, head_num, head_dim),
146
+ dtype=self.store_dtype,
147
+ device=device,
168
148
  )
169
149
  for _ in range(layer_num)
170
150
  ]
@@ -210,15 +190,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
210
190
  kv_lora_rank: int,
211
191
  qk_rope_head_dim: int,
212
192
  layer_num: int,
193
+ device: str,
213
194
  ):
214
- super().__init__(size, dtype)
195
+ super().__init__(size, dtype, device)
215
196
 
216
197
  self.kv_lora_rank = kv_lora_rank
198
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
217
199
  self.kv_buffer = [
218
200
  torch.empty(
219
201
  (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
220
202
  dtype=self.store_dtype,
221
- device="cuda",
203
+ device=device,
222
204
  )
223
205
  for _ in range(layer_num)
224
206
  ]
@@ -291,15 +291,15 @@ class RadixCache(BasePrefixCache):
291
291
 
292
292
  def _collect_leaves(self):
293
293
  ret_list = []
294
+ stack = [self.root_node]
294
295
 
295
- def dfs_(cur_node):
296
+ while stack:
297
+ cur_node = stack.pop()
296
298
  if len(cur_node.children) == 0:
297
299
  ret_list.append(cur_node)
300
+ else:
301
+ stack.extend(cur_node.children.values())
298
302
 
299
- for x in cur_node.children.values():
300
- dfs_(x)
301
-
302
- dfs_(self.root_node)
303
303
  return ret_list
304
304
 
305
305
 
@@ -25,13 +25,13 @@ import torch
25
25
  from vllm.distributed.parallel_state import graph_capture
26
26
  from vllm.model_executor.custom_op import CustomOp
27
27
 
28
+ from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native
28
29
  from sglang.srt.layers.logits_processor import (
29
30
  LogitsMetadata,
30
31
  LogitsProcessor,
31
32
  LogitsProcessorOutput,
32
33
  )
33
- from sglang.srt.managers.schedule_batch import ScheduleBatch
34
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
35
35
  from sglang.srt.utils import monkey_patch_vllm_all_gather
36
36
 
37
37
  if TYPE_CHECKING:
@@ -41,14 +41,15 @@ if TYPE_CHECKING:
41
41
  def _to_torch(model: torch.nn.Module, reverse: bool = False):
42
42
  for sub in model._modules.values():
43
43
  if isinstance(sub, CustomOp):
44
- # NOTE: FusedMoE torch native implementaiton is not efficient
45
- if "FusedMoE" in sub.__class__.__name__:
46
- continue
47
44
  if reverse:
48
45
  sub._forward_method = sub.forward_cuda
49
46
  setattr(sub, "is_torch_compile", False)
50
47
  else:
51
- sub._forward_method = sub.forward_native
48
+ # NOTE: Temporarily workaround MoE
49
+ if "FusedMoE" in sub.__class__.__name__:
50
+ sub._forward_method = fused_moe_forward_native
51
+ else:
52
+ sub._forward_method = sub.forward_native
52
53
  setattr(sub, "is_torch_compile", True)
53
54
  if isinstance(sub, torch.nn.Module):
54
55
  _to_torch(sub, reverse)
@@ -67,7 +68,9 @@ def patch_model(
67
68
  monkey_patch_vllm_all_gather()
68
69
  backup_ca_comm = tp_group.ca_comm
69
70
  tp_group.ca_comm = None
70
- yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
71
+ yield torch.compile(
72
+ torch.no_grad()(model.forward), mode="max-autotune-no-cudagraphs"
73
+ )
71
74
  else:
72
75
  yield model.forward
73
76
  finally:
@@ -139,7 +142,6 @@ class CudaGraphRunner:
139
142
  self.seq_lens = torch.full(
140
143
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
141
144
  )
142
- self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
143
145
  self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
144
146
 
145
147
  # Capture
@@ -150,7 +152,7 @@ class CudaGraphRunner:
150
152
  f"Capture cuda graph failed: {e}\n"
151
153
  "Possible solutions:\n"
152
154
  "1. disable cuda graph by --disable-cuda-graph\n"
153
- "2. set --mem-fraction-static to a smaller value\n"
155
+ "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
154
156
  "3. disable torch compile by not using --enable-torch-compile\n"
155
157
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
156
158
  )
@@ -185,7 +187,6 @@ class CudaGraphRunner:
185
187
  input_ids = self.input_ids[:bs]
186
188
  req_pool_indices = self.req_pool_indices[:bs]
187
189
  seq_lens = self.seq_lens[:bs]
188
- position_ids_offsets = self.position_ids_offsets[:bs]
189
190
  out_cache_loc = self.out_cache_loc[:bs]
190
191
 
191
192
  # Attention backend
@@ -195,9 +196,10 @@ class CudaGraphRunner:
195
196
 
196
197
  # Run and capture
197
198
  def run_once():
198
- input_metadata = InputMetadata(
199
+ forward_batch = ForwardBatch(
199
200
  forward_mode=ForwardMode.DECODE,
200
201
  batch_size=bs,
202
+ input_ids=input_ids,
201
203
  req_pool_indices=req_pool_indices,
202
204
  seq_lens=seq_lens,
203
205
  req_to_token_pool=self.model_runner.req_to_token_pool,
@@ -206,9 +208,9 @@ class CudaGraphRunner:
206
208
  out_cache_loc=out_cache_loc,
207
209
  return_logprob=False,
208
210
  top_logprobs_nums=[0] * bs,
209
- positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
211
+ positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
210
212
  )
211
- return forward(input_ids, input_metadata.positions, input_metadata)
213
+ return forward(input_ids, forward_batch.positions, forward_batch)
212
214
 
213
215
  for _ in range(2):
214
216
  torch.cuda.synchronize()
@@ -231,24 +233,22 @@ class CudaGraphRunner:
231
233
  self.graph_memory_pool = graph.pool()
232
234
  return graph, out
233
235
 
234
- def replay(self, batch: ScheduleBatch):
235
- assert batch.out_cache_loc is not None
236
- raw_bs = len(batch.reqs)
236
+ def replay(self, forward_batch: ForwardBatch):
237
+ assert forward_batch.out_cache_loc is not None
238
+ raw_bs = forward_batch.batch_size
237
239
 
238
240
  # Pad
239
241
  index = bisect.bisect_left(self.capture_bs, raw_bs)
240
242
  bs = self.capture_bs[index]
241
243
  if bs != raw_bs:
242
244
  self.seq_lens.fill_(self.seq_len_fill_value)
243
- self.position_ids_offsets.fill_(1)
244
245
  self.out_cache_loc.zero_()
245
246
 
246
247
  # Common inputs
247
- self.input_ids[:raw_bs] = batch.input_ids
248
- self.req_pool_indices[:raw_bs] = batch.req_pool_indices
249
- self.seq_lens[:raw_bs] = batch.seq_lens
250
- self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
251
- self.out_cache_loc[:raw_bs] = batch.out_cache_loc
248
+ self.input_ids[:raw_bs] = forward_batch.input_ids
249
+ self.req_pool_indices[:raw_bs] = forward_batch.req_pool_indices
250
+ self.seq_lens[:raw_bs] = forward_batch.seq_lens
251
+ self.out_cache_loc[:raw_bs] = forward_batch.out_cache_loc
252
252
 
253
253
  # Attention backend
254
254
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
@@ -271,15 +271,15 @@ class CudaGraphRunner:
271
271
  )
272
272
 
273
273
  # Extract logprobs
274
- if batch.return_logprob:
274
+ if forward_batch.return_logprob:
275
275
  logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
276
276
  logits_output.next_token_logits, dim=-1
277
277
  )
278
- return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
278
+ return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
279
279
  if return_top_logprob:
280
280
  logits_metadata = LogitsMetadata(
281
281
  forward_mode=ForwardMode.DECODE,
282
- top_logprobs_nums=batch.top_logprobs_nums,
282
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
283
283
  )
284
284
  logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
285
285
  logits_output.next_token_logprobs, logits_metadata
@@ -15,19 +15,33 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
- """Meta data for a forward pass."""
18
+ """
19
+ Store information about a forward batch.
20
+
21
+ The following is the flow of data structures for a batch:
22
+
23
+ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
24
+
25
+ - ScheduleBatch is managed by `scheduler.py::Scheduler`.
26
+ It contains high-level scheduling data. Most of the data is on the CPU.
27
+ - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
28
+ - ForwardBatch is managed by `model_runner.py::ModelRunner`.
29
+ It contains low-level tensor data. Most of the data consists of GPU tensors.
30
+ """
31
+
19
32
  from dataclasses import dataclass
20
33
  from enum import IntEnum, auto
21
- from typing import TYPE_CHECKING, List
34
+ from typing import TYPE_CHECKING, List, Optional
22
35
 
23
36
  import numpy as np
24
37
  import torch
25
38
 
26
39
  if TYPE_CHECKING:
27
- from sglang.srt.layers.attention_backend import AttentionBackend
28
- from sglang.srt.managers.schedule_batch import ScheduleBatch
40
+ from sglang.srt.layers.attention import AttentionBackend
41
+ from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
29
42
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
30
43
  from sglang.srt.model_executor.model_runner import ModelRunner
44
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
31
45
 
32
46
 
33
47
  class ForwardMode(IntEnum):
@@ -37,7 +51,7 @@ class ForwardMode(IntEnum):
37
51
  EXTEND = auto()
38
52
  # Decode one token.
39
53
  DECODE = auto()
40
- # Contains both PREFILL and EXTEND.
54
+ # Contains both EXTEND and DECODE.
41
55
  MIXED = auto()
42
56
 
43
57
  def is_prefill(self):
@@ -54,123 +68,106 @@ class ForwardMode(IntEnum):
54
68
 
55
69
 
56
70
  @dataclass
57
- class InputMetadata:
58
- """Store all inforamtion of a forward pass."""
71
+ class ForwardBatch:
72
+ """Store all inputs of a forward pass."""
59
73
 
74
+ # The forward mode
60
75
  forward_mode: ForwardMode
76
+ # The batch size
61
77
  batch_size: int
78
+ # The input ids
79
+ input_ids: torch.Tensor
80
+ # The indices of requests in the req_to_token_pool
62
81
  req_pool_indices: torch.Tensor
82
+ # The sequence length
63
83
  seq_lens: torch.Tensor
64
- req_to_token_pool: ReqToTokenPool
65
- token_to_kv_pool: BaseTokenToKVPool
66
- attn_backend: AttentionBackend
67
-
68
- # Output location of the KV cache
84
+ # The indices of output tokens in the token_to_kv_pool
69
85
  out_cache_loc: torch.Tensor
70
86
 
87
+ # For logprob
88
+ return_logprob: bool = False
89
+ top_logprobs_nums: Optional[List[int]] = None
90
+
71
91
  # Position information
72
92
  positions: torch.Tensor = None
73
93
 
74
94
  # For extend
75
- extend_seq_lens: torch.Tensor = None
76
- extend_prefix_lens: torch.Tensor = None
77
- extend_start_loc: torch.Tensor = None
78
- extend_no_prefix: bool = None
79
-
80
- # For logprob
81
- return_logprob: bool = False
82
- top_logprobs_nums: List[int] = None
83
- extend_seq_lens_cpu: List[int] = None
84
- extend_logprob_start_lens_cpu: List[int] = None
95
+ extend_seq_lens: Optional[torch.Tensor] = None
96
+ extend_prefix_lens: Optional[torch.Tensor] = None
97
+ extend_start_loc: Optional[torch.Tensor] = None
98
+ extend_seq_lens_cpu: Optional[List[int]] = None
99
+ extend_logprob_start_lens_cpu: Optional[List[int]] = None
85
100
 
86
101
  # For multimodal
87
- pixel_values: List[torch.Tensor] = None
88
- image_sizes: List[List[List[int]]] = None
89
- image_offsets: List[List[int]] = None
90
- modalities: List[List[str]] = None
91
-
92
- def init_multimuldal_info(self, batch: ScheduleBatch):
93
- reqs = batch.reqs
94
- self.pixel_values = [r.pixel_values for r in reqs]
95
- self.image_sizes = [r.image_sizes for r in reqs]
96
- self.image_offsets = [r.image_offsets for r in reqs]
97
- self.modalities = [r.modalities for r in reqs]
98
-
99
- def compute_positions(self, batch: ScheduleBatch):
100
- position_ids_offsets = batch.position_ids_offsets
101
-
102
- if self.forward_mode.is_decode():
103
- if True:
104
- self.positions = self.seq_lens - 1
105
- else:
106
- # Deprecated
107
- self.positions = (self.seq_lens - 1) + position_ids_offsets
108
- else:
109
- if True:
110
- self.positions = torch.tensor(
111
- np.concatenate(
112
- [
113
- np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
114
- for i, req in enumerate(batch.reqs)
115
- ],
116
- axis=0,
117
- ),
118
- device="cuda",
119
- )
120
- else:
121
- # Deprecated
122
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
123
- self.positions = torch.tensor(
124
- np.concatenate(
125
- [
126
- np.arange(
127
- batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
128
- len(req.fill_ids) + position_ids_offsets_cpu[i],
129
- )
130
- for i, req in enumerate(batch.reqs)
131
- ],
132
- axis=0,
133
- ),
134
- device="cuda",
135
- )
136
-
137
- # Positions should be in long type
138
- self.positions = self.positions.to(torch.int64)
139
-
140
- def compute_extend_infos(self, batch: ScheduleBatch):
141
- self.extend_seq_lens = torch.tensor(batch.extend_lens_cpu, device="cuda")
142
- self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
143
- self.extend_start_loc = torch.zeros_like(self.extend_seq_lens)
144
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
145
- self.extend_no_prefix = all(x == 0 for x in batch.prefix_lens_cpu)
146
- self.extend_seq_lens_cpu = batch.extend_lens_cpu
147
- self.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens_cpu
102
+ image_inputs: Optional[List[ImageInputs]] = None
103
+
104
+ # For LoRA
105
+ lora_paths: Optional[List[str]] = None
106
+
107
+ # Sampling info
108
+ sampling_info: SamplingBatchInfo = None
109
+
110
+ # Attention backend
111
+ req_to_token_pool: ReqToTokenPool = None
112
+ token_to_kv_pool: BaseTokenToKVPool = None
113
+ attn_backend: AttentionBackend = None
148
114
 
149
115
  @classmethod
150
- def from_schedule_batch(
116
+ def init_new(
151
117
  cls,
152
- model_runner: "ModelRunner",
153
- batch: ScheduleBatch,
118
+ batch: ModelWorkerBatch,
119
+ model_runner: ModelRunner,
154
120
  ):
121
+ device = "cuda"
122
+
155
123
  ret = cls(
156
124
  forward_mode=batch.forward_mode,
157
- batch_size=batch.batch_size(),
125
+ batch_size=len(batch.seq_lens),
126
+ input_ids=batch.input_ids,
158
127
  req_pool_indices=batch.req_pool_indices,
159
128
  seq_lens=batch.seq_lens,
160
- req_to_token_pool=model_runner.req_to_token_pool,
161
- token_to_kv_pool=model_runner.token_to_kv_pool,
162
- attn_backend=model_runner.attn_backend,
163
129
  out_cache_loc=batch.out_cache_loc,
164
130
  return_logprob=batch.return_logprob,
165
131
  top_logprobs_nums=batch.top_logprobs_nums,
132
+ lora_paths=batch.lora_paths,
133
+ sampling_info=batch.sampling_info,
166
134
  )
167
135
 
168
- ret.compute_positions(batch)
169
-
170
- if not batch.forward_mode.is_decode():
171
- ret.init_multimuldal_info(batch)
172
- ret.compute_extend_infos(batch)
173
-
174
- model_runner.attn_backend.init_forward_metadata(batch, ret)
136
+ # Init position information
137
+ if ret.forward_mode.is_decode():
138
+ ret.positions = (ret.seq_lens - 1).to(torch.int64)
139
+ else:
140
+ ret.positions = torch.tensor(
141
+ np.concatenate(
142
+ [
143
+ np.arange(prefix_len, prefix_len + extend_len)
144
+ for prefix_len, extend_len in zip(
145
+ batch.extend_prefix_lens, batch.extend_seq_lens
146
+ )
147
+ ],
148
+ axis=0,
149
+ ),
150
+ device=device,
151
+ ).to(torch.int64)
152
+
153
+ ret.image_inputs = batch.image_inputs
154
+ ret.extend_seq_lens = torch.tensor(batch.extend_seq_lens, device=device)
155
+ ret.extend_prefix_lens = torch.tensor(
156
+ batch.extend_prefix_lens, device=device
157
+ )
158
+ ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
159
+ ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
160
+ ret.extend_seq_lens_cpu = batch.extend_seq_lens
161
+ ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
162
+
163
+ # Init attention information
164
+ ret.req_to_token_pool = model_runner.req_to_token_pool
165
+ ret.token_to_kv_pool = model_runner.token_to_kv_pool
166
+ ret.attn_backend = model_runner.attn_backend
167
+ model_runner.attn_backend.init_forward_metadata(ret)
168
+
169
+ # Init lora information
170
+ if model_runner.server_args.lora_paths is not None:
171
+ model_runner.lora_manager.prepare_lora_batch(ret)
175
172
 
176
173
  return ret