sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +18 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +76 -20
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +267 -170
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +245 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ from vllm.distributed import (
|
|
23
23
|
tensor_model_parallel_all_gather,
|
24
24
|
)
|
25
25
|
|
26
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
26
27
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
27
28
|
|
28
29
|
|
@@ -163,7 +164,7 @@ class LogitsProcessor(nn.Module):
|
|
163
164
|
self,
|
164
165
|
input_ids,
|
165
166
|
hidden_states,
|
166
|
-
|
167
|
+
lm_head: VocabParallelEmbedding,
|
167
168
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
168
169
|
):
|
169
170
|
if isinstance(logits_metadata, ForwardBatch):
|
@@ -178,7 +179,7 @@ class LogitsProcessor(nn.Module):
|
|
178
179
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
179
180
|
last_hidden = hidden_states[last_index]
|
180
181
|
|
181
|
-
last_logits =
|
182
|
+
last_logits = self._get_logits(last_hidden, lm_head)
|
182
183
|
if self.do_tensor_parallel_all_gather:
|
183
184
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
184
185
|
last_logits = last_logits[:, : self.config.vocab_size].float()
|
@@ -229,7 +230,7 @@ class LogitsProcessor(nn.Module):
|
|
229
230
|
|
230
231
|
# Compute the logits and logprobs for all required tokens
|
231
232
|
states = torch.cat(states, dim=0)
|
232
|
-
all_logits =
|
233
|
+
all_logits = self._get_logits(states, lm_head)
|
233
234
|
if self.do_tensor_parallel_all_gather:
|
234
235
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
235
236
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
@@ -276,6 +277,19 @@ class LogitsProcessor(nn.Module):
|
|
276
277
|
output_top_logprobs=output_top_logprobs,
|
277
278
|
)
|
278
279
|
|
280
|
+
def _get_logits(
|
281
|
+
self,
|
282
|
+
hidden_states: torch.Tensor,
|
283
|
+
lm_head: VocabParallelEmbedding,
|
284
|
+
embedding_bias: Optional[torch.Tensor] = None,
|
285
|
+
) -> torch.Tensor:
|
286
|
+
if hasattr(lm_head, "weight"):
|
287
|
+
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
288
|
+
else:
|
289
|
+
# GGUF models
|
290
|
+
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
|
291
|
+
return logits
|
292
|
+
|
279
293
|
|
280
294
|
def test():
|
281
295
|
all_logprobs = torch.tensor(
|
@@ -117,10 +117,44 @@ def fp8_get_quant_method(self, layer, prefix):
|
|
117
117
|
return None
|
118
118
|
|
119
119
|
|
120
|
+
def gptq_get_quant_method(self, layer, prefix):
|
121
|
+
from vllm.model_executor.layers.linear import LinearBase
|
122
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
123
|
+
GPTQMarlinLinearMethod,
|
124
|
+
GPTQMarlinMoEMethod,
|
125
|
+
)
|
126
|
+
|
127
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
128
|
+
|
129
|
+
if isinstance(layer, LinearBase):
|
130
|
+
return GPTQMarlinLinearMethod(self)
|
131
|
+
elif isinstance(layer, FusedMoE):
|
132
|
+
return GPTQMarlinMoEMethod(self)
|
133
|
+
return None
|
134
|
+
|
135
|
+
|
136
|
+
def awq_get_quant_method(self, layer, prefix):
|
137
|
+
from vllm.model_executor.layers.linear import LinearBase
|
138
|
+
from vllm.model_executor.layers.quantization.awq_marlin import (
|
139
|
+
AWQMarlinLinearMethod,
|
140
|
+
AWQMoEMethod,
|
141
|
+
)
|
142
|
+
|
143
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
144
|
+
|
145
|
+
if isinstance(layer, LinearBase):
|
146
|
+
return AWQMarlinLinearMethod(self)
|
147
|
+
elif isinstance(layer, FusedMoE):
|
148
|
+
return AWQMoEMethod(self)
|
149
|
+
return None
|
150
|
+
|
151
|
+
|
120
152
|
def apply_monkey_patches():
|
121
153
|
"""Apply all monkey patches in one place."""
|
122
154
|
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
123
155
|
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
156
|
+
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
157
|
+
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
124
158
|
|
125
159
|
|
126
160
|
# Apply patches when module is imported
|
sglang/srt/lora/lora.py
CHANGED
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
31
31
|
ParallelLMHead,
|
32
32
|
VocabParallelEmbedding,
|
33
33
|
)
|
34
|
-
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
35
34
|
|
36
35
|
from sglang.srt.layers.linear import (
|
37
36
|
ColumnParallelLinear,
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
|
|
40
39
|
RowParallelLinear,
|
41
40
|
)
|
42
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
42
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
43
43
|
|
44
44
|
|
45
45
|
class BaseLayerWithLoRA(nn.Module):
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -352,7 +352,7 @@ class FlushCacheReq:
|
|
352
352
|
|
353
353
|
|
354
354
|
@dataclass
|
355
|
-
class
|
355
|
+
class UpdateWeightFromDiskReqInput:
|
356
356
|
# The model path with the new weights
|
357
357
|
model_path: str
|
358
358
|
# The format to load the weights
|
@@ -360,11 +360,57 @@ class UpdateWeightReqInput:
|
|
360
360
|
|
361
361
|
|
362
362
|
@dataclass
|
363
|
-
class
|
363
|
+
class UpdateWeightFromDiskReqOutput:
|
364
364
|
success: bool
|
365
365
|
message: str
|
366
366
|
|
367
367
|
|
368
|
+
@dataclass
|
369
|
+
class UpdateWeightsFromDistributedReqInput:
|
370
|
+
name: str
|
371
|
+
dtype: str
|
372
|
+
shape: List[int]
|
373
|
+
|
374
|
+
|
375
|
+
@dataclass
|
376
|
+
class UpdateWeightsFromDistributedReqOutput:
|
377
|
+
success: bool
|
378
|
+
message: str
|
379
|
+
|
380
|
+
|
381
|
+
@dataclass
|
382
|
+
class InitWeightsUpdateGroupReqInput:
|
383
|
+
# The master address
|
384
|
+
master_address: str
|
385
|
+
# The master port
|
386
|
+
master_port: int
|
387
|
+
# The rank offset
|
388
|
+
rank_offset: int
|
389
|
+
# The world size
|
390
|
+
world_size: int
|
391
|
+
# The group name
|
392
|
+
group_name: str = "weight_update_group"
|
393
|
+
# The backend
|
394
|
+
backend: str = "nccl"
|
395
|
+
|
396
|
+
|
397
|
+
@dataclass
|
398
|
+
class InitWeightsUpdateGroupReqOutput:
|
399
|
+
success: bool
|
400
|
+
message: str
|
401
|
+
|
402
|
+
|
403
|
+
@dataclass
|
404
|
+
class GetWeightsByNameReqInput:
|
405
|
+
name: str
|
406
|
+
truncate_size: int = 100
|
407
|
+
|
408
|
+
|
409
|
+
@dataclass
|
410
|
+
class GetWeightsByNameReqOutput:
|
411
|
+
parameter: list
|
412
|
+
|
413
|
+
|
368
414
|
@dataclass
|
369
415
|
class AbortReq:
|
370
416
|
# The request id
|
@@ -743,20 +743,24 @@ class ScheduleBatch:
|
|
743
743
|
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
744
744
|
self.device, non_blocking=True
|
745
745
|
)
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
746
|
+
if global_server_args_dict["attention_backend"] != "torch_native":
|
747
|
+
write_req_to_token_pool_triton[(bs,)](
|
748
|
+
self.req_to_token_pool.req_to_token,
|
749
|
+
self.req_pool_indices,
|
750
|
+
pre_lens,
|
751
|
+
self.seq_lens,
|
752
|
+
extend_lens,
|
753
|
+
self.out_cache_loc,
|
754
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
755
|
+
)
|
756
|
+
else:
|
757
|
+
pt = 0
|
758
|
+
for i in range(bs):
|
759
|
+
self.req_to_token_pool.write(
|
760
|
+
(self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
|
761
|
+
self.out_cache_loc[pt : pt + self.extend_lens[i]],
|
762
|
+
)
|
763
|
+
pt += self.extend_lens[i]
|
760
764
|
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
761
765
|
|
762
766
|
if self.model_config.is_encoder_decoder:
|
@@ -142,7 +142,7 @@ class PrefillAdder:
|
|
142
142
|
|
143
143
|
self.req_states = None
|
144
144
|
self.can_run_list = []
|
145
|
-
self.
|
145
|
+
self.new_being_chunked_req = None
|
146
146
|
self.log_hit_tokens = 0
|
147
147
|
self.log_input_tokens = 0
|
148
148
|
|
@@ -182,7 +182,7 @@ class PrefillAdder:
|
|
182
182
|
self.log_hit_tokens += prefix_len
|
183
183
|
self.log_input_tokens += extend_input_len
|
184
184
|
|
185
|
-
def
|
185
|
+
def add_being_chunked_req(self, req: Req):
|
186
186
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
187
187
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
188
188
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
@@ -269,10 +269,13 @@ class PrefillAdder:
|
|
269
269
|
else:
|
270
270
|
# Chunked prefill
|
271
271
|
trunc_len = self.rem_chunk_tokens
|
272
|
+
if trunc_len == 0:
|
273
|
+
return AddReqResult.OTHER
|
274
|
+
|
272
275
|
req.extend_input_len = trunc_len
|
273
276
|
req.fill_ids = req.fill_ids[:trunc_len]
|
274
277
|
self.can_run_list.append(req)
|
275
|
-
self.
|
278
|
+
self.new_being_chunked_req = req
|
276
279
|
self._prefill_one_req(0, trunc_len, 0)
|
277
280
|
|
278
281
|
return self.budget_state()
|
@@ -326,7 +329,7 @@ class PrefillAdder:
|
|
326
329
|
req.extend_input_len = trunc_len
|
327
330
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
328
331
|
self.can_run_list.append(req)
|
329
|
-
self.
|
332
|
+
self.new_being_chunked_req = req
|
330
333
|
self.tree_cache.inc_lock_ref(req.last_node)
|
331
334
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
332
335
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -38,13 +38,19 @@ from sglang.srt.managers.io_struct import (
|
|
38
38
|
BatchTokenIDOut,
|
39
39
|
CloseSessionReqInput,
|
40
40
|
FlushCacheReq,
|
41
|
+
GetWeightsByNameReqInput,
|
42
|
+
GetWeightsByNameReqOutput,
|
43
|
+
InitWeightsUpdateGroupReqInput,
|
44
|
+
InitWeightsUpdateGroupReqOutput,
|
41
45
|
OpenSessionReqInput,
|
42
46
|
OpenSessionReqOutput,
|
43
47
|
ProfileReq,
|
44
48
|
TokenizedEmbeddingReqInput,
|
45
49
|
TokenizedGenerateReqInput,
|
46
|
-
|
47
|
-
|
50
|
+
UpdateWeightFromDiskReqInput,
|
51
|
+
UpdateWeightFromDiskReqOutput,
|
52
|
+
UpdateWeightsFromDistributedReqInput,
|
53
|
+
UpdateWeightsFromDistributedReqOutput,
|
48
54
|
)
|
49
55
|
from sglang.srt.managers.schedule_batch import (
|
50
56
|
FINISH_ABORT,
|
@@ -141,9 +147,12 @@ class Scheduler:
|
|
141
147
|
self.model_config = ModelConfig(
|
142
148
|
server_args.model_path,
|
143
149
|
trust_remote_code=server_args.trust_remote_code,
|
150
|
+
revision=server_args.revision,
|
144
151
|
context_length=server_args.context_length,
|
145
152
|
model_override_args=server_args.json_model_override_args,
|
146
153
|
is_embedding=server_args.is_embedding,
|
154
|
+
dtype=server_args.dtype,
|
155
|
+
quantization=server_args.quantization,
|
147
156
|
)
|
148
157
|
self.is_generation = self.model_config.is_generation
|
149
158
|
|
@@ -253,6 +262,8 @@ class Scheduler:
|
|
253
262
|
|
254
263
|
# Init chunked prefill
|
255
264
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
265
|
+
if self.chunked_prefill_size <= 0: # -1 means disable
|
266
|
+
self.chunked_prefill_size = None
|
256
267
|
self.being_chunked_req = None
|
257
268
|
self.is_mixed_chunk = (
|
258
269
|
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
@@ -504,11 +515,27 @@ class Scheduler:
|
|
504
515
|
self.flush_cache()
|
505
516
|
elif isinstance(recv_req, AbortReq):
|
506
517
|
self.abort_request(recv_req)
|
507
|
-
elif isinstance(recv_req,
|
508
|
-
success, message = self.
|
518
|
+
elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
|
519
|
+
success, message = self.update_weights_from_disk(recv_req)
|
509
520
|
self.send_to_tokenizer.send_pyobj(
|
510
|
-
|
521
|
+
UpdateWeightFromDiskReqOutput(success, message)
|
511
522
|
)
|
523
|
+
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
524
|
+
parameter = self.get_weights_by_name(recv_req)
|
525
|
+
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
526
|
+
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
527
|
+
success, message = self.init_weights_update_group(recv_req)
|
528
|
+
self.send_to_tokenizer.send_pyobj(
|
529
|
+
InitWeightsUpdateGroupReqOutput(success, message)
|
530
|
+
)
|
531
|
+
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
|
532
|
+
success, message = self.update_weights_from_distributed(recv_req)
|
533
|
+
self.send_to_tokenizer.send_pyobj(
|
534
|
+
UpdateWeightsFromDistributedReqOutput(success, message)
|
535
|
+
)
|
536
|
+
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
537
|
+
parameter = self.get_weights_by_name(recv_req)
|
538
|
+
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
512
539
|
elif isinstance(recv_req, ProfileReq):
|
513
540
|
if recv_req == ProfileReq.START_PROFILE:
|
514
541
|
self.start_profile()
|
@@ -653,7 +680,7 @@ class Scheduler:
|
|
653
680
|
|
654
681
|
self.waiting_queue.append(req)
|
655
682
|
|
656
|
-
def log_prefill_stats(self, adder, can_run_list, running_bs,
|
683
|
+
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
657
684
|
if isinstance(self.tree_cache, RadixCache):
|
658
685
|
self.tree_cache_metrics["total"] += (
|
659
686
|
adder.log_input_tokens + adder.log_hit_tokens
|
@@ -677,14 +704,14 @@ class Scheduler:
|
|
677
704
|
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
678
705
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
679
706
|
f"#running-req: {running_bs}, "
|
680
|
-
f"#queue-req: {len(self.waiting_queue) +
|
707
|
+
f"#queue-req: {len(self.waiting_queue) + has_being_chunked}"
|
681
708
|
)
|
682
709
|
|
683
710
|
if self.enable_metrics:
|
684
711
|
self.stats.num_running_reqs = running_bs
|
685
712
|
self.stats.num_used_tokens = num_used
|
686
713
|
self.stats.token_usage = round(num_used / self.max_total_num_tokens, 2)
|
687
|
-
self.stats.num_queue_reqs = len(self.waiting_queue) +
|
714
|
+
self.stats.num_queue_reqs = len(self.waiting_queue) + has_being_chunked
|
688
715
|
self.stats.cache_hit_rate = tree_cache_hit_rate
|
689
716
|
self.metrics_collector.log_stats(self.stats)
|
690
717
|
|
@@ -745,7 +772,7 @@ class Scheduler:
|
|
745
772
|
# Move the chunked request out of the batch
|
746
773
|
self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req)
|
747
774
|
self.tree_cache.cache_unfinished_req(self.being_chunked_req)
|
748
|
-
#
|
775
|
+
# being chunked request keeps its rid but will get a new req_pool_idx
|
749
776
|
self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx)
|
750
777
|
self.batch_is_full = False
|
751
778
|
|
@@ -796,10 +823,10 @@ class Scheduler:
|
|
796
823
|
running_bs if self.is_mixed_chunk else 0,
|
797
824
|
)
|
798
825
|
|
799
|
-
|
800
|
-
if
|
826
|
+
has_being_chunked = self.being_chunked_req is not None
|
827
|
+
if has_being_chunked:
|
801
828
|
self.being_chunked_req.init_next_round_input()
|
802
|
-
self.being_chunked_req = adder.
|
829
|
+
self.being_chunked_req = adder.add_being_chunked_req(self.being_chunked_req)
|
803
830
|
|
804
831
|
if self.lora_paths:
|
805
832
|
lora_set = (
|
@@ -841,16 +868,16 @@ class Scheduler:
|
|
841
868
|
x for x in self.waiting_queue if x not in set(can_run_list)
|
842
869
|
]
|
843
870
|
|
844
|
-
if adder.
|
871
|
+
if adder.new_being_chunked_req is not None:
|
845
872
|
assert self.being_chunked_req is None
|
846
|
-
self.being_chunked_req = adder.
|
873
|
+
self.being_chunked_req = adder.new_being_chunked_req
|
847
874
|
|
848
875
|
if self.being_chunked_req:
|
849
876
|
self.being_chunked_req.is_being_chunked += 1
|
850
877
|
|
851
878
|
# Print stats
|
852
879
|
if self.tp_rank == 0:
|
853
|
-
self.log_prefill_stats(adder, can_run_list, running_bs,
|
880
|
+
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
854
881
|
|
855
882
|
# Create a new batch
|
856
883
|
new_batch = ScheduleBatch.init_new(
|
@@ -1023,7 +1050,7 @@ class Scheduler:
|
|
1023
1050
|
if req.grammar is not None:
|
1024
1051
|
req.grammar.accept_token(next_token_id)
|
1025
1052
|
else:
|
1026
|
-
#
|
1053
|
+
# being chunked reqs' prefill is not finished
|
1027
1054
|
req.is_being_chunked -= 1
|
1028
1055
|
|
1029
1056
|
if batch.next_batch_sampling_info:
|
@@ -1051,7 +1078,7 @@ class Scheduler:
|
|
1051
1078
|
else:
|
1052
1079
|
self.tree_cache.cache_unfinished_req(req)
|
1053
1080
|
else:
|
1054
|
-
#
|
1081
|
+
# being chunked reqs' prefill is not finished
|
1055
1082
|
req.is_being_chunked -= 1
|
1056
1083
|
|
1057
1084
|
self.stream_output(batch.reqs)
|
@@ -1146,6 +1173,14 @@ class Scheduler:
|
|
1146
1173
|
+ 1 : len(req.fill_ids)
|
1147
1174
|
- req.last_update_decode_tokens
|
1148
1175
|
]
|
1176
|
+
|
1177
|
+
# Clip the padded hash values from image tokens.
|
1178
|
+
# Otherwise, it will lead to detokenization errors.
|
1179
|
+
input_token_ids = [
|
1180
|
+
x if x < self.model_config.vocab_size - 1 else 0
|
1181
|
+
for x in input_token_ids
|
1182
|
+
]
|
1183
|
+
|
1149
1184
|
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
1150
1185
|
|
1151
1186
|
if (
|
@@ -1361,9 +1396,26 @@ class Scheduler:
|
|
1361
1396
|
req.to_abort = True
|
1362
1397
|
break
|
1363
1398
|
|
1364
|
-
def
|
1365
|
-
"""In-place update of the weights."""
|
1366
|
-
success, message = self.tp_worker.
|
1399
|
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
1400
|
+
"""In-place update of the weights from disk."""
|
1401
|
+
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
1402
|
+
if success:
|
1403
|
+
flash_cache_success = self.flush_cache()
|
1404
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
1405
|
+
else:
|
1406
|
+
logger.error(message)
|
1407
|
+
return success, message
|
1408
|
+
|
1409
|
+
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
1410
|
+
"""Initialize the online model parameter update group."""
|
1411
|
+
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
1412
|
+
return success, message
|
1413
|
+
|
1414
|
+
def update_weights_from_distributed(
|
1415
|
+
self, recv_req: UpdateWeightsFromDistributedReqInput
|
1416
|
+
):
|
1417
|
+
"""Update the online model parameter."""
|
1418
|
+
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
1367
1419
|
if success:
|
1368
1420
|
flash_cache_success = self.flush_cache()
|
1369
1421
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
@@ -1371,6 +1423,10 @@ class Scheduler:
|
|
1371
1423
|
logger.error(message)
|
1372
1424
|
return success, message
|
1373
1425
|
|
1426
|
+
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1427
|
+
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1428
|
+
return parameter
|
1429
|
+
|
1374
1430
|
def start_profile(self) -> None:
|
1375
1431
|
if self.profiler is None:
|
1376
1432
|
raise RuntimeError("Profiler is not enabled.")
|