sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -15,19 +15,21 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""A tensor parallel worker."""
|
17
17
|
|
18
|
+
import json
|
18
19
|
import logging
|
19
20
|
import multiprocessing
|
20
21
|
import os
|
21
22
|
import pickle
|
22
23
|
import time
|
23
24
|
import warnings
|
24
|
-
from typing import Any, List, Optional
|
25
|
+
from typing import Any, List, Optional
|
25
26
|
|
26
27
|
import torch
|
27
28
|
import torch.distributed
|
28
29
|
import torch.distributed as dist
|
29
30
|
|
30
31
|
from sglang.global_config import global_config
|
32
|
+
from sglang.srt.configs.model_config import ModelConfig
|
31
33
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
34
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
35
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
@@ -51,8 +53,6 @@ from sglang.srt.managers.schedule_batch import (
|
|
51
53
|
)
|
52
54
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
53
55
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
54
|
-
from sglang.srt.model_config import ModelConfig
|
55
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
56
56
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
57
57
|
from sglang.srt.server_args import ServerArgs
|
58
58
|
from sglang.srt.utils import (
|
@@ -66,6 +66,7 @@ from sglang.utils import get_exception_traceback
|
|
66
66
|
logger = logging.getLogger(__name__)
|
67
67
|
|
68
68
|
|
69
|
+
# Crash on warning if we are running CI tests
|
69
70
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
70
71
|
|
71
72
|
|
@@ -76,26 +77,26 @@ class ModelTpServer:
|
|
76
77
|
tp_rank: int,
|
77
78
|
server_args: ServerArgs,
|
78
79
|
nccl_port: int,
|
79
|
-
model_override_args: dict,
|
80
80
|
):
|
81
81
|
suppress_other_loggers()
|
82
82
|
|
83
|
-
#
|
83
|
+
# Parse arguments
|
84
84
|
self.gpu_id = gpu_id
|
85
85
|
self.tp_rank = tp_rank
|
86
86
|
self.tp_size = server_args.tp_size
|
87
87
|
self.dp_size = server_args.dp_size
|
88
88
|
self.schedule_policy = server_args.schedule_policy
|
89
89
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
90
|
+
self.lora_paths = server_args.lora_paths
|
91
|
+
self.max_loras_per_batch = server_args.max_loras_per_batch
|
90
92
|
|
91
93
|
# Init model and tokenizer
|
92
94
|
self.model_config = ModelConfig(
|
93
95
|
server_args.model_path,
|
94
96
|
server_args.trust_remote_code,
|
95
97
|
context_length=server_args.context_length,
|
96
|
-
model_override_args=
|
98
|
+
model_override_args=json.loads(server_args.json_model_override_args),
|
97
99
|
)
|
98
|
-
|
99
100
|
self.model_runner = ModelRunner(
|
100
101
|
model_config=self.model_config,
|
101
102
|
mem_fraction_static=server_args.mem_fraction_static,
|
@@ -129,14 +130,14 @@ class ModelTpServer:
|
|
129
130
|
if server_args.max_running_requests is None
|
130
131
|
else server_args.max_running_requests
|
131
132
|
),
|
132
|
-
self.model_runner.req_to_token_pool.size
|
133
|
+
self.model_runner.req_to_token_pool.size,
|
133
134
|
)
|
134
135
|
self.max_req_input_len = min(
|
135
136
|
self.model_config.context_len - 1,
|
136
137
|
self.max_total_num_tokens - 1,
|
137
138
|
)
|
138
139
|
|
139
|
-
# Sync random seed
|
140
|
+
# Sync random seed across TP workers
|
140
141
|
server_args.random_seed = broadcast_recv_input(
|
141
142
|
[server_args.random_seed],
|
142
143
|
self.tp_rank,
|
@@ -144,7 +145,7 @@ class ModelTpServer:
|
|
144
145
|
)[0]
|
145
146
|
set_random_seed(server_args.random_seed)
|
146
147
|
|
147
|
-
# Print info
|
148
|
+
# Print debug info
|
148
149
|
logger.info(
|
149
150
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
150
151
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
@@ -181,7 +182,7 @@ class ModelTpServer:
|
|
181
182
|
self.num_generated_tokens = 0
|
182
183
|
self.last_stats_tic = time.time()
|
183
184
|
|
184
|
-
#
|
185
|
+
# Init chunked prefill
|
185
186
|
self.chunked_prefill_size = server_args.chunked_prefill_size
|
186
187
|
self.current_inflight_req = None
|
187
188
|
self.is_mixed_chunk = (
|
@@ -197,16 +198,7 @@ class ModelTpServer:
|
|
197
198
|
"trust_remote_code": server_args.trust_remote_code,
|
198
199
|
},
|
199
200
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
200
|
-
|
201
|
-
)
|
202
|
-
self.json_fsm_cache = FSMCache(
|
203
|
-
server_args.tokenizer_path,
|
204
|
-
{
|
205
|
-
"tokenizer_mode": server_args.tokenizer_mode,
|
206
|
-
"trust_remote_code": server_args.trust_remote_code,
|
207
|
-
},
|
208
|
-
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
209
|
-
json_schema_mode=True,
|
201
|
+
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
210
202
|
)
|
211
203
|
self.jump_forward_cache = JumpForwardCache()
|
212
204
|
|
@@ -221,15 +213,18 @@ class ModelTpServer:
|
|
221
213
|
)
|
222
214
|
self.new_token_ratio = self.min_new_token_ratio
|
223
215
|
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
216
|
+
self.do_not_get_new_batch = False
|
224
217
|
|
225
218
|
def exposed_step(self, recv_reqs: List):
|
226
219
|
try:
|
227
220
|
# Recv requests
|
228
221
|
for recv_req in recv_reqs:
|
229
|
-
if isinstance(
|
230
|
-
recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
231
|
-
):
|
222
|
+
if isinstance(recv_req, TokenizedGenerateReqInput):
|
232
223
|
self.handle_generate_request(recv_req)
|
224
|
+
self.do_not_get_new_batch = False
|
225
|
+
elif isinstance(recv_req, TokenizedEmbeddingReqInput):
|
226
|
+
self.handle_embedding_request(recv_req)
|
227
|
+
self.do_not_get_new_batch = False
|
233
228
|
elif isinstance(recv_req, FlushCacheReq):
|
234
229
|
self.flush_cache()
|
235
230
|
elif isinstance(recv_req, AbortReq):
|
@@ -253,7 +248,11 @@ class ModelTpServer:
|
|
253
248
|
|
254
249
|
@torch.inference_mode()
|
255
250
|
def forward_step(self):
|
256
|
-
|
251
|
+
if self.do_not_get_new_batch and self.current_inflight_req is None:
|
252
|
+
new_batch = None
|
253
|
+
else:
|
254
|
+
new_batch = self.get_new_prefill_batch()
|
255
|
+
self.do_not_get_new_batch = False
|
257
256
|
|
258
257
|
if new_batch is not None:
|
259
258
|
# Run a new prefill batch
|
@@ -280,7 +279,7 @@ class ModelTpServer:
|
|
280
279
|
self.running_batch = None
|
281
280
|
break
|
282
281
|
|
283
|
-
if self.out_pyobjs and self.running_batch.has_stream
|
282
|
+
if self.out_pyobjs and self.running_batch.has_stream:
|
284
283
|
break
|
285
284
|
else:
|
286
285
|
self.check_memory()
|
@@ -325,73 +324,102 @@ class ModelTpServer:
|
|
325
324
|
|
326
325
|
def handle_generate_request(
|
327
326
|
self,
|
328
|
-
recv_req:
|
327
|
+
recv_req: TokenizedGenerateReqInput,
|
329
328
|
):
|
330
|
-
|
329
|
+
if isinstance(recv_req, TokenizedGenerateReqInput):
|
330
|
+
req = Req(
|
331
|
+
recv_req.rid,
|
332
|
+
recv_req.input_text,
|
333
|
+
recv_req.input_ids,
|
334
|
+
lora_path=recv_req.lora_path,
|
335
|
+
)
|
336
|
+
else:
|
337
|
+
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
331
338
|
req.tokenizer = self.tokenizer
|
332
339
|
req.sampling_params = recv_req.sampling_params
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
req.
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
340
|
+
req.pixel_values = recv_req.pixel_values
|
341
|
+
if req.pixel_values is not None:
|
342
|
+
# Use image hash as fake token_ids, which is then used
|
343
|
+
# for prefix matching
|
344
|
+
image_hash = hash(tuple(recv_req.image_hashes))
|
345
|
+
req.pad_value = [
|
346
|
+
(image_hash) % self.model_config.vocab_size,
|
347
|
+
(image_hash >> 16) % self.model_config.vocab_size,
|
348
|
+
(image_hash >> 32) % self.model_config.vocab_size,
|
349
|
+
(image_hash >> 64) % self.model_config.vocab_size,
|
350
|
+
]
|
351
|
+
req.image_sizes = recv_req.image_sizes
|
352
|
+
(
|
353
|
+
req.origin_input_ids,
|
354
|
+
req.image_offsets,
|
355
|
+
) = self.model_runner.model.pad_input_ids(
|
356
|
+
req.origin_input_ids_unpadded,
|
357
|
+
req.pad_value,
|
358
|
+
req.pixel_values,
|
359
|
+
req.image_sizes,
|
360
|
+
)
|
361
|
+
# Only when pixel values is not None we have modalities
|
362
|
+
req.modalities = recv_req.modalites
|
363
|
+
req.return_logprob = recv_req.return_logprob
|
364
|
+
req.top_logprobs_num = recv_req.top_logprobs_num
|
365
|
+
req.stream = recv_req.stream
|
366
|
+
req.logprob_start_len = recv_req.logprob_start_len
|
367
|
+
|
368
|
+
if req.logprob_start_len == -1:
|
369
|
+
# By default, only return the logprobs for output tokens
|
370
|
+
req.logprob_start_len = len(recv_req.input_ids) - 1
|
371
|
+
|
372
|
+
# Init regex FSM
|
373
|
+
if (
|
374
|
+
req.sampling_params.json_schema is not None
|
375
|
+
or req.sampling_params.regex is not None
|
376
|
+
):
|
361
377
|
if req.sampling_params.json_schema is not None:
|
362
|
-
req.regex_fsm, computed_regex_string = self.
|
363
|
-
req.sampling_params.json_schema
|
378
|
+
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
379
|
+
("json", req.sampling_params.json_schema)
|
364
380
|
)
|
365
|
-
if not self.disable_regex_jump_forward:
|
366
|
-
req.jump_forward_map = self.jump_forward_cache.query(
|
367
|
-
computed_regex_string
|
368
|
-
)
|
369
|
-
|
370
|
-
# Init regex fsm
|
371
381
|
elif req.sampling_params.regex is not None:
|
372
|
-
req.regex_fsm = self.regex_fsm_cache.query(
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
382
|
+
req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
|
383
|
+
("regex", req.sampling_params.regex)
|
384
|
+
)
|
385
|
+
if not self.disable_regex_jump_forward:
|
386
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
387
|
+
computed_regex_string
|
388
|
+
)
|
377
389
|
|
378
390
|
# Truncate prompts that are too long
|
379
391
|
if len(req.origin_input_ids) >= self.max_req_input_len:
|
380
|
-
logger.
|
392
|
+
logger.warning(
|
381
393
|
"Request length is longer than the KV cache pool size or "
|
382
394
|
"the max context length. Truncated!!!"
|
383
395
|
)
|
384
396
|
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
397
|
+
req.sampling_params.max_new_tokens = min(
|
398
|
+
(
|
399
|
+
req.sampling_params.max_new_tokens
|
400
|
+
if req.sampling_params.max_new_tokens is not None
|
401
|
+
else 1 << 30
|
402
|
+
),
|
403
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
404
|
+
)
|
385
405
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
406
|
+
self.waiting_queue.append(req)
|
407
|
+
|
408
|
+
def handle_embedding_request(
|
409
|
+
self,
|
410
|
+
recv_req: TokenizedEmbeddingReqInput,
|
411
|
+
):
|
412
|
+
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
413
|
+
req.tokenizer = self.tokenizer
|
414
|
+
req.sampling_params = recv_req.sampling_params
|
415
|
+
|
416
|
+
# Truncate prompts that are too long
|
417
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
418
|
+
logger.warning(
|
419
|
+
"Request length is longer than the KV cache pool size or "
|
420
|
+
"the max context length. Truncated!!!"
|
394
421
|
)
|
422
|
+
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
395
423
|
|
396
424
|
self.waiting_queue.append(req)
|
397
425
|
|
@@ -409,6 +437,8 @@ class ModelTpServer:
|
|
409
437
|
|
410
438
|
adder = PrefillAdder(
|
411
439
|
self.tree_cache,
|
440
|
+
self.running_batch,
|
441
|
+
self.new_token_ratio,
|
412
442
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
413
443
|
self.max_prefill_tokens,
|
414
444
|
self.chunked_prefill_size,
|
@@ -416,7 +446,7 @@ class ModelTpServer:
|
|
416
446
|
)
|
417
447
|
|
418
448
|
if self.running_batch is not None:
|
419
|
-
adder.remove_running_tokens(self.running_batch
|
449
|
+
adder.remove_running_tokens(self.running_batch)
|
420
450
|
|
421
451
|
has_inflight = self.current_inflight_req is not None
|
422
452
|
if self.current_inflight_req is not None:
|
@@ -427,12 +457,30 @@ class ModelTpServer:
|
|
427
457
|
self.current_inflight_req
|
428
458
|
)
|
429
459
|
|
460
|
+
if self.lora_paths is not None:
|
461
|
+
lora_set = (
|
462
|
+
set([req.lora_path for req in self.running_batch.reqs])
|
463
|
+
if self.running_batch is not None
|
464
|
+
else set([])
|
465
|
+
)
|
466
|
+
|
430
467
|
for req in self.waiting_queue:
|
468
|
+
if adder.no_remaining_tokens():
|
469
|
+
break
|
431
470
|
req.init_next_round_input(None if prefix_computed else self.tree_cache)
|
471
|
+
if (
|
472
|
+
self.lora_paths is not None
|
473
|
+
and len(
|
474
|
+
lora_set
|
475
|
+
| set([req.lora_path for req in adder.can_run_list])
|
476
|
+
| set([req.lora_path])
|
477
|
+
)
|
478
|
+
> self.max_loras_per_batch
|
479
|
+
):
|
480
|
+
break
|
432
481
|
res = adder.add_one_req(req)
|
433
482
|
if (
|
434
483
|
not res
|
435
|
-
or adder.no_remaining_tokens()
|
436
484
|
or running_bs + len(adder.can_run_list) >= self.max_running_requests
|
437
485
|
):
|
438
486
|
break
|
@@ -504,10 +552,9 @@ class ModelTpServer:
|
|
504
552
|
if self.model_runner.is_generation:
|
505
553
|
# Forward and sample the next tokens
|
506
554
|
if batch.extend_num_tokens != 0:
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
next_token_ids = batch.check_sample_results(sample_output)
|
555
|
+
logits_output = self.model_runner.forward(batch)
|
556
|
+
next_token_ids = self.model_runner.sample(logits_output, batch)
|
557
|
+
|
511
558
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
512
559
|
next_token_ids
|
513
560
|
)
|
@@ -541,7 +588,7 @@ class ModelTpServer:
|
|
541
588
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
542
589
|
|
543
590
|
# Check finish conditions
|
544
|
-
|
591
|
+
logprob_pt = 0
|
545
592
|
for i, req in enumerate(batch.reqs):
|
546
593
|
if req is not self.current_inflight_req:
|
547
594
|
# Inflight reqs' prefill is not finished
|
@@ -565,13 +612,12 @@ class ModelTpServer:
|
|
565
612
|
self.req_to_token_pool.free(req.req_pool_idx)
|
566
613
|
|
567
614
|
if req.return_logprob:
|
568
|
-
self.add_logprob_return_values(
|
569
|
-
i, req,
|
615
|
+
logprob_pt += self.add_logprob_return_values(
|
616
|
+
i, req, logprob_pt, next_token_ids, logits_output
|
570
617
|
)
|
571
|
-
pt += req.extend_input_len
|
572
618
|
else:
|
573
619
|
assert batch.extend_num_tokens != 0
|
574
|
-
logits_output = self.model_runner.forward(batch
|
620
|
+
logits_output = self.model_runner.forward(batch)
|
575
621
|
embeddings = logits_output.embeddings.tolist()
|
576
622
|
|
577
623
|
# Check finish conditions
|
@@ -596,48 +642,63 @@ class ModelTpServer:
|
|
596
642
|
|
597
643
|
def add_logprob_return_values(
|
598
644
|
self,
|
599
|
-
i,
|
645
|
+
i: int,
|
600
646
|
req: Req,
|
601
647
|
pt: int,
|
602
648
|
next_token_ids: List[int],
|
603
649
|
output: LogitsProcessorOutput,
|
604
650
|
):
|
651
|
+
"""Attach logprobs to the return values."""
|
652
|
+
req.output_token_logprobs.append(
|
653
|
+
(output.next_token_logprobs[i], next_token_ids[i])
|
654
|
+
)
|
655
|
+
|
656
|
+
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
657
|
+
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
658
|
+
|
605
659
|
if req.normalized_prompt_logprob is None:
|
606
660
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
607
661
|
|
608
662
|
if req.input_token_logprobs is None:
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
663
|
+
input_token_logprobs = output.input_token_logprobs[
|
664
|
+
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
665
|
+
]
|
666
|
+
input_token_ids = req.fill_ids[
|
667
|
+
len(req.fill_ids)
|
668
|
+
- num_input_logprobs
|
669
|
+
+ 1 : len(req.fill_ids)
|
670
|
+
- req.last_update_decode_tokens
|
671
|
+
]
|
672
|
+
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
673
|
+
|
674
|
+
if (
|
675
|
+
req.logprob_start_len == 0
|
676
|
+
): # The first token does not have logprob, pad it.
|
617
677
|
req.input_token_logprobs = [
|
618
678
|
(None, req.fill_ids[0])
|
619
679
|
] + req.input_token_logprobs
|
620
680
|
|
621
681
|
if req.last_update_decode_tokens != 0:
|
682
|
+
# Some decode tokens are re-computed in an extend batch
|
622
683
|
req.output_token_logprobs.extend(
|
623
684
|
list(
|
624
685
|
zip(
|
625
686
|
output.input_token_logprobs[
|
626
687
|
pt
|
627
|
-
+
|
688
|
+
+ num_input_logprobs
|
689
|
+
- 1
|
628
690
|
- req.last_update_decode_tokens : pt
|
629
|
-
+
|
691
|
+
+ num_input_logprobs
|
630
692
|
- 1
|
631
693
|
],
|
632
|
-
req.fill_ids[
|
694
|
+
req.fill_ids[
|
695
|
+
len(req.fill_ids)
|
696
|
+
- req.last_update_decode_tokens : len(req.fill_ids)
|
697
|
+
],
|
633
698
|
)
|
634
699
|
)
|
635
700
|
)
|
636
701
|
|
637
|
-
req.output_token_logprobs.append(
|
638
|
-
(output.next_token_logprobs[i], next_token_ids[i])
|
639
|
-
)
|
640
|
-
|
641
702
|
if req.top_logprobs_num > 0:
|
642
703
|
if req.input_top_logprobs is None:
|
643
704
|
req.input_top_logprobs = output.input_top_logprobs[i]
|
@@ -646,10 +707,12 @@ class ModelTpServer:
|
|
646
707
|
|
647
708
|
if req.last_update_decode_tokens != 0:
|
648
709
|
req.output_top_logprobs.extend(
|
649
|
-
output.input_top_logprobs[i][-req.last_update_decode_tokens
|
710
|
+
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
650
711
|
)
|
651
712
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
652
713
|
|
714
|
+
return num_input_logprobs
|
715
|
+
|
653
716
|
def forward_decode_batch(self, batch: ScheduleBatch):
|
654
717
|
# Check if decode out of memory
|
655
718
|
if not batch.check_decode_mem():
|
@@ -682,10 +745,8 @@ class ModelTpServer:
|
|
682
745
|
batch.prepare_for_decode()
|
683
746
|
|
684
747
|
# Forward and sample the next tokens
|
685
|
-
|
686
|
-
|
687
|
-
)
|
688
|
-
next_token_ids = batch.check_sample_results(sample_output)
|
748
|
+
logits_output = self.model_runner.forward(batch)
|
749
|
+
next_token_ids = self.model_runner.sample(logits_output, batch)
|
689
750
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
690
751
|
next_token_ids
|
691
752
|
)
|
@@ -700,6 +761,7 @@ class ModelTpServer:
|
|
700
761
|
next_token_ids = next_token_ids.tolist()
|
701
762
|
|
702
763
|
# Check finish condition
|
764
|
+
has_finished = False
|
703
765
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
704
766
|
req.completion_tokens_wo_jump_forward += 1
|
705
767
|
req.output_ids.append(next_token_id)
|
@@ -712,6 +774,7 @@ class ModelTpServer:
|
|
712
774
|
|
713
775
|
if req.finished():
|
714
776
|
self.tree_cache.cache_finished_req(req)
|
777
|
+
has_finished = True
|
715
778
|
|
716
779
|
if req.return_logprob:
|
717
780
|
req.output_token_logprobs.append(
|
@@ -720,6 +783,9 @@ class ModelTpServer:
|
|
720
783
|
if req.top_logprobs_num > 0:
|
721
784
|
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
722
785
|
|
786
|
+
if not has_finished:
|
787
|
+
self.do_not_get_new_batch = True
|
788
|
+
|
723
789
|
self.handle_finished_requests(batch)
|
724
790
|
|
725
791
|
def handle_finished_requests(self, batch: ScheduleBatch):
|
@@ -742,12 +808,10 @@ class ModelTpServer:
|
|
742
808
|
unfinished_indices.append(i)
|
743
809
|
|
744
810
|
if req.finished() or (
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
or len(req.output_ids) == 1
|
750
|
-
)
|
811
|
+
req.stream
|
812
|
+
and (
|
813
|
+
self.decode_forward_ct % self.stream_interval == 0
|
814
|
+
or len(req.output_ids) == 1
|
751
815
|
)
|
752
816
|
):
|
753
817
|
output_rids.append(req.rid)
|
@@ -769,7 +833,11 @@ class ModelTpServer:
|
|
769
833
|
"prompt_tokens": len(req.origin_input_ids),
|
770
834
|
"completion_tokens": len(req.output_ids),
|
771
835
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
772
|
-
"finish_reason":
|
836
|
+
"finish_reason": (
|
837
|
+
req.finished_reason.to_json()
|
838
|
+
if req.finished_reason is not None
|
839
|
+
else None
|
840
|
+
),
|
773
841
|
}
|
774
842
|
if req.return_logprob:
|
775
843
|
(
|
@@ -868,6 +936,8 @@ class ModelTpServer:
|
|
868
936
|
if success:
|
869
937
|
flash_cache_success = self.flush_cache()
|
870
938
|
assert flash_cache_success, "Cache flush failed after updating weights"
|
939
|
+
else:
|
940
|
+
logger.error(message)
|
871
941
|
return success, message
|
872
942
|
|
873
943
|
|
@@ -876,7 +946,6 @@ def run_tp_server(
|
|
876
946
|
tp_rank: int,
|
877
947
|
server_args: ServerArgs,
|
878
948
|
nccl_port: int,
|
879
|
-
model_override_args: dict,
|
880
949
|
):
|
881
950
|
"""Run a tensor parallel model server."""
|
882
951
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -887,7 +956,6 @@ def run_tp_server(
|
|
887
956
|
tp_rank,
|
888
957
|
server_args,
|
889
958
|
nccl_port,
|
890
|
-
model_override_args,
|
891
959
|
)
|
892
960
|
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
893
961
|
|
@@ -904,14 +972,13 @@ def launch_tp_servers(
|
|
904
972
|
tp_rank_range: List[int],
|
905
973
|
server_args: ServerArgs,
|
906
974
|
nccl_port: int,
|
907
|
-
model_override_args: dict,
|
908
975
|
):
|
909
976
|
"""Launch multiple tensor parallel servers."""
|
910
977
|
procs = []
|
911
978
|
for i in tp_rank_range:
|
912
979
|
proc = multiprocessing.Process(
|
913
980
|
target=run_tp_server,
|
914
|
-
args=(gpu_ids[i], i, server_args, nccl_port
|
981
|
+
args=(gpu_ids[i], i, server_args, nccl_port),
|
915
982
|
)
|
916
983
|
proc.start()
|
917
984
|
procs.append(proc)
|