sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from vllm.distributed import (
23
23
  tensor_model_parallel_all_gather,
24
24
  )
25
25
 
26
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
26
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
27
28
 
28
29
 
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
163
164
  self,
164
165
  input_ids,
165
166
  hidden_states,
166
- weight,
167
+ lm_head: VocabParallelEmbedding,
167
168
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
168
169
  ):
169
170
  if isinstance(logits_metadata, ForwardBatch):
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
178
179
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
179
180
  last_hidden = hidden_states[last_index]
180
181
 
181
- last_logits = torch.matmul(last_hidden, weight.T)
182
+ last_logits = self._get_logits(last_hidden, lm_head)
182
183
  if self.do_tensor_parallel_all_gather:
183
184
  last_logits = tensor_model_parallel_all_gather(last_logits)
184
185
  last_logits = last_logits[:, : self.config.vocab_size].float()
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
229
230
 
230
231
  # Compute the logits and logprobs for all required tokens
231
232
  states = torch.cat(states, dim=0)
232
- all_logits = torch.matmul(states, weight.T)
233
+ all_logits = self._get_logits(states, lm_head)
233
234
  if self.do_tensor_parallel_all_gather:
234
235
  all_logits = tensor_model_parallel_all_gather(all_logits)
235
236
  all_logits = all_logits[:, : self.config.vocab_size].float()
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
276
277
  output_top_logprobs=output_top_logprobs,
277
278
  )
278
279
 
280
+ def _get_logits(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ lm_head: VocabParallelEmbedding,
284
+ embedding_bias: Optional[torch.Tensor] = None,
285
+ ) -> torch.Tensor:
286
+ if hasattr(lm_head, "weight"):
287
+ logits = torch.matmul(hidden_states, lm_head.weight.T)
288
+ else:
289
+ # GGUF models
290
+ logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
291
+ return logits
292
+
279
293
 
280
294
  def test():
281
295
  all_logprobs = torch.tensor(
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
117
117
  return None
118
118
 
119
119
 
120
+ def gptq_get_quant_method(self, layer, prefix):
121
+ from vllm.model_executor.layers.linear import LinearBase
122
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
123
+ GPTQMarlinLinearMethod,
124
+ GPTQMarlinMoEMethod,
125
+ )
126
+
127
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
128
+
129
+ if isinstance(layer, LinearBase):
130
+ return GPTQMarlinLinearMethod(self)
131
+ elif isinstance(layer, FusedMoE):
132
+ return GPTQMarlinMoEMethod(self)
133
+ return None
134
+
135
+
136
+ def awq_get_quant_method(self, layer, prefix):
137
+ from vllm.model_executor.layers.linear import LinearBase
138
+ from vllm.model_executor.layers.quantization.awq_marlin import (
139
+ AWQMarlinLinearMethod,
140
+ AWQMoEMethod,
141
+ )
142
+
143
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
144
+
145
+ if isinstance(layer, LinearBase):
146
+ return AWQMarlinLinearMethod(self)
147
+ elif isinstance(layer, FusedMoE):
148
+ return AWQMoEMethod(self)
149
+ return None
150
+
151
+
120
152
  def apply_monkey_patches():
121
153
  """Apply all monkey patches in one place."""
122
154
  setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
123
155
  setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
156
+ setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
157
+ setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
124
158
 
125
159
 
126
160
  # Apply patches when module is imported
@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module):
222
222
  enable_tp: bool = True,
223
223
  ):
224
224
  super().__init__()
225
+ self.quant_config = quant_config
225
226
 
226
227
  self.enable_tp = enable_tp
227
228
  if self.enable_tp:
sglang/srt/lora/lora.py CHANGED
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
31
31
  ParallelLMHead,
32
32
  VocabParallelEmbedding,
33
33
  )
34
- from vllm.model_executor.model_loader.loader import DefaultModelLoader
35
34
 
36
35
  from sglang.srt.layers.linear import (
37
36
  ColumnParallelLinear,
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
40
39
  RowParallelLinear,
41
40
  )
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
42
+ from sglang.srt.model_loader.loader import DefaultModelLoader
43
43
 
44
44
 
45
45
  class BaseLayerWithLoRA(nn.Module):
@@ -15,9 +15,11 @@
15
15
 
16
16
  import logging
17
17
  import multiprocessing as mp
18
+ import signal
18
19
  import threading
19
20
  from enum import Enum, auto
20
21
 
22
+ import psutil
21
23
  import zmq
22
24
 
23
25
  from sglang.srt.managers.io_struct import (
@@ -26,13 +28,7 @@ from sglang.srt.managers.io_struct import (
26
28
  )
27
29
  from sglang.srt.managers.scheduler import run_scheduler_process
28
30
  from sglang.srt.server_args import PortArgs, ServerArgs
29
- from sglang.srt.utils import (
30
- bind_port,
31
- configure_logger,
32
- get_zmq_socket,
33
- kill_parent_process,
34
- suppress_other_loggers,
35
- )
31
+ from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
36
32
  from sglang.utils import get_exception_traceback
37
33
 
38
34
  logger = logging.getLogger(__name__)
@@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
235
231
  pipe_writer,
236
232
  ):
237
233
  configure_logger(server_args)
238
- suppress_other_loggers()
234
+ parent_process = psutil.Process().parent()
239
235
 
240
236
  try:
241
237
  controller = DataParallelController(server_args, port_args)
@@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
244
240
  )
245
241
  controller.event_loop()
246
242
  except Exception:
247
- msg = get_exception_traceback()
248
- logger.error(msg)
249
- kill_parent_process()
243
+ traceback = get_exception_traceback()
244
+ logger.error(f"DataParallelController hit an exception: {traceback}")
245
+ parent_process.send_signal(signal.SIGQUIT)
@@ -15,9 +15,11 @@
15
15
 
16
16
  import dataclasses
17
17
  import logging
18
+ import signal
18
19
  from collections import OrderedDict
19
20
  from typing import List, Union
20
21
 
22
+ import psutil
21
23
  import zmq
22
24
 
23
25
  from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -28,7 +30,7 @@ from sglang.srt.managers.io_struct import (
28
30
  )
29
31
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
30
32
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process
33
+ from sglang.srt.utils import configure_logger, get_zmq_socket
32
34
  from sglang.utils import find_printable_text, get_exception_traceback
33
35
 
34
36
  logger = logging.getLogger(__name__)
@@ -193,11 +195,12 @@ def run_detokenizer_process(
193
195
  port_args: PortArgs,
194
196
  ):
195
197
  configure_logger(server_args)
198
+ parent_process = psutil.Process().parent()
196
199
 
197
200
  try:
198
201
  manager = DetokenizerManager(server_args, port_args)
199
202
  manager.event_loop()
200
203
  except Exception:
201
- msg = get_exception_traceback()
202
- logger.error(msg)
203
- kill_parent_process()
204
+ traceback = get_exception_traceback()
205
+ logger.error(f"DetokenizerManager hit an exception: {traceback}")
206
+ parent_process.send_signal(signal.SIGQUIT)
@@ -338,7 +338,7 @@ class Qwen2VLImageProcessor(BaseImageProcessor):
338
338
  "pixel_values": pixel_values,
339
339
  "image_hashes": image_hashes,
340
340
  "image_sizes": image_sizes,
341
- "modalities": request_obj.modalities,
341
+ "modalities": request_obj.modalities or ["image"],
342
342
  "image_grid_thws": image_grid_thws,
343
343
  }
344
344
 
@@ -352,7 +352,7 @@ class FlushCacheReq:
352
352
 
353
353
 
354
354
  @dataclass
355
- class UpdateWeightReqInput:
355
+ class UpdateWeightFromDiskReqInput:
356
356
  # The model path with the new weights
357
357
  model_path: str
358
358
  # The format to load the weights
@@ -360,30 +360,66 @@ class UpdateWeightReqInput:
360
360
 
361
361
 
362
362
  @dataclass
363
- class UpdateWeightReqOutput:
363
+ class UpdateWeightFromDiskReqOutput:
364
364
  success: bool
365
365
  message: str
366
366
 
367
367
 
368
368
  @dataclass
369
- class AbortReq:
370
- # The request id
371
- rid: str
369
+ class UpdateWeightsFromDistributedReqInput:
370
+ name: str
371
+ dtype: str
372
+ shape: List[int]
372
373
 
373
374
 
374
- class ProfileReq(Enum):
375
- START_PROFILE = 1
376
- STOP_PROFILE = 2
375
+ @dataclass
376
+ class UpdateWeightsFromDistributedReqOutput:
377
+ success: bool
378
+ message: str
377
379
 
378
380
 
379
381
  @dataclass
380
- class GetMemPoolSizeReq:
381
- pass
382
+ class InitWeightsUpdateGroupReqInput:
383
+ # The master address
384
+ master_address: str
385
+ # The master port
386
+ master_port: int
387
+ # The rank offset
388
+ rank_offset: int
389
+ # The world size
390
+ world_size: int
391
+ # The group name
392
+ group_name: str = "weight_update_group"
393
+ # The backend
394
+ backend: str = "nccl"
395
+
396
+
397
+ @dataclass
398
+ class InitWeightsUpdateGroupReqOutput:
399
+ success: bool
400
+ message: str
382
401
 
383
402
 
384
403
  @dataclass
385
- class GetMemPoolSizeReqOutput:
386
- size: int
404
+ class GetWeightsByNameReqInput:
405
+ name: str
406
+ truncate_size: int = 100
407
+
408
+
409
+ @dataclass
410
+ class GetWeightsByNameReqOutput:
411
+ parameter: list
412
+
413
+
414
+ @dataclass
415
+ class AbortReq:
416
+ # The request id
417
+ rid: str
418
+
419
+
420
+ class ProfileReq(Enum):
421
+ START_PROFILE = 1
422
+ STOP_PROFILE = 2
387
423
 
388
424
 
389
425
  @dataclass
@@ -124,7 +124,7 @@ class FINISH_ABORT(BaseFinishReason):
124
124
  class ImageInputs:
125
125
  """The image related inputs."""
126
126
 
127
- pixel_values: torch.Tensor
127
+ pixel_values: Union[torch.Tensor, np.array]
128
128
  image_hashes: Optional[list] = None
129
129
  image_sizes: Optional[list] = None
130
130
  image_offsets: Optional[list] = None
@@ -132,7 +132,7 @@ class ImageInputs:
132
132
  modalities: Optional[list] = None
133
133
  num_image_tokens: Optional[int] = None
134
134
 
135
- image_embeds: Optional[List[torch.Tensor]] = None
135
+ # Llava related
136
136
  aspect_ratio_ids: Optional[List[torch.Tensor]] = None
137
137
  aspect_ratio_mask: Optional[List[torch.Tensor]] = None
138
138
 
@@ -141,19 +141,17 @@ class ImageInputs:
141
141
  mrope_position_delta: Optional[torch.Tensor] = None
142
142
 
143
143
  @staticmethod
144
- def from_dict(obj, vocab_size):
145
- # Use image hash as fake token_ids, which is then used for prefix matching
144
+ def from_dict(obj: dict):
146
145
  ret = ImageInputs(
147
146
  pixel_values=obj["pixel_values"],
148
- image_hashes=hash(tuple(obj["image_hashes"])),
147
+ image_hashes=obj["image_hashes"],
149
148
  )
150
- image_hash = ret.image_hashes
151
- ret.pad_values = [
152
- (image_hash) % vocab_size,
153
- (image_hash >> 16) % vocab_size,
154
- (image_hash >> 32) % vocab_size,
155
- (image_hash >> 64) % vocab_size,
156
- ]
149
+
150
+ # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
151
+ # Please note that if the `input_ids` is later used in the model forward,
152
+ # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
153
+ # errors in cuda kernels. See also llava.py for example.
154
+ ret.pad_values = [x % (1 << 30) for x in ret.image_hashes]
157
155
 
158
156
  optional_args = [
159
157
  "image_sizes",
@@ -168,17 +166,16 @@ class ImageInputs:
168
166
 
169
167
  return ret
170
168
 
171
- def merge(self, other, vocab_size):
169
+ def merge(self, other):
172
170
  assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
173
171
  self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
174
- self.image_hashes += other.image_hashes
175
172
 
176
- self.pad_values = [
177
- (self.image_hashes) % vocab_size,
178
- (self.image_hashes >> 16) % vocab_size,
179
- (self.image_hashes >> 32) % vocab_size,
180
- (self.image_hashes >> 64) % vocab_size,
181
- ]
173
+ # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
174
+ # Please note that if the `input_ids` is later used in the model forward,
175
+ # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
176
+ # errors in cuda kernels. See also llava.py for example.
177
+ self.image_hashes += other.image_hashes
178
+ self.pad_values = [x % (1 << 30) for x in self.image_hashes]
182
179
 
183
180
  optional_args = [
184
181
  "image_sizes",
@@ -231,6 +228,7 @@ class Req:
231
228
  self.tokenizer = None
232
229
  self.finished_reason = None
233
230
  self.stream = False
231
+ self.to_abort = False
234
232
 
235
233
  # For incremental decoding
236
234
  # ----- | --------- read_ids -------|
@@ -290,11 +288,11 @@ class Req:
290
288
  # The number of cached tokens, that were already cached in the KV cache
291
289
  self.cached_tokens = 0
292
290
 
293
- def extend_image_inputs(self, image_inputs, vocab_size):
291
+ def extend_image_inputs(self, image_inputs):
294
292
  if self.image_inputs is None:
295
293
  self.image_inputs = image_inputs
296
294
  else:
297
- self.image_inputs.merge(image_inputs, vocab_size)
295
+ self.image_inputs.merge(image_inputs)
298
296
 
299
297
  # whether request reached finished condition
300
298
  def finished(self) -> bool:
@@ -368,6 +366,10 @@ class Req:
368
366
  if self.finished():
369
367
  return
370
368
 
369
+ if self.to_abort:
370
+ self.finished_reason = FINISH_ABORT()
371
+ return
372
+
371
373
  if len(self.output_ids) >= self.sampling_params.max_new_tokens:
372
374
  self.finished_reason = FINISH_LENGTH(
373
375
  length=self.sampling_params.max_new_tokens
@@ -741,20 +743,24 @@ class ScheduleBatch:
741
743
  extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
742
744
  self.device, non_blocking=True
743
745
  )
744
- write_req_to_token_pool_triton[(bs,)](
745
- self.req_to_token_pool.req_to_token,
746
- self.req_pool_indices,
747
- pre_lens,
748
- self.seq_lens,
749
- extend_lens,
750
- self.out_cache_loc,
751
- self.req_to_token_pool.req_to_token.shape[1],
752
- )
753
- # The triton kernel is equivalent to the following python code.
754
- # self.req_to_token_pool.write(
755
- # (req.req_pool_idx, slice(pre_len, seq_len)),
756
- # out_cache_loc[pt : pt + req.extend_input_len],
757
- # )
746
+ if global_server_args_dict["attention_backend"] != "torch_native":
747
+ write_req_to_token_pool_triton[(bs,)](
748
+ self.req_to_token_pool.req_to_token,
749
+ self.req_pool_indices,
750
+ pre_lens,
751
+ self.seq_lens,
752
+ extend_lens,
753
+ self.out_cache_loc,
754
+ self.req_to_token_pool.req_to_token.shape[1],
755
+ )
756
+ else:
757
+ pt = 0
758
+ for i in range(bs):
759
+ self.req_to_token_pool.write(
760
+ (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
761
+ self.out_cache_loc[pt : pt + self.extend_lens[i]],
762
+ )
763
+ pt += self.extend_lens[i]
758
764
  # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
759
765
 
760
766
  if self.model_config.is_encoder_decoder:
@@ -142,7 +142,7 @@ class PrefillAdder:
142
142
 
143
143
  self.req_states = None
144
144
  self.can_run_list = []
145
- self.new_inflight_req = None
145
+ self.new_being_chunked_req = None
146
146
  self.log_hit_tokens = 0
147
147
  self.log_input_tokens = 0
148
148
 
@@ -182,7 +182,7 @@ class PrefillAdder:
182
182
  self.log_hit_tokens += prefix_len
183
183
  self.log_input_tokens += extend_input_len
184
184
 
185
- def add_inflight_req(self, req: Req):
185
+ def add_being_chunked_req(self, req: Req):
186
186
  truncated = req.extend_input_len > self.rem_chunk_tokens
187
187
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
188
188
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -269,10 +269,13 @@ class PrefillAdder:
269
269
  else:
270
270
  # Chunked prefill
271
271
  trunc_len = self.rem_chunk_tokens
272
+ if trunc_len == 0:
273
+ return AddReqResult.OTHER
274
+
272
275
  req.extend_input_len = trunc_len
273
276
  req.fill_ids = req.fill_ids[:trunc_len]
274
277
  self.can_run_list.append(req)
275
- self.new_inflight_req = req
278
+ self.new_being_chunked_req = req
276
279
  self._prefill_one_req(0, trunc_len, 0)
277
280
 
278
281
  return self.budget_state()
@@ -326,7 +329,7 @@ class PrefillAdder:
326
329
  req.extend_input_len = trunc_len
327
330
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
328
331
  self.can_run_list.append(req)
329
- self.new_inflight_req = req
332
+ self.new_being_chunked_req = req
330
333
  self.tree_cache.inc_lock_ref(req.last_node)
331
334
  self._prefill_one_req(prefix_len, trunc_len, 0)
332
335