sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from vllm.distributed import (
23
23
  tensor_model_parallel_all_gather,
24
24
  )
25
25
 
26
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
26
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
27
28
 
28
29
 
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
163
164
  self,
164
165
  input_ids,
165
166
  hidden_states,
166
- weight,
167
+ lm_head: VocabParallelEmbedding,
167
168
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
168
169
  ):
169
170
  if isinstance(logits_metadata, ForwardBatch):
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
178
179
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
179
180
  last_hidden = hidden_states[last_index]
180
181
 
181
- last_logits = torch.matmul(last_hidden, weight.T)
182
+ last_logits = self._get_logits(last_hidden, lm_head)
182
183
  if self.do_tensor_parallel_all_gather:
183
184
  last_logits = tensor_model_parallel_all_gather(last_logits)
184
185
  last_logits = last_logits[:, : self.config.vocab_size].float()
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
229
230
 
230
231
  # Compute the logits and logprobs for all required tokens
231
232
  states = torch.cat(states, dim=0)
232
- all_logits = torch.matmul(states, weight.T)
233
+ all_logits = self._get_logits(states, lm_head)
233
234
  if self.do_tensor_parallel_all_gather:
234
235
  all_logits = tensor_model_parallel_all_gather(all_logits)
235
236
  all_logits = all_logits[:, : self.config.vocab_size].float()
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
276
277
  output_top_logprobs=output_top_logprobs,
277
278
  )
278
279
 
280
+ def _get_logits(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ lm_head: VocabParallelEmbedding,
284
+ embedding_bias: Optional[torch.Tensor] = None,
285
+ ) -> torch.Tensor:
286
+ if hasattr(lm_head, "weight"):
287
+ logits = torch.matmul(hidden_states, lm_head.weight.T)
288
+ else:
289
+ # GGUF models
290
+ logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
291
+ return logits
292
+
279
293
 
280
294
  def test():
281
295
  all_logprobs = torch.tensor(
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
117
117
  return None
118
118
 
119
119
 
120
+ def gptq_get_quant_method(self, layer, prefix):
121
+ from vllm.model_executor.layers.linear import LinearBase
122
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
123
+ GPTQMarlinLinearMethod,
124
+ GPTQMarlinMoEMethod,
125
+ )
126
+
127
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
128
+
129
+ if isinstance(layer, LinearBase):
130
+ return GPTQMarlinLinearMethod(self)
131
+ elif isinstance(layer, FusedMoE):
132
+ return GPTQMarlinMoEMethod(self)
133
+ return None
134
+
135
+
136
+ def awq_get_quant_method(self, layer, prefix):
137
+ from vllm.model_executor.layers.linear import LinearBase
138
+ from vllm.model_executor.layers.quantization.awq_marlin import (
139
+ AWQMarlinLinearMethod,
140
+ AWQMoEMethod,
141
+ )
142
+
143
+ from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
144
+
145
+ if isinstance(layer, LinearBase):
146
+ return AWQMarlinLinearMethod(self)
147
+ elif isinstance(layer, FusedMoE):
148
+ return AWQMoEMethod(self)
149
+ return None
150
+
151
+
120
152
  def apply_monkey_patches():
121
153
  """Apply all monkey patches in one place."""
122
154
  setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
123
155
  setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
156
+ setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
157
+ setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
124
158
 
125
159
 
126
160
  # Apply patches when module is imported
@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module):
222
222
  enable_tp: bool = True,
223
223
  ):
224
224
  super().__init__()
225
+ self.quant_config = quant_config
225
226
 
226
227
  self.enable_tp = enable_tp
227
228
  if self.enable_tp:
sglang/srt/lora/lora.py CHANGED
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
31
31
  ParallelLMHead,
32
32
  VocabParallelEmbedding,
33
33
  )
34
- from vllm.model_executor.model_loader.loader import DefaultModelLoader
35
34
 
36
35
  from sglang.srt.layers.linear import (
37
36
  ColumnParallelLinear,
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
40
39
  RowParallelLinear,
41
40
  )
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
42
+ from sglang.srt.model_loader.loader import DefaultModelLoader
43
43
 
44
44
 
45
45
  class BaseLayerWithLoRA(nn.Module):
@@ -352,7 +352,7 @@ class FlushCacheReq:
352
352
 
353
353
 
354
354
  @dataclass
355
- class UpdateWeightReqInput:
355
+ class UpdateWeightFromDiskReqInput:
356
356
  # The model path with the new weights
357
357
  model_path: str
358
358
  # The format to load the weights
@@ -360,11 +360,57 @@ class UpdateWeightReqInput:
360
360
 
361
361
 
362
362
  @dataclass
363
- class UpdateWeightReqOutput:
363
+ class UpdateWeightFromDiskReqOutput:
364
364
  success: bool
365
365
  message: str
366
366
 
367
367
 
368
+ @dataclass
369
+ class UpdateWeightsFromDistributedReqInput:
370
+ name: str
371
+ dtype: str
372
+ shape: List[int]
373
+
374
+
375
+ @dataclass
376
+ class UpdateWeightsFromDistributedReqOutput:
377
+ success: bool
378
+ message: str
379
+
380
+
381
+ @dataclass
382
+ class InitWeightsUpdateGroupReqInput:
383
+ # The master address
384
+ master_address: str
385
+ # The master port
386
+ master_port: int
387
+ # The rank offset
388
+ rank_offset: int
389
+ # The world size
390
+ world_size: int
391
+ # The group name
392
+ group_name: str = "weight_update_group"
393
+ # The backend
394
+ backend: str = "nccl"
395
+
396
+
397
+ @dataclass
398
+ class InitWeightsUpdateGroupReqOutput:
399
+ success: bool
400
+ message: str
401
+
402
+
403
+ @dataclass
404
+ class GetWeightsByNameReqInput:
405
+ name: str
406
+ truncate_size: int = 100
407
+
408
+
409
+ @dataclass
410
+ class GetWeightsByNameReqOutput:
411
+ parameter: list
412
+
413
+
368
414
  @dataclass
369
415
  class AbortReq:
370
416
  # The request id
@@ -743,20 +743,24 @@ class ScheduleBatch:
743
743
  extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
744
744
  self.device, non_blocking=True
745
745
  )
746
- write_req_to_token_pool_triton[(bs,)](
747
- self.req_to_token_pool.req_to_token,
748
- self.req_pool_indices,
749
- pre_lens,
750
- self.seq_lens,
751
- extend_lens,
752
- self.out_cache_loc,
753
- self.req_to_token_pool.req_to_token.shape[1],
754
- )
755
- # The triton kernel is equivalent to the following python code.
756
- # self.req_to_token_pool.write(
757
- # (req.req_pool_idx, slice(pre_len, seq_len)),
758
- # out_cache_loc[pt : pt + req.extend_input_len],
759
- # )
746
+ if global_server_args_dict["attention_backend"] != "torch_native":
747
+ write_req_to_token_pool_triton[(bs,)](
748
+ self.req_to_token_pool.req_to_token,
749
+ self.req_pool_indices,
750
+ pre_lens,
751
+ self.seq_lens,
752
+ extend_lens,
753
+ self.out_cache_loc,
754
+ self.req_to_token_pool.req_to_token.shape[1],
755
+ )
756
+ else:
757
+ pt = 0
758
+ for i in range(bs):
759
+ self.req_to_token_pool.write(
760
+ (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
761
+ self.out_cache_loc[pt : pt + self.extend_lens[i]],
762
+ )
763
+ pt += self.extend_lens[i]
760
764
  # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
761
765
 
762
766
  if self.model_config.is_encoder_decoder:
@@ -142,7 +142,7 @@ class PrefillAdder:
142
142
 
143
143
  self.req_states = None
144
144
  self.can_run_list = []
145
- self.new_inflight_req = None
145
+ self.new_being_chunked_req = None
146
146
  self.log_hit_tokens = 0
147
147
  self.log_input_tokens = 0
148
148
 
@@ -182,7 +182,7 @@ class PrefillAdder:
182
182
  self.log_hit_tokens += prefix_len
183
183
  self.log_input_tokens += extend_input_len
184
184
 
185
- def add_inflight_req(self, req: Req):
185
+ def add_being_chunked_req(self, req: Req):
186
186
  truncated = req.extend_input_len > self.rem_chunk_tokens
187
187
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
188
188
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -269,10 +269,13 @@ class PrefillAdder:
269
269
  else:
270
270
  # Chunked prefill
271
271
  trunc_len = self.rem_chunk_tokens
272
+ if trunc_len == 0:
273
+ return AddReqResult.OTHER
274
+
272
275
  req.extend_input_len = trunc_len
273
276
  req.fill_ids = req.fill_ids[:trunc_len]
274
277
  self.can_run_list.append(req)
275
- self.new_inflight_req = req
278
+ self.new_being_chunked_req = req
276
279
  self._prefill_one_req(0, trunc_len, 0)
277
280
 
278
281
  return self.budget_state()
@@ -326,7 +329,7 @@ class PrefillAdder:
326
329
  req.extend_input_len = trunc_len
327
330
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
328
331
  self.can_run_list.append(req)
329
- self.new_inflight_req = req
332
+ self.new_being_chunked_req = req
330
333
  self.tree_cache.inc_lock_ref(req.last_node)
331
334
  self._prefill_one_req(prefix_len, trunc_len, 0)
332
335
 
@@ -38,13 +38,19 @@ from sglang.srt.managers.io_struct import (
38
38
  BatchTokenIDOut,
39
39
  CloseSessionReqInput,
40
40
  FlushCacheReq,
41
+ GetWeightsByNameReqInput,
42
+ GetWeightsByNameReqOutput,
43
+ InitWeightsUpdateGroupReqInput,
44
+ InitWeightsUpdateGroupReqOutput,
41
45
  OpenSessionReqInput,
42
46
  OpenSessionReqOutput,
43
47
  ProfileReq,
44
48
  TokenizedEmbeddingReqInput,
45
49
  TokenizedGenerateReqInput,
46
- UpdateWeightReqInput,
47
- UpdateWeightReqOutput,
50
+ UpdateWeightFromDiskReqInput,
51
+ UpdateWeightFromDiskReqOutput,
52
+ UpdateWeightsFromDistributedReqInput,
53
+ UpdateWeightsFromDistributedReqOutput,
48
54
  )
49
55
  from sglang.srt.managers.schedule_batch import (
50
56
  FINISH_ABORT,
@@ -141,9 +147,12 @@ class Scheduler:
141
147
  self.model_config = ModelConfig(
142
148
  server_args.model_path,
143
149
  trust_remote_code=server_args.trust_remote_code,
150
+ revision=server_args.revision,
144
151
  context_length=server_args.context_length,
145
152
  model_override_args=server_args.json_model_override_args,
146
153
  is_embedding=server_args.is_embedding,
154
+ dtype=server_args.dtype,
155
+ quantization=server_args.quantization,
147
156
  )
148
157
  self.is_generation = self.model_config.is_generation
149
158
 
@@ -253,6 +262,8 @@ class Scheduler:
253
262
 
254
263
  # Init chunked prefill
255
264
  self.chunked_prefill_size = server_args.chunked_prefill_size
265
+ if self.chunked_prefill_size <= 0: # -1 means disable
266
+ self.chunked_prefill_size = None
256
267
  self.being_chunked_req = None
257
268
  self.is_mixed_chunk = (
258
269
  self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
@@ -504,11 +515,27 @@ class Scheduler:
504
515
  self.flush_cache()
505
516
  elif isinstance(recv_req, AbortReq):
506
517
  self.abort_request(recv_req)
507
- elif isinstance(recv_req, UpdateWeightReqInput):
508
- success, message = self.update_weights(recv_req)
518
+ elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
519
+ success, message = self.update_weights_from_disk(recv_req)
509
520
  self.send_to_tokenizer.send_pyobj(
510
- UpdateWeightReqOutput(success, message)
521
+ UpdateWeightFromDiskReqOutput(success, message)
511
522
  )
523
+ elif isinstance(recv_req, GetWeightsByNameReqInput):
524
+ parameter = self.get_weights_by_name(recv_req)
525
+ self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
526
+ elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
527
+ success, message = self.init_weights_update_group(recv_req)
528
+ self.send_to_tokenizer.send_pyobj(
529
+ InitWeightsUpdateGroupReqOutput(success, message)
530
+ )
531
+ elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
532
+ success, message = self.update_weights_from_distributed(recv_req)
533
+ self.send_to_tokenizer.send_pyobj(
534
+ UpdateWeightsFromDistributedReqOutput(success, message)
535
+ )
536
+ elif isinstance(recv_req, GetWeightsByNameReqInput):
537
+ parameter = self.get_weights_by_name(recv_req)
538
+ self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
512
539
  elif isinstance(recv_req, ProfileReq):
513
540
  if recv_req == ProfileReq.START_PROFILE:
514
541
  self.start_profile()
@@ -653,7 +680,7 @@ class Scheduler:
653
680
 
654
681
  self.waiting_queue.append(req)
655
682
 
656
- def log_prefill_stats(self, adder, can_run_list, running_bs, has_inflight):
683
+ def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
657
684
  if isinstance(self.tree_cache, RadixCache):
658
685
  self.tree_cache_metrics["total"] += (
659
686
  adder.log_input_tokens + adder.log_hit_tokens
@@ -677,14 +704,14 @@ class Scheduler:
677
704
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
678
705
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
679
706
  f"#running-req: {running_bs}, "
680
- f"#queue-req: {len(self.waiting_queue) + has_inflight}"
707
+ f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
681
708
  )
682
709
 
683
710
  if self.enable_metrics:
684
711
  self.stats.num_running_reqs = running_bs
685
712
  self.stats.num_used_tokens = num_used
686
713
  self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
687
- self.stats.num_queue_reqs = len(self.waiting_queue) + has_inflight
714
+ self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
688
715
  self.stats.cache_hit_rate = tree_cache_hit_rate
689
716
  self.metrics_collector.log_stats(self.stats)
690
717
 
@@ -745,7 +772,7 @@ class Scheduler:
745
772
  # Move the chunked request out of the batch
746
773
  self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
747
774
  self.tree_cache.cache_unfinished_req(self.being_chunked_req)
748
- # Inflight request keeps its rid but will get a new req_pool_idx
775
+ # being chunked request keeps its rid but will get a new req_pool_idx
749
776
  self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
750
777
  self.batch_is_full = False
751
778
 
@@ -796,10 +823,10 @@ class Scheduler:
796
823
  running_bs if self.is_mixed_chunk else 0,
797
824
  )
798
825
 
799
- has_inflight = self.being_chunked_req is not None
800
- if has_inflight:
826
+ has_being_chunked = self.being_chunked_req is not None
827
+ if has_being_chunked:
801
828
  self.being_chunked_req.init_next_round_input()
802
- self.being_chunked_req = adder.add_inflight_req(self.being_chunked_req)
829
+ self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
803
830
 
804
831
  if self.lora_paths:
805
832
  lora_set = (
@@ -841,16 +868,16 @@ class Scheduler:
841
868
  x for x in self.waiting_queue if x not in set(can_run_list)
842
869
  ]
843
870
 
844
- if adder.new_inflight_req is not None:
871
+ if adder.new_being_chunked_req is not None:
845
872
  assert self.being_chunked_req is None
846
- self.being_chunked_req = adder.new_inflight_req
873
+ self.being_chunked_req = adder.new_being_chunked_req
847
874
 
848
875
  if self.being_chunked_req:
849
876
  self.being_chunked_req.is_being_chunked += 1
850
877
 
851
878
  # Print stats
852
879
  if self.tp_rank == 0:
853
- self.log_prefill_stats(adder, can_run_list, running_bs, has_inflight)
880
+ self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
854
881
 
855
882
  # Create a new batch
856
883
  new_batch = ScheduleBatch.init_new(
@@ -1023,7 +1050,7 @@ class Scheduler:
1023
1050
  if req.grammar is not None:
1024
1051
  req.grammar.accept_token(next_token_id)
1025
1052
  else:
1026
- # Inflight reqs' prefill is not finished
1053
+ # being chunked reqs' prefill is not finished
1027
1054
  req.is_being_chunked -= 1
1028
1055
 
1029
1056
  if batch.next_batch_sampling_info:
@@ -1051,7 +1078,7 @@ class Scheduler:
1051
1078
  else:
1052
1079
  self.tree_cache.cache_unfinished_req(req)
1053
1080
  else:
1054
- # Inflight reqs' prefill is not finished
1081
+ # being chunked reqs' prefill is not finished
1055
1082
  req.is_being_chunked -= 1
1056
1083
 
1057
1084
  self.stream_output(batch.reqs)
@@ -1146,6 +1173,14 @@ class Scheduler:
1146
1173
  + 1 : len(req.fill_ids)
1147
1174
  - req.last_update_decode_tokens
1148
1175
  ]
1176
+
1177
+ # Clip the padded hash values from image tokens.
1178
+ # Otherwise, it will lead to detokenization errors.
1179
+ input_token_ids = [
1180
+ x if x < self.model_config.vocab_size - 1 else 0
1181
+ for x in input_token_ids
1182
+ ]
1183
+
1149
1184
  req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
1150
1185
 
1151
1186
  if (
@@ -1361,9 +1396,26 @@ class Scheduler:
1361
1396
  req.to_abort = True
1362
1397
  break
1363
1398
 
1364
- def update_weights(self, recv_req: UpdateWeightReqInput):
1365
- """In-place update of the weights."""
1366
- success, message = self.tp_worker.update_weights(recv_req)
1399
+ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
1400
+ """In-place update of the weights from disk."""
1401
+ success, message = self.tp_worker.update_weights_from_disk(recv_req)
1402
+ if success:
1403
+ flash_cache_success = self.flush_cache()
1404
+ assert flash_cache_success, "Cache flush failed after updating weights"
1405
+ else:
1406
+ logger.error(message)
1407
+ return success, message
1408
+
1409
+ def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
1410
+ """Initialize the online model parameter update group."""
1411
+ success, message = self.tp_worker.init_weights_update_group(recv_req)
1412
+ return success, message
1413
+
1414
+ def update_weights_from_distributed(
1415
+ self, recv_req: UpdateWeightsFromDistributedReqInput
1416
+ ):
1417
+ """Update the online model parameter."""
1418
+ success, message = self.tp_worker.update_weights_from_distributed(recv_req)
1367
1419
  if success:
1368
1420
  flash_cache_success = self.flush_cache()
1369
1421
  assert flash_cache_success, "Cache flush failed after updating weights"
@@ -1371,6 +1423,10 @@ class Scheduler:
1371
1423
  logger.error(message)
1372
1424
  return success, message
1373
1425
 
1426
+ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
1427
+ parameter = self.tp_worker.get_weights_by_name(recv_req)
1428
+ return parameter
1429
+
1374
1430
  def start_profile(self) -> None:
1375
1431
  if self.profiler is None:
1376
1432
  raise RuntimeError("Profiler is not enabled.")