sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +46 -25
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +184 -63
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -248
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/model_executor/cuda_graph_runner.py +15 -19
- sglang/srt/model_executor/forward_batch_info.py +94 -95
- sglang/srt/model_executor/model_runner.py +76 -75
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +14 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +71 -26
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +18 -9
- sglang/version.py +1 -1
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -474
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.2.dist-info/RECORD +0 -135
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -17,60 +17,22 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import json
|
19
19
|
import logging
|
20
|
-
import multiprocessing
|
21
|
-
import os
|
22
|
-
import pickle
|
23
|
-
import time
|
24
|
-
import warnings
|
25
|
-
from typing import Any, List, Optional
|
26
20
|
|
27
|
-
import torch
|
28
|
-
import torch.distributed
|
29
|
-
import torch.distributed as dist
|
30
|
-
|
31
|
-
from sglang.global_config import global_config
|
32
21
|
from sglang.srt.configs.model_config import ModelConfig
|
33
|
-
from sglang.srt.constrained.fsm_cache import FSMCache
|
34
|
-
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
35
22
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
36
|
-
from sglang.srt.
|
37
|
-
from sglang.srt.managers.
|
38
|
-
|
39
|
-
BatchEmbeddingOut,
|
40
|
-
BatchTokenIDOut,
|
41
|
-
FlushCacheReq,
|
42
|
-
TokenizedEmbeddingReqInput,
|
43
|
-
TokenizedGenerateReqInput,
|
44
|
-
UpdateWeightReqInput,
|
45
|
-
UpdateWeightReqOutput,
|
46
|
-
)
|
47
|
-
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
48
|
-
from sglang.srt.managers.schedule_batch import (
|
49
|
-
FINISH_ABORT,
|
50
|
-
BaseFinishReason,
|
51
|
-
Req,
|
52
|
-
ScheduleBatch,
|
53
|
-
)
|
54
|
-
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
55
|
-
from sglang.srt.mem_cache.radix_cache import RadixCache
|
23
|
+
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
24
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
25
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
56
26
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
57
27
|
from sglang.srt.server_args import ServerArgs
|
58
|
-
from sglang.srt.utils import
|
59
|
-
configure_logger,
|
60
|
-
is_multimodal_model,
|
61
|
-
set_random_seed,
|
62
|
-
suppress_other_loggers,
|
63
|
-
)
|
64
|
-
from sglang.utils import get_exception_traceback
|
28
|
+
from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
|
65
29
|
|
66
30
|
logger = logging.getLogger(__name__)
|
67
31
|
|
68
32
|
|
69
|
-
|
70
|
-
|
71
|
-
|
33
|
+
class TpModelWorker:
|
34
|
+
"""A tensor parallel model worker."""
|
72
35
|
|
73
|
-
class ModelTpServer:
|
74
36
|
def __init__(
|
75
37
|
self,
|
76
38
|
gpu_id: int,
|
@@ -78,17 +40,8 @@ class ModelTpServer:
|
|
78
40
|
server_args: ServerArgs,
|
79
41
|
nccl_port: int,
|
80
42
|
):
|
81
|
-
|
82
|
-
|
83
|
-
# Parse arguments
|
84
|
-
self.gpu_id = gpu_id
|
43
|
+
# Parse args
|
85
44
|
self.tp_rank = tp_rank
|
86
|
-
self.tp_size = server_args.tp_size
|
87
|
-
self.dp_size = server_args.dp_size
|
88
|
-
self.schedule_policy = server_args.schedule_policy
|
89
|
-
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
90
|
-
self.lora_paths = server_args.lora_paths
|
91
|
-
self.max_loras_per_batch = server_args.max_loras_per_batch
|
92
45
|
|
93
46
|
# Init model and tokenizer
|
94
47
|
self.model_config = ModelConfig(
|
@@ -122,6 +75,8 @@ class ModelTpServer:
|
|
122
75
|
tokenizer_mode=server_args.tokenizer_mode,
|
123
76
|
trust_remote_code=server_args.trust_remote_code,
|
124
77
|
)
|
78
|
+
|
79
|
+
# Profile number of tokens
|
125
80
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
126
81
|
self.max_prefill_tokens = server_args.max_prefill_tokens
|
127
82
|
self.max_running_requests = min(
|
@@ -138,888 +93,36 @@ class ModelTpServer:
|
|
138
93
|
)
|
139
94
|
|
140
95
|
# Sync random seed across TP workers
|
141
|
-
|
96
|
+
self.random_seed = broadcast_pyobj(
|
142
97
|
[server_args.random_seed],
|
143
98
|
self.tp_rank,
|
144
99
|
self.model_runner.tp_group.cpu_group,
|
145
100
|
)[0]
|
146
|
-
set_random_seed(
|
147
|
-
|
148
|
-
# Print debug info
|
149
|
-
logger.info(
|
150
|
-
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
151
|
-
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
152
|
-
f"max_running_requests={self.max_running_requests}, "
|
153
|
-
f"context_len={self.model_config.context_len}"
|
154
|
-
)
|
155
|
-
|
156
|
-
# Init cache
|
157
|
-
if (
|
158
|
-
server_args.chunked_prefill_size is not None
|
159
|
-
and server_args.disable_radix_cache
|
160
|
-
):
|
161
|
-
self.tree_cache = ChunkCache(
|
162
|
-
req_to_token_pool=self.model_runner.req_to_token_pool,
|
163
|
-
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
164
|
-
)
|
165
|
-
else:
|
166
|
-
self.tree_cache = RadixCache(
|
167
|
-
req_to_token_pool=self.model_runner.req_to_token_pool,
|
168
|
-
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
169
|
-
disable=server_args.disable_radix_cache,
|
170
|
-
)
|
171
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
172
|
-
self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
|
173
|
-
self.req_to_token_pool = self.model_runner.req_to_token_pool
|
174
|
-
self.token_to_kv_pool = self.model_runner.token_to_kv_pool
|
175
|
-
|
176
|
-
# Init running status
|
177
|
-
self.waiting_queue: List[Req] = []
|
178
|
-
self.running_batch: ScheduleBatch = None
|
179
|
-
self.out_pyobjs = []
|
180
|
-
self.decode_forward_ct = 0
|
181
|
-
self.stream_interval = server_args.stream_interval
|
182
|
-
self.num_generated_tokens = 0
|
183
|
-
self.last_stats_tic = time.time()
|
184
|
-
|
185
|
-
# Init chunked prefill
|
186
|
-
self.chunked_prefill_size = server_args.chunked_prefill_size
|
187
|
-
self.current_inflight_req = None
|
188
|
-
self.is_mixed_chunk = (
|
189
|
-
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
190
|
-
)
|
191
|
-
|
192
|
-
# Init the FSM cache for constrained generation
|
193
|
-
if not server_args.skip_tokenizer_init:
|
194
|
-
self.regex_fsm_cache = FSMCache(
|
195
|
-
server_args.tokenizer_path,
|
196
|
-
{
|
197
|
-
"tokenizer_mode": server_args.tokenizer_mode,
|
198
|
-
"trust_remote_code": server_args.trust_remote_code,
|
199
|
-
},
|
200
|
-
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
201
|
-
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
202
|
-
)
|
203
|
-
self.jump_forward_cache = JumpForwardCache()
|
204
|
-
|
205
|
-
# Init new token estimation
|
206
|
-
assert (
|
207
|
-
server_args.schedule_conservativeness >= 0
|
208
|
-
), "Invalid schedule_conservativeness"
|
209
|
-
self.min_new_token_ratio = min(
|
210
|
-
global_config.base_min_new_token_ratio
|
211
|
-
* server_args.schedule_conservativeness,
|
212
|
-
1.0,
|
213
|
-
)
|
214
|
-
self.new_token_ratio = self.min_new_token_ratio
|
215
|
-
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
216
|
-
self.do_not_get_new_batch = False
|
217
|
-
|
218
|
-
@torch.inference_mode()
|
219
|
-
def exposed_step(self, recv_reqs: List):
|
220
|
-
try:
|
221
|
-
# Recv requests
|
222
|
-
for recv_req in recv_reqs:
|
223
|
-
if isinstance(recv_req, TokenizedGenerateReqInput):
|
224
|
-
self.handle_generate_request(recv_req)
|
225
|
-
self.do_not_get_new_batch = False
|
226
|
-
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
227
|
-
self.handle_embedding_request(recv_req)
|
228
|
-
self.do_not_get_new_batch = False
|
229
|
-
elif isinstance(recv_req, FlushCacheReq):
|
230
|
-
self.flush_cache()
|
231
|
-
elif isinstance(recv_req, AbortReq):
|
232
|
-
self.abort_request(recv_req)
|
233
|
-
elif isinstance(recv_req, UpdateWeightReqInput):
|
234
|
-
success, message = self.update_weights(recv_req)
|
235
|
-
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
236
|
-
else:
|
237
|
-
raise ValueError(f"Invalid request: {recv_req}")
|
238
|
-
|
239
|
-
# Forward
|
240
|
-
self.forward_step()
|
241
|
-
except Exception:
|
242
|
-
logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
|
243
|
-
raise
|
244
|
-
|
245
|
-
# Return results
|
246
|
-
ret = self.out_pyobjs
|
247
|
-
self.out_pyobjs = []
|
248
|
-
return ret
|
249
|
-
|
250
|
-
def forward_step(self):
|
251
|
-
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
252
|
-
new_batch = None
|
253
|
-
else:
|
254
|
-
new_batch = self.get_new_prefill_batch()
|
255
|
-
self.do_not_get_new_batch = False
|
256
|
-
|
257
|
-
if new_batch is not None:
|
258
|
-
# Run a new prefill batch
|
259
|
-
self.forward_prefill_batch(new_batch)
|
101
|
+
set_random_seed(self.random_seed)
|
260
102
|
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
else:
|
265
|
-
self.running_batch.merge(new_batch)
|
266
|
-
else:
|
267
|
-
# Run a decode batch
|
268
|
-
if self.running_batch is not None:
|
269
|
-
# Run a few decode batches continuously for reducing overhead
|
270
|
-
for _ in range(global_config.num_continue_decode_steps):
|
271
|
-
self.num_generated_tokens += len(self.running_batch.reqs)
|
272
|
-
self.forward_decode_batch(self.running_batch)
|
273
|
-
|
274
|
-
# Print stats
|
275
|
-
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
276
|
-
self.print_decode_stats()
|
277
|
-
|
278
|
-
if self.running_batch.is_empty():
|
279
|
-
self.running_batch = None
|
280
|
-
break
|
281
|
-
|
282
|
-
if self.out_pyobjs and self.running_batch.has_stream:
|
283
|
-
break
|
284
|
-
else:
|
285
|
-
self.check_memory()
|
286
|
-
self.new_token_ratio = global_config.init_new_token_ratio
|
287
|
-
|
288
|
-
def print_decode_stats(self):
|
289
|
-
num_used = self.max_total_num_tokens - (
|
290
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
291
|
-
)
|
292
|
-
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
293
|
-
self.num_generated_tokens = 0
|
294
|
-
self.last_stats_tic = time.time()
|
295
|
-
logger.info(
|
296
|
-
f"Decode batch. "
|
297
|
-
f"#running-req: {len(self.running_batch.reqs)}, "
|
298
|
-
f"#token: {num_used}, "
|
299
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
300
|
-
f"gen throughput (token/s): {throughput:.2f}, "
|
301
|
-
f"#queue-req: {len(self.waiting_queue)}"
|
302
|
-
)
|
303
|
-
|
304
|
-
def check_memory(self):
|
305
|
-
available_size = (
|
306
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
307
|
-
)
|
308
|
-
if available_size != self.max_total_num_tokens:
|
309
|
-
warnings.warn(
|
310
|
-
"Warning: "
|
311
|
-
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
|
312
|
-
"KV cache pool leak detected!"
|
313
|
-
)
|
314
|
-
exit(1) if crash_on_warning else None
|
315
|
-
|
316
|
-
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
317
|
-
warnings.warn(
|
318
|
-
"Warning: "
|
319
|
-
f"available req slots={len(self.req_to_token_pool.free_slots)}, "
|
320
|
-
f"total slots={self.req_to_token_pool.size}\n"
|
321
|
-
"Memory pool leak detected!"
|
322
|
-
)
|
323
|
-
exit(1) if crash_on_warning else None
|
324
|
-
|
325
|
-
def handle_generate_request(
|
326
|
-
self,
|
327
|
-
recv_req: TokenizedGenerateReqInput,
|
328
|
-
):
|
329
|
-
if isinstance(recv_req, TokenizedGenerateReqInput):
|
330
|
-
req = Req(
|
331
|
-
recv_req.rid,
|
332
|
-
recv_req.input_text,
|
333
|
-
recv_req.input_ids,
|
334
|
-
lora_path=recv_req.lora_path,
|
335
|
-
)
|
336
|
-
else:
|
337
|
-
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
338
|
-
req.tokenizer = self.tokenizer
|
339
|
-
req.sampling_params = recv_req.sampling_params
|
340
|
-
req.pixel_values = recv_req.pixel_values
|
341
|
-
if req.pixel_values is not None:
|
342
|
-
# Use image hash as fake token_ids, which is then used
|
343
|
-
# for prefix matching
|
344
|
-
image_hash = hash(tuple(recv_req.image_hashes))
|
345
|
-
req.pad_value = [
|
346
|
-
(image_hash) % self.model_config.vocab_size,
|
347
|
-
(image_hash >> 16) % self.model_config.vocab_size,
|
348
|
-
(image_hash >> 32) % self.model_config.vocab_size,
|
349
|
-
(image_hash >> 64) % self.model_config.vocab_size,
|
350
|
-
]
|
351
|
-
req.image_sizes = recv_req.image_sizes
|
352
|
-
(
|
353
|
-
req.origin_input_ids,
|
354
|
-
req.image_offsets,
|
355
|
-
) = self.model_runner.model.pad_input_ids(
|
356
|
-
req.origin_input_ids_unpadded,
|
357
|
-
req.pad_value,
|
358
|
-
req.pixel_values,
|
359
|
-
req.image_sizes,
|
360
|
-
)
|
361
|
-
# Only when pixel values is not None we have modalities
|
362
|
-
req.modalities = recv_req.modalites
|
363
|
-
req.return_logprob = recv_req.return_logprob
|
364
|
-
req.top_logprobs_num = recv_req.top_logprobs_num
|
365
|
-
req.stream = recv_req.stream
|
366
|
-
req.logprob_start_len = recv_req.logprob_start_len
|
367
|
-
|
368
|
-
if req.logprob_start_len == -1:
|
369
|
-
# By default, only return the logprobs for output tokens
|
370
|
-
req.logprob_start_len = len(recv_req.input_ids) - 1
|
371
|
-
|
372
|
-
# Init regex FSM
|
373
|
-
if (
|
374
|
-
req.sampling_params.json_schema is not None
|
375
|
-
or req.sampling_params.regex is not None
|
376
|
-
):
|
377
|
-
if req.sampling_params.json_schema is not None:
|
378
|
-
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
379
|
-
("json", req.sampling_params.json_schema)
|
380
|
-
)
|
381
|
-
elif req.sampling_params.regex is not None:
|
382
|
-
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
383
|
-
("regex", req.sampling_params.regex)
|
384
|
-
)
|
385
|
-
if not self.disable_regex_jump_forward:
|
386
|
-
req.jump_forward_map = self.jump_forward_cache.query(
|
387
|
-
computed_regex_string
|
388
|
-
)
|
389
|
-
|
390
|
-
# Truncate prompts that are too long
|
391
|
-
if len(req.origin_input_ids) >= self.max_req_input_len:
|
392
|
-
logger.warning(
|
393
|
-
"Request length is longer than the KV cache pool size or "
|
394
|
-
"the max context length. Truncated!!!"
|
395
|
-
)
|
396
|
-
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
397
|
-
req.sampling_params.max_new_tokens = min(
|
398
|
-
(
|
399
|
-
req.sampling_params.max_new_tokens
|
400
|
-
if req.sampling_params.max_new_tokens is not None
|
401
|
-
else 1 << 30
|
402
|
-
),
|
403
|
-
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
404
|
-
)
|
405
|
-
|
406
|
-
self.waiting_queue.append(req)
|
407
|
-
|
408
|
-
def handle_embedding_request(
|
409
|
-
self,
|
410
|
-
recv_req: TokenizedEmbeddingReqInput,
|
411
|
-
):
|
412
|
-
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
413
|
-
req.tokenizer = self.tokenizer
|
414
|
-
req.sampling_params = recv_req.sampling_params
|
415
|
-
|
416
|
-
# Truncate prompts that are too long
|
417
|
-
if len(req.origin_input_ids) >= self.max_req_input_len:
|
418
|
-
logger.warning(
|
419
|
-
"Request length is longer than the KV cache pool size or "
|
420
|
-
"the max context length. Truncated!!!"
|
421
|
-
)
|
422
|
-
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
423
|
-
|
424
|
-
self.waiting_queue.append(req)
|
425
|
-
|
426
|
-
def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
|
427
|
-
running_bs = (
|
428
|
-
len(self.running_batch.reqs) if self.running_batch is not None else 0
|
429
|
-
)
|
430
|
-
if running_bs >= self.max_running_requests:
|
431
|
-
return None
|
432
|
-
|
433
|
-
# Get priority queue
|
434
|
-
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
435
|
-
|
436
|
-
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
437
|
-
|
438
|
-
adder = PrefillAdder(
|
439
|
-
self.tree_cache,
|
440
|
-
self.running_batch,
|
441
|
-
self.new_token_ratio,
|
442
|
-
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
103
|
+
def get_token_and_memory_info(self):
|
104
|
+
return (
|
105
|
+
self.max_total_num_tokens,
|
443
106
|
self.max_prefill_tokens,
|
444
|
-
self.
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
has_inflight = self.current_inflight_req is not None
|
449
|
-
if self.current_inflight_req is not None:
|
450
|
-
self.current_inflight_req.init_next_round_input(
|
451
|
-
None if prefix_computed else self.tree_cache
|
452
|
-
)
|
453
|
-
self.current_inflight_req = adder.add_inflight_req(
|
454
|
-
self.current_inflight_req
|
455
|
-
)
|
456
|
-
|
457
|
-
if self.lora_paths is not None:
|
458
|
-
lora_set = (
|
459
|
-
set([req.lora_path for req in self.running_batch.reqs])
|
460
|
-
if self.running_batch is not None
|
461
|
-
else set([])
|
462
|
-
)
|
463
|
-
|
464
|
-
for req in self.waiting_queue:
|
465
|
-
if (
|
466
|
-
self.lora_paths is not None
|
467
|
-
and len(
|
468
|
-
lora_set
|
469
|
-
| set([req.lora_path for req in adder.can_run_list])
|
470
|
-
| set([req.lora_path])
|
471
|
-
)
|
472
|
-
> self.max_loras_per_batch
|
473
|
-
):
|
474
|
-
break
|
475
|
-
|
476
|
-
if adder.no_remaining_tokens():
|
477
|
-
break
|
478
|
-
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
479
|
-
res = adder.add_one_req(req)
|
480
|
-
if (
|
481
|
-
not res
|
482
|
-
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
483
|
-
):
|
484
|
-
break
|
485
|
-
|
486
|
-
can_run_list = adder.can_run_list
|
487
|
-
|
488
|
-
if adder.new_inflight_req is not None:
|
489
|
-
assert self.current_inflight_req is None
|
490
|
-
self.current_inflight_req = adder.new_inflight_req
|
491
|
-
|
492
|
-
if len(can_run_list) == 0:
|
493
|
-
return None
|
494
|
-
|
495
|
-
# Print stats
|
496
|
-
if self.tp_rank == 0:
|
497
|
-
if isinstance(self.tree_cache, RadixCache):
|
498
|
-
self.tree_cache_metrics["total"] += (
|
499
|
-
adder.log_input_tokens + adder.log_hit_tokens
|
500
|
-
) / 10**9
|
501
|
-
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
502
|
-
tree_cache_hit_rate = (
|
503
|
-
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
504
|
-
)
|
505
|
-
else:
|
506
|
-
tree_cache_hit_rate = 0.0
|
507
|
-
|
508
|
-
num_used = self.max_total_num_tokens - (
|
509
|
-
self.token_to_kv_pool.available_size()
|
510
|
-
+ self.tree_cache.evictable_size()
|
511
|
-
)
|
512
|
-
|
513
|
-
if num_mixed_running > 0:
|
514
|
-
logger.info(
|
515
|
-
f"Prefill batch"
|
516
|
-
f"(mixed #running-req: {num_mixed_running}). "
|
517
|
-
f"#new-seq: {len(can_run_list)}, "
|
518
|
-
f"#new-token: {adder.log_input_tokens}, "
|
519
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
520
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
521
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
522
|
-
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
523
|
-
)
|
524
|
-
else:
|
525
|
-
logger.info(
|
526
|
-
f"Prefill batch. "
|
527
|
-
f"#new-seq: {len(can_run_list)}, "
|
528
|
-
f"#new-token: {adder.log_input_tokens}, "
|
529
|
-
f"#cached-token: {adder.log_hit_tokens}, "
|
530
|
-
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
531
|
-
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
532
|
-
f"#running-req: {running_bs}, "
|
533
|
-
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
534
|
-
)
|
535
|
-
|
536
|
-
# Return the new batch
|
537
|
-
new_batch = ScheduleBatch.init_new(
|
538
|
-
can_run_list,
|
539
|
-
self.req_to_token_pool,
|
540
|
-
self.token_to_kv_pool,
|
541
|
-
self.tree_cache,
|
542
|
-
)
|
543
|
-
self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
|
544
|
-
return new_batch
|
545
|
-
|
546
|
-
def forward_prefill_batch(self, batch: ScheduleBatch):
|
547
|
-
# Build batch tensors
|
548
|
-
batch.prepare_for_extend(self.model_config.vocab_size)
|
549
|
-
|
550
|
-
decoding_reqs = []
|
551
|
-
if self.is_mixed_chunk and self.running_batch is not None:
|
552
|
-
self.running_batch.prepare_for_decode()
|
553
|
-
batch.mix_with_running(self.running_batch)
|
554
|
-
decoding_reqs = self.running_batch.reqs
|
555
|
-
self.running_batch = None
|
556
|
-
|
557
|
-
if self.model_runner.is_generation:
|
558
|
-
# Forward and sample the next tokens
|
559
|
-
if batch.extend_num_tokens != 0:
|
560
|
-
logits_output = self.model_runner.forward(batch)
|
561
|
-
next_token_ids = self.model_runner.sample(logits_output, batch)
|
562
|
-
|
563
|
-
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
564
|
-
next_token_ids
|
565
|
-
)
|
566
|
-
|
567
|
-
# Move logprobs to cpu
|
568
|
-
if logits_output.next_token_logprobs is not None:
|
569
|
-
logits_output.next_token_logprobs = (
|
570
|
-
logits_output.next_token_logprobs[
|
571
|
-
torch.arange(
|
572
|
-
len(next_token_ids), device=next_token_ids.device
|
573
|
-
),
|
574
|
-
next_token_ids,
|
575
|
-
].tolist()
|
576
|
-
)
|
577
|
-
logits_output.input_token_logprobs = (
|
578
|
-
logits_output.input_token_logprobs.tolist()
|
579
|
-
)
|
580
|
-
logits_output.normalized_prompt_logprobs = (
|
581
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
582
|
-
)
|
583
|
-
|
584
|
-
next_token_ids = next_token_ids.tolist()
|
585
|
-
else:
|
586
|
-
if self.tokenizer is None:
|
587
|
-
next_token_ids = []
|
588
|
-
for req in batch.reqs:
|
589
|
-
next_token_ids.append(
|
590
|
-
next(iter(req.sampling_params.stop_token_ids))
|
591
|
-
)
|
592
|
-
else:
|
593
|
-
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
594
|
-
|
595
|
-
# Check finish conditions
|
596
|
-
logprob_pt = 0
|
597
|
-
for i, req in enumerate(batch.reqs):
|
598
|
-
if req is not self.current_inflight_req:
|
599
|
-
# Inflight reqs' prefill is not finished
|
600
|
-
req.completion_tokens_wo_jump_forward += 1
|
601
|
-
req.output_ids.append(next_token_ids[i])
|
602
|
-
req.check_finished()
|
603
|
-
|
604
|
-
if req.regex_fsm is not None:
|
605
|
-
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
606
|
-
req.regex_fsm_state, next_token_ids[i]
|
607
|
-
)
|
608
|
-
|
609
|
-
if req.finished():
|
610
|
-
self.tree_cache.cache_finished_req(req)
|
611
|
-
elif req not in decoding_reqs:
|
612
|
-
# To reduce overhead, only cache prefill reqs
|
613
|
-
self.tree_cache.cache_unfinished_req(req)
|
614
|
-
|
615
|
-
if req is self.current_inflight_req:
|
616
|
-
# Inflight request would get a new req idx
|
617
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
618
|
-
|
619
|
-
if req.return_logprob:
|
620
|
-
logprob_pt += self.add_logprob_return_values(
|
621
|
-
i, req, logprob_pt, next_token_ids, logits_output
|
622
|
-
)
|
623
|
-
else:
|
624
|
-
assert batch.extend_num_tokens != 0
|
625
|
-
logits_output = self.model_runner.forward(batch)
|
626
|
-
embeddings = logits_output.embeddings.tolist()
|
627
|
-
|
628
|
-
# Check finish conditions
|
629
|
-
for i, req in enumerate(batch.reqs):
|
630
|
-
req.embedding = embeddings[i]
|
631
|
-
if req is not self.current_inflight_req:
|
632
|
-
# Inflight reqs' prefill is not finished
|
633
|
-
# dummy output token for embedding models
|
634
|
-
req.output_ids.append(0)
|
635
|
-
req.check_finished()
|
636
|
-
|
637
|
-
if req.finished():
|
638
|
-
self.tree_cache.cache_finished_req(req)
|
639
|
-
else:
|
640
|
-
self.tree_cache.cache_unfinished_req(req)
|
641
|
-
|
642
|
-
if req is self.current_inflight_req:
|
643
|
-
# Inflight request would get a new req idx
|
644
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
645
|
-
|
646
|
-
self.handle_finished_requests(batch)
|
647
|
-
|
648
|
-
def add_logprob_return_values(
|
649
|
-
self,
|
650
|
-
i: int,
|
651
|
-
req: Req,
|
652
|
-
pt: int,
|
653
|
-
next_token_ids: List[int],
|
654
|
-
output: LogitsProcessorOutput,
|
655
|
-
):
|
656
|
-
"""Attach logprobs to the return values."""
|
657
|
-
req.output_token_logprobs.append(
|
658
|
-
(output.next_token_logprobs[i], next_token_ids[i])
|
107
|
+
self.max_running_requests,
|
108
|
+
self.max_req_input_len,
|
109
|
+
self.random_seed,
|
659
110
|
)
|
660
111
|
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
112
|
+
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
113
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
114
|
+
logits_output = self.model_runner.forward(forward_batch)
|
115
|
+
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
116
|
+
return logits_output, next_token_ids
|
666
117
|
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
len(req.fill_ids)
|
673
|
-
- num_input_logprobs
|
674
|
-
+ 1 : len(req.fill_ids)
|
675
|
-
- req.last_update_decode_tokens
|
676
|
-
]
|
677
|
-
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
118
|
+
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
119
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
120
|
+
logits_output = self.model_runner.forward(forward_batch)
|
121
|
+
embeddings = logits_output.embeddings.tolist()
|
122
|
+
return embeddings
|
678
123
|
|
679
|
-
|
680
|
-
req.logprob_start_len == 0
|
681
|
-
): # The first token does not have logprob, pad it.
|
682
|
-
req.input_token_logprobs = [
|
683
|
-
(None, req.fill_ids[0])
|
684
|
-
] + req.input_token_logprobs
|
685
|
-
|
686
|
-
if req.last_update_decode_tokens != 0:
|
687
|
-
# Some decode tokens are re-computed in an extend batch
|
688
|
-
req.output_token_logprobs.extend(
|
689
|
-
list(
|
690
|
-
zip(
|
691
|
-
output.input_token_logprobs[
|
692
|
-
pt
|
693
|
-
+ num_input_logprobs
|
694
|
-
- 1
|
695
|
-
- req.last_update_decode_tokens : pt
|
696
|
-
+ num_input_logprobs
|
697
|
-
- 1
|
698
|
-
],
|
699
|
-
req.fill_ids[
|
700
|
-
len(req.fill_ids)
|
701
|
-
- req.last_update_decode_tokens : len(req.fill_ids)
|
702
|
-
],
|
703
|
-
)
|
704
|
-
)
|
705
|
-
)
|
706
|
-
|
707
|
-
if req.top_logprobs_num > 0:
|
708
|
-
if req.input_top_logprobs is None:
|
709
|
-
req.input_top_logprobs = output.input_top_logprobs[i]
|
710
|
-
if req.logprob_start_len == 0:
|
711
|
-
req.input_top_logprobs = [None] + req.input_top_logprobs
|
712
|
-
|
713
|
-
if req.last_update_decode_tokens != 0:
|
714
|
-
req.output_top_logprobs.extend(
|
715
|
-
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
716
|
-
)
|
717
|
-
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
718
|
-
|
719
|
-
return num_input_logprobs
|
720
|
-
|
721
|
-
def forward_decode_batch(self, batch: ScheduleBatch):
|
722
|
-
# Check if decode out of memory
|
723
|
-
if not batch.check_decode_mem():
|
724
|
-
old_ratio = self.new_token_ratio
|
725
|
-
|
726
|
-
retracted_reqs, new_token_ratio = batch.retract_decode()
|
727
|
-
self.new_token_ratio = new_token_ratio
|
728
|
-
|
729
|
-
logger.info(
|
730
|
-
"Decode out of memory happened. "
|
731
|
-
f"#retracted_reqs: {len(retracted_reqs)}, "
|
732
|
-
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
733
|
-
)
|
734
|
-
self.waiting_queue.extend(retracted_reqs)
|
735
|
-
else:
|
736
|
-
self.new_token_ratio = max(
|
737
|
-
self.new_token_ratio - self.new_token_ratio_decay,
|
738
|
-
self.min_new_token_ratio,
|
739
|
-
)
|
740
|
-
|
741
|
-
if not self.disable_regex_jump_forward:
|
742
|
-
# Check for jump-forward
|
743
|
-
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
744
|
-
self.waiting_queue.extend(jump_forward_reqs)
|
745
|
-
if batch.is_empty():
|
746
|
-
return
|
747
|
-
|
748
|
-
# Update batch tensors
|
749
|
-
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
750
|
-
batch.prepare_for_decode()
|
751
|
-
|
752
|
-
# Forward and sample the next tokens
|
753
|
-
logits_output = self.model_runner.forward(batch)
|
754
|
-
next_token_ids = self.model_runner.sample(logits_output, batch)
|
755
|
-
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
756
|
-
next_token_ids
|
757
|
-
)
|
758
|
-
|
759
|
-
# Move logprobs to cpu
|
760
|
-
if logits_output.next_token_logprobs is not None:
|
761
|
-
next_token_logprobs = logits_output.next_token_logprobs[
|
762
|
-
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
763
|
-
next_token_ids,
|
764
|
-
].tolist()
|
765
|
-
|
766
|
-
next_token_ids = next_token_ids.tolist()
|
767
|
-
|
768
|
-
# Check finish condition
|
769
|
-
has_finished = False
|
770
|
-
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
771
|
-
req.completion_tokens_wo_jump_forward += 1
|
772
|
-
req.output_ids.append(next_token_id)
|
773
|
-
req.check_finished()
|
774
|
-
|
775
|
-
if req.regex_fsm is not None:
|
776
|
-
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
777
|
-
req.regex_fsm_state, next_token_id
|
778
|
-
)
|
779
|
-
|
780
|
-
if req.finished():
|
781
|
-
self.tree_cache.cache_finished_req(req)
|
782
|
-
has_finished = True
|
783
|
-
|
784
|
-
if req.return_logprob:
|
785
|
-
req.output_token_logprobs.append(
|
786
|
-
(next_token_logprobs[i], next_token_id)
|
787
|
-
)
|
788
|
-
if req.top_logprobs_num > 0:
|
789
|
-
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
790
|
-
|
791
|
-
if not has_finished:
|
792
|
-
self.do_not_get_new_batch = True
|
793
|
-
|
794
|
-
self.handle_finished_requests(batch)
|
795
|
-
|
796
|
-
def handle_finished_requests(self, batch: ScheduleBatch):
|
797
|
-
output_rids = []
|
798
|
-
output_meta_info = []
|
799
|
-
output_finished_reason: List[BaseFinishReason] = []
|
800
|
-
if self.model_runner.is_generation:
|
801
|
-
output_vids = []
|
802
|
-
decoded_texts = []
|
803
|
-
output_read_ids = []
|
804
|
-
output_read_offsets = []
|
805
|
-
output_skip_special_tokens = []
|
806
|
-
output_spaces_between_special_tokens = []
|
807
|
-
else: # for embedding model
|
808
|
-
output_embeddings = []
|
809
|
-
unfinished_indices = []
|
810
|
-
|
811
|
-
for i, req in enumerate(batch.reqs):
|
812
|
-
if not req.finished() and req is not self.current_inflight_req:
|
813
|
-
unfinished_indices.append(i)
|
814
|
-
|
815
|
-
if req.finished() or (
|
816
|
-
req.stream
|
817
|
-
and (
|
818
|
-
self.decode_forward_ct % self.stream_interval == 0
|
819
|
-
or len(req.output_ids) == 1
|
820
|
-
)
|
821
|
-
):
|
822
|
-
output_rids.append(req.rid)
|
823
|
-
output_finished_reason.append(req.finished_reason)
|
824
|
-
if self.model_runner.is_generation:
|
825
|
-
output_vids.append(req.vid)
|
826
|
-
decoded_texts.append(req.decoded_text)
|
827
|
-
read_ids, read_offset = req.init_incremental_detokenize()
|
828
|
-
output_read_ids.append(read_ids)
|
829
|
-
output_read_offsets.append(read_offset)
|
830
|
-
output_skip_special_tokens.append(
|
831
|
-
req.sampling_params.skip_special_tokens
|
832
|
-
)
|
833
|
-
output_spaces_between_special_tokens.append(
|
834
|
-
req.sampling_params.spaces_between_special_tokens
|
835
|
-
)
|
836
|
-
|
837
|
-
meta_info = {
|
838
|
-
"prompt_tokens": len(req.origin_input_ids),
|
839
|
-
"completion_tokens": len(req.output_ids),
|
840
|
-
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
841
|
-
"finish_reason": (
|
842
|
-
req.finished_reason.to_json()
|
843
|
-
if req.finished_reason is not None
|
844
|
-
else None
|
845
|
-
),
|
846
|
-
}
|
847
|
-
if req.return_logprob:
|
848
|
-
(
|
849
|
-
meta_info["input_token_logprobs"],
|
850
|
-
meta_info["output_token_logprobs"],
|
851
|
-
meta_info["input_top_logprobs"],
|
852
|
-
meta_info["output_top_logprobs"],
|
853
|
-
meta_info["normalized_prompt_logprob"],
|
854
|
-
) = (
|
855
|
-
req.input_token_logprobs,
|
856
|
-
req.output_token_logprobs,
|
857
|
-
req.input_top_logprobs,
|
858
|
-
req.output_top_logprobs,
|
859
|
-
req.normalized_prompt_logprob,
|
860
|
-
)
|
861
|
-
output_meta_info.append(meta_info)
|
862
|
-
else: # for embedding model
|
863
|
-
output_embeddings.append(req.embedding)
|
864
|
-
meta_info = {
|
865
|
-
"prompt_tokens": len(req.origin_input_ids),
|
866
|
-
}
|
867
|
-
output_meta_info.append(meta_info)
|
868
|
-
|
869
|
-
# Send to detokenizer
|
870
|
-
if output_rids:
|
871
|
-
if self.model_runner.is_generation:
|
872
|
-
self.out_pyobjs.append(
|
873
|
-
BatchTokenIDOut(
|
874
|
-
output_rids,
|
875
|
-
output_vids,
|
876
|
-
decoded_texts,
|
877
|
-
output_read_ids,
|
878
|
-
output_read_offsets,
|
879
|
-
output_skip_special_tokens,
|
880
|
-
output_spaces_between_special_tokens,
|
881
|
-
output_meta_info,
|
882
|
-
output_finished_reason,
|
883
|
-
)
|
884
|
-
)
|
885
|
-
else: # for embedding model
|
886
|
-
self.out_pyobjs.append(
|
887
|
-
BatchEmbeddingOut(
|
888
|
-
output_rids,
|
889
|
-
output_embeddings,
|
890
|
-
output_meta_info,
|
891
|
-
output_finished_reason,
|
892
|
-
)
|
893
|
-
)
|
894
|
-
|
895
|
-
# Remove finished reqs: update batch tensors
|
896
|
-
batch.filter_batch(unfinished_indices)
|
897
|
-
|
898
|
-
def flush_cache(self):
|
899
|
-
if len(self.waiting_queue) == 0 and (
|
900
|
-
self.running_batch is None or len(self.running_batch.reqs) == 0
|
901
|
-
):
|
902
|
-
self.tree_cache.reset()
|
903
|
-
self.tree_cache_metrics = {"total": 0, "hit": 0}
|
904
|
-
self.regex_fsm_cache.reset()
|
905
|
-
self.req_to_token_pool.clear()
|
906
|
-
self.token_to_kv_pool.clear()
|
907
|
-
torch.cuda.empty_cache()
|
908
|
-
logger.info("Cache flushed successfully!")
|
909
|
-
if_success = True
|
910
|
-
else:
|
911
|
-
logging.warning(
|
912
|
-
f"Cache not flushed because there are pending requests. "
|
913
|
-
f"#queue-req: {len(self.waiting_queue)}, "
|
914
|
-
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
915
|
-
)
|
916
|
-
if_success = False
|
917
|
-
return if_success
|
918
|
-
|
919
|
-
def abort_request(self, recv_req):
|
920
|
-
# Delete requests in the waiting queue
|
921
|
-
to_del = None
|
922
|
-
for i, req in enumerate(self.waiting_queue):
|
923
|
-
if req.rid == recv_req.rid:
|
924
|
-
to_del = i
|
925
|
-
break
|
926
|
-
|
927
|
-
if to_del is not None:
|
928
|
-
del self.waiting_queue[to_del]
|
929
|
-
|
930
|
-
# Delete requests in the running batch
|
931
|
-
if self.running_batch:
|
932
|
-
for req in self.running_batch.reqs:
|
933
|
-
if req.rid == recv_req.rid:
|
934
|
-
req.finished_reason = FINISH_ABORT()
|
935
|
-
break
|
936
|
-
|
937
|
-
def update_weights(self, recv_req):
|
124
|
+
def update_weights(self, recv_req: UpdateWeightReqInput):
|
938
125
|
success, message = self.model_runner.update_weights(
|
939
126
|
recv_req.model_path, recv_req.load_format
|
940
127
|
)
|
941
|
-
if success:
|
942
|
-
flash_cache_success = self.flush_cache()
|
943
|
-
assert flash_cache_success, "Cache flush failed after updating weights"
|
944
|
-
else:
|
945
|
-
logger.error(message)
|
946
128
|
return success, message
|
947
|
-
|
948
|
-
|
949
|
-
def run_tp_server(
|
950
|
-
gpu_id: int,
|
951
|
-
tp_rank: int,
|
952
|
-
server_args: ServerArgs,
|
953
|
-
nccl_port: int,
|
954
|
-
):
|
955
|
-
"""Run a tensor parallel model server."""
|
956
|
-
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
957
|
-
|
958
|
-
try:
|
959
|
-
model_server = ModelTpServer(
|
960
|
-
gpu_id,
|
961
|
-
tp_rank,
|
962
|
-
server_args,
|
963
|
-
nccl_port,
|
964
|
-
)
|
965
|
-
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
966
|
-
|
967
|
-
while True:
|
968
|
-
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
969
|
-
model_server.exposed_step(recv_reqs)
|
970
|
-
except Exception:
|
971
|
-
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
972
|
-
raise
|
973
|
-
|
974
|
-
|
975
|
-
def launch_tp_servers(
|
976
|
-
gpu_ids: List[int],
|
977
|
-
tp_rank_range: List[int],
|
978
|
-
server_args: ServerArgs,
|
979
|
-
nccl_port: int,
|
980
|
-
):
|
981
|
-
"""Launch multiple tensor parallel servers."""
|
982
|
-
procs = []
|
983
|
-
for i in tp_rank_range:
|
984
|
-
proc = multiprocessing.Process(
|
985
|
-
target=run_tp_server,
|
986
|
-
args=(gpu_ids[i], i, server_args, nccl_port),
|
987
|
-
)
|
988
|
-
proc.start()
|
989
|
-
procs.append(proc)
|
990
|
-
|
991
|
-
return procs
|
992
|
-
|
993
|
-
|
994
|
-
def broadcast_recv_input(
|
995
|
-
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
|
996
|
-
):
|
997
|
-
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
998
|
-
|
999
|
-
if rank == 0:
|
1000
|
-
if len(data) == 0:
|
1001
|
-
tensor_size = torch.tensor([0], dtype=torch.long)
|
1002
|
-
dist.broadcast(tensor_size, src=0, group=dist_group)
|
1003
|
-
else:
|
1004
|
-
serialized_data = pickle.dumps(data)
|
1005
|
-
size = len(serialized_data)
|
1006
|
-
tensor_data = torch.ByteTensor(list(serialized_data))
|
1007
|
-
tensor_size = torch.tensor([size], dtype=torch.long)
|
1008
|
-
|
1009
|
-
dist.broadcast(tensor_size, src=0, group=dist_group)
|
1010
|
-
dist.broadcast(tensor_data, src=0, group=dist_group)
|
1011
|
-
return data
|
1012
|
-
else:
|
1013
|
-
tensor_size = torch.tensor([0], dtype=torch.long)
|
1014
|
-
dist.broadcast(tensor_size, src=0, group=dist_group)
|
1015
|
-
size = tensor_size.item()
|
1016
|
-
|
1017
|
-
if size == 0:
|
1018
|
-
return []
|
1019
|
-
|
1020
|
-
tensor_data = torch.empty(size, dtype=torch.uint8)
|
1021
|
-
dist.broadcast(tensor_data, src=0, group=dist_group)
|
1022
|
-
|
1023
|
-
serialized_data = bytes(tensor_data.tolist())
|
1024
|
-
data = pickle.loads(serialized_data)
|
1025
|
-
return data
|