sglang 0.4.1__py3-none-any.whl → 0.4.1.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/bench_serving.py +11 -3
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -30
- sglang/srt/layers/moe/topk.py +14 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +91 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/io_struct.py +29 -8
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +71 -34
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +95 -55
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -6
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/llama.py +13 -2
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +79 -2
- sglang/srt/openai_api/protocol.py +50 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +45 -39
- sglang/srt/server_args.py +17 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/METADATA +5 -5
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/RECORD +41 -39
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.dist-info → sglang-0.4.1.post2.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -22,7 +22,7 @@ import warnings
|
|
22
22
|
from collections import deque
|
23
23
|
from concurrent import futures
|
24
24
|
from types import SimpleNamespace
|
25
|
-
from typing import
|
25
|
+
from typing import Dict, List, Optional, Tuple
|
26
26
|
|
27
27
|
import psutil
|
28
28
|
import setproctitle
|
@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
|
|
52
52
|
UpdateWeightFromDiskReqOutput,
|
53
53
|
UpdateWeightsFromDistributedReqInput,
|
54
54
|
UpdateWeightsFromDistributedReqOutput,
|
55
|
+
UpdateWeightsFromTensorReqInput,
|
56
|
+
UpdateWeightsFromTensorReqOutput,
|
55
57
|
)
|
56
58
|
from sglang.srt.managers.schedule_batch import (
|
57
59
|
FINISH_ABORT,
|
@@ -88,7 +90,7 @@ from sglang.utils import get_exception_traceback
|
|
88
90
|
|
89
91
|
logger = logging.getLogger(__name__)
|
90
92
|
|
91
|
-
# Test retract decode
|
93
|
+
# Test retract decode for debugging purposes
|
92
94
|
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
93
95
|
|
94
96
|
|
@@ -127,12 +129,12 @@ class Scheduler:
|
|
127
129
|
)
|
128
130
|
|
129
131
|
if server_args.skip_tokenizer_init:
|
130
|
-
# Directly send to the
|
132
|
+
# Directly send to the TokenizerManager
|
131
133
|
self.send_to_detokenizer = get_zmq_socket(
|
132
134
|
context, zmq.PUSH, port_args.tokenizer_ipc_name
|
133
135
|
)
|
134
136
|
else:
|
135
|
-
# Send to the
|
137
|
+
# Send to the DetokenizerManager
|
136
138
|
self.send_to_detokenizer = get_zmq_socket(
|
137
139
|
context, zmq.PUSH, port_args.detokenizer_ipc_name
|
138
140
|
)
|
@@ -383,7 +385,8 @@ class Scheduler:
|
|
383
385
|
self.process_input_requests(recv_reqs)
|
384
386
|
|
385
387
|
batch = self.get_next_batch_to_run()
|
386
|
-
|
388
|
+
|
389
|
+
if self.server_args.enable_dp_attention: # TODO: simplify this
|
387
390
|
batch = self.prepare_dp_attn_batch(batch)
|
388
391
|
|
389
392
|
self.cur_batch = batch
|
@@ -392,7 +395,7 @@ class Scheduler:
|
|
392
395
|
result = self.run_batch(batch)
|
393
396
|
self.process_batch_result(batch, result)
|
394
397
|
else:
|
395
|
-
#
|
398
|
+
# When the server is idle, so self-check and re-init some states
|
396
399
|
self.check_memory()
|
397
400
|
self.new_token_ratio = self.init_new_token_ratio
|
398
401
|
|
@@ -409,12 +412,13 @@ class Scheduler:
|
|
409
412
|
|
410
413
|
batch = self.get_next_batch_to_run()
|
411
414
|
self.cur_batch = batch
|
415
|
+
|
412
416
|
if batch:
|
413
417
|
result = self.run_batch(batch)
|
414
418
|
result_queue.append((batch.copy(), result))
|
415
419
|
|
416
420
|
if self.last_batch is None:
|
417
|
-
#
|
421
|
+
# Create a dummy first batch to start the pipeline for overlap scheduler.
|
418
422
|
# It is now used for triggering the sampling_info_done event.
|
419
423
|
tmp_batch = ScheduleBatch(
|
420
424
|
reqs=None,
|
@@ -424,19 +428,21 @@ class Scheduler:
|
|
424
428
|
self.process_batch_result(tmp_batch, None)
|
425
429
|
|
426
430
|
if self.last_batch:
|
431
|
+
# Process the results of the last batch
|
427
432
|
tmp_batch, tmp_result = result_queue.popleft()
|
428
433
|
tmp_batch.next_batch_sampling_info = (
|
429
434
|
self.tp_worker.cur_sampling_info if batch else None
|
430
435
|
)
|
431
436
|
self.process_batch_result(tmp_batch, tmp_result)
|
432
437
|
elif batch is None:
|
433
|
-
#
|
438
|
+
# When the server is idle, so self-check and re-init some states
|
434
439
|
self.check_memory()
|
435
440
|
self.new_token_ratio = self.init_new_token_ratio
|
436
441
|
|
437
442
|
self.last_batch = batch
|
438
443
|
|
439
|
-
def recv_requests(self):
|
444
|
+
def recv_requests(self) -> List[Req]:
|
445
|
+
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
440
446
|
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
441
447
|
recv_reqs = []
|
442
448
|
|
@@ -468,9 +474,6 @@ class Scheduler:
|
|
468
474
|
self.send_to_tokenizer.send_pyobj(
|
469
475
|
UpdateWeightFromDiskReqOutput(success, message)
|
470
476
|
)
|
471
|
-
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
472
|
-
parameter = self.get_weights_by_name(recv_req)
|
473
|
-
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
474
477
|
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
|
475
478
|
success, message = self.init_weights_update_group(recv_req)
|
476
479
|
self.send_to_tokenizer.send_pyobj(
|
@@ -481,6 +484,11 @@ class Scheduler:
|
|
481
484
|
self.send_to_tokenizer.send_pyobj(
|
482
485
|
UpdateWeightsFromDistributedReqOutput(success, message)
|
483
486
|
)
|
487
|
+
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
|
488
|
+
success, message = self.update_weights_from_tensor(recv_req)
|
489
|
+
self.send_to_tokenizer.send_pyobj(
|
490
|
+
UpdateWeightsFromTensorReqOutput(success, message)
|
491
|
+
)
|
484
492
|
elif isinstance(recv_req, GetWeightsByNameReqInput):
|
485
493
|
parameter = self.get_weights_by_name(recv_req)
|
486
494
|
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
|
@@ -490,8 +498,10 @@ class Scheduler:
|
|
490
498
|
else:
|
491
499
|
self.stop_profile()
|
492
500
|
elif isinstance(recv_req, OpenSessionReqInput):
|
493
|
-
session_id = self.open_session(recv_req)
|
494
|
-
self.send_to_tokenizer.send_pyobj(
|
501
|
+
session_id, success = self.open_session(recv_req)
|
502
|
+
self.send_to_tokenizer.send_pyobj(
|
503
|
+
OpenSessionReqOutput(session_id=session_id, success=success)
|
504
|
+
)
|
495
505
|
elif isinstance(recv_req, CloseSessionReqInput):
|
496
506
|
self.close_session(recv_req)
|
497
507
|
else:
|
@@ -502,7 +512,11 @@ class Scheduler:
|
|
502
512
|
recv_req: TokenizedGenerateReqInput,
|
503
513
|
):
|
504
514
|
# Create a new request
|
505
|
-
if
|
515
|
+
if (
|
516
|
+
recv_req.session_params is None
|
517
|
+
or recv_req.session_params.id is None
|
518
|
+
or recv_req.session_params.id not in self.sessions
|
519
|
+
):
|
506
520
|
|
507
521
|
if recv_req.input_embeds is not None:
|
508
522
|
# Generate fake input_ids based on the length of input_embeds
|
@@ -520,18 +534,22 @@ class Scheduler:
|
|
520
534
|
stream=recv_req.stream,
|
521
535
|
lora_path=recv_req.lora_path,
|
522
536
|
input_embeds=recv_req.input_embeds,
|
537
|
+
eos_token_ids=self.model_config.hf_eos_token_id,
|
523
538
|
)
|
524
539
|
req.tokenizer = self.tokenizer
|
525
540
|
|
526
|
-
if
|
541
|
+
if (
|
542
|
+
recv_req.session_params is not None
|
543
|
+
and recv_req.session_params.id is not None
|
544
|
+
):
|
527
545
|
req.finished_reason = FINISH_ABORT(
|
528
|
-
f"Invalid request: session id {recv_req.
|
546
|
+
f"Invalid request: session id {recv_req.session_params.id} does not exist"
|
529
547
|
)
|
530
548
|
self.waiting_queue.append(req)
|
531
549
|
return
|
532
550
|
else:
|
533
|
-
# Create a new request from a
|
534
|
-
session = self.sessions[recv_req.
|
551
|
+
# Create a new request from a previous session
|
552
|
+
session = self.sessions[recv_req.session_params.id]
|
535
553
|
req = session.create_req(recv_req, self.tokenizer)
|
536
554
|
if isinstance(req.finished_reason, FINISH_ABORT):
|
537
555
|
self.waiting_queue.append(req)
|
@@ -565,7 +583,7 @@ class Scheduler:
|
|
565
583
|
|
566
584
|
if req.logprob_start_len == -1:
|
567
585
|
# By default, only return the logprobs for output tokens
|
568
|
-
req.logprob_start_len = len(
|
586
|
+
req.logprob_start_len = len(req.origin_input_ids) - 1
|
569
587
|
|
570
588
|
# Truncate prompts that are too long
|
571
589
|
if len(req.origin_input_ids) > self.max_req_input_len:
|
@@ -589,12 +607,15 @@ class Scheduler:
|
|
589
607
|
if (
|
590
608
|
req.sampling_params.json_schema is not None
|
591
609
|
or req.sampling_params.regex is not None
|
610
|
+
or req.sampling_params.ebnf is not None
|
592
611
|
):
|
593
612
|
assert self.grammar_backend is not None
|
594
613
|
if req.sampling_params.json_schema is not None:
|
595
614
|
key = ("json", req.sampling_params.json_schema)
|
596
615
|
elif req.sampling_params.regex is not None:
|
597
616
|
key = ("regex", req.sampling_params.regex)
|
617
|
+
elif req.sampling_params.ebnf is not None:
|
618
|
+
key = ("ebnf", req.sampling_params.ebnf)
|
598
619
|
|
599
620
|
req.grammar = self.grammar_backend.get_cached_value(key)
|
600
621
|
if not req.grammar:
|
@@ -629,16 +650,13 @@ class Scheduler:
|
|
629
650
|
self.waiting_queue.append(req)
|
630
651
|
|
631
652
|
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
)
|
640
|
-
else:
|
641
|
-
tree_cache_hit_rate = 0.0
|
653
|
+
self.tree_cache_metrics["total"] += (
|
654
|
+
adder.log_input_tokens + adder.log_hit_tokens
|
655
|
+
) / 10**9
|
656
|
+
self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
|
657
|
+
tree_cache_hit_rate = (
|
658
|
+
self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
|
659
|
+
)
|
642
660
|
|
643
661
|
num_used = self.max_total_num_tokens - (
|
644
662
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
@@ -807,6 +825,8 @@ class Scheduler:
|
|
807
825
|
if res == AddReqResult.NO_TOKEN:
|
808
826
|
self.batch_is_full = True
|
809
827
|
break
|
828
|
+
if self.server_args.prefill_only_one_req:
|
829
|
+
break
|
810
830
|
|
811
831
|
# Update waiting queue
|
812
832
|
can_run_list = adder.can_run_list
|
@@ -1460,6 +1480,17 @@ class Scheduler:
|
|
1460
1480
|
logger.error(message)
|
1461
1481
|
return success, message
|
1462
1482
|
|
1483
|
+
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
1484
|
+
"""Update the online model parameter from tensors."""
|
1485
|
+
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
1486
|
+
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
1487
|
+
if success:
|
1488
|
+
flash_cache_success = self.flush_cache()
|
1489
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
1490
|
+
else:
|
1491
|
+
logger.error(message)
|
1492
|
+
return success, message
|
1493
|
+
|
1463
1494
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
1464
1495
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
1465
1496
|
return parameter
|
@@ -1478,16 +1509,20 @@ class Scheduler:
|
|
1478
1509
|
)
|
1479
1510
|
logger.info("Profiler is done")
|
1480
1511
|
|
1481
|
-
def open_session(self, recv_req: OpenSessionReqInput) -> str:
|
1512
|
+
def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
|
1482
1513
|
# handle error
|
1483
1514
|
session_id = recv_req.session_id
|
1484
1515
|
if session_id in self.sessions:
|
1485
1516
|
logger.warning(f"session id {session_id} already exist, cannot open.")
|
1517
|
+
return session_id, False
|
1518
|
+
elif session_id is None:
|
1519
|
+
logger.warning(f"session id is None, cannot open.")
|
1520
|
+
return session_id, False
|
1486
1521
|
else:
|
1487
1522
|
self.sessions[session_id] = Session(
|
1488
1523
|
recv_req.capacity_of_str_len, session_id
|
1489
1524
|
)
|
1490
|
-
|
1525
|
+
return session_id, True
|
1491
1526
|
|
1492
1527
|
def close_session(self, recv_req: CloseSessionReqInput):
|
1493
1528
|
# handle error
|
@@ -1512,18 +1547,20 @@ def run_scheduler_process(
|
|
1512
1547
|
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1513
1548
|
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
1514
1549
|
|
1550
|
+
# Configue the logger
|
1515
1551
|
if dp_rank is None:
|
1516
1552
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
1517
1553
|
else:
|
1518
1554
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1555
|
+
suppress_other_loggers()
|
1519
1556
|
|
1520
|
-
#
|
1557
|
+
# Set cpu affinity to this gpu process
|
1521
1558
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1522
1559
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1523
1560
|
|
1524
|
-
suppress_other_loggers()
|
1525
1561
|
parent_process = psutil.Process().parent()
|
1526
1562
|
|
1563
|
+
# Create a scheduler and run the event loop
|
1527
1564
|
try:
|
1528
1565
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
1529
1566
|
pipe_writer.send(
|
@@ -10,41 +10,116 @@
|
|
10
10
|
# limitations under the License.
|
11
11
|
# ==============================================================================
|
12
12
|
|
13
|
+
import logging
|
13
14
|
import uuid
|
15
|
+
from typing import Dict, Optional
|
14
16
|
|
15
17
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
16
|
-
from sglang.srt.managers.schedule_batch import
|
18
|
+
from sglang.srt.managers.schedule_batch import Req
|
19
|
+
|
20
|
+
|
21
|
+
class SessionReqNode:
|
22
|
+
def __init__(self, req, parent=None, childs=None):
|
23
|
+
self.req = req
|
24
|
+
self.parent = parent
|
25
|
+
if parent is not None:
|
26
|
+
parent.childs.append(self)
|
27
|
+
self.childs = [] if not childs else childs
|
28
|
+
|
29
|
+
def clear_childs(self, req_dict):
|
30
|
+
for req_node in self.childs:
|
31
|
+
req_node.clear(req_dict)
|
32
|
+
self.childs = []
|
33
|
+
|
34
|
+
def clear(self, req_dict):
|
35
|
+
for req_node in self.childs:
|
36
|
+
req_node.clear(req_dict)
|
37
|
+
|
38
|
+
if self.req.finished_reason == None:
|
39
|
+
self.req.to_abort = True
|
40
|
+
del req_dict[self.req.rid]
|
41
|
+
|
42
|
+
def abort(self):
|
43
|
+
if self.req.finished_reason == None:
|
44
|
+
self.req.to_abort = True
|
45
|
+
|
46
|
+
def __str__(self):
|
47
|
+
return self._str_helper(self.req.rid)
|
48
|
+
|
49
|
+
def _str_helper(self, prefix=""):
|
50
|
+
if len(self.childs) == 0:
|
51
|
+
return prefix + "\n"
|
52
|
+
else:
|
53
|
+
origin_prefix = prefix
|
54
|
+
prefix += " -- " + self.childs[0].req.rid
|
55
|
+
ret = self.childs[0]._str_helper(prefix)
|
56
|
+
for child in self.childs[1:]:
|
57
|
+
prefix = " " * len(origin_prefix) + " \- " + child.req.rid
|
58
|
+
ret += child._str_helper(prefix)
|
59
|
+
return ret
|
17
60
|
|
18
61
|
|
19
62
|
class Session:
|
20
|
-
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
63
|
+
def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None):
|
21
64
|
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
22
65
|
self.capacity_of_str_len = capacity_of_str_len
|
23
|
-
self.
|
66
|
+
self.req_nodes: Dict[str, SessionReqNode] = {}
|
24
67
|
|
25
68
|
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
69
|
+
assert req.session_params is not None
|
70
|
+
session_params = req.session_params
|
71
|
+
|
72
|
+
last_req_node = None
|
73
|
+
last_req = None
|
74
|
+
abort = False
|
75
|
+
if session_params.replace:
|
76
|
+
if session_params.rid is None:
|
77
|
+
for _, req_node in self.req_nodes.items():
|
78
|
+
req_node.clear(self.req_nodes)
|
79
|
+
else:
|
80
|
+
if session_params.rid not in self.req_nodes:
|
81
|
+
abort = True
|
82
|
+
else:
|
83
|
+
last_req_node = self.req_nodes[session_params.rid]
|
84
|
+
last_req_node.abort()
|
85
|
+
last_req = last_req_node.req
|
86
|
+
last_req_node.clear_childs(self.req_nodes)
|
31
87
|
else:
|
32
|
-
|
33
|
-
|
88
|
+
if session_params.rid is not None:
|
89
|
+
if session_params.rid not in self.req_nodes:
|
90
|
+
abort = True
|
91
|
+
else:
|
92
|
+
last_req_node = self.req_nodes[session_params.rid]
|
93
|
+
last_req = last_req_node.req
|
94
|
+
if not last_req.finished():
|
95
|
+
logging.warning(
|
96
|
+
"The request in a session is appending to a request that hasn't finished."
|
97
|
+
)
|
98
|
+
abort = True
|
99
|
+
|
100
|
+
if last_req is not None:
|
101
|
+
# trim bos token if it is an append
|
102
|
+
if req.input_ids[0] == tokenizer.bos_token_id:
|
103
|
+
req.input_ids = req.input_ids[1:]
|
104
|
+
|
34
105
|
input_ids = (
|
35
|
-
|
36
|
-
+
|
37
|
-
: self.reqs[-1].sampling_params.max_new_tokens
|
38
|
-
]
|
39
|
-
+ req.input_ids
|
106
|
+
last_req.origin_input_ids
|
107
|
+
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
40
108
|
)
|
109
|
+
if session_params.offset and session_params.offset != 0:
|
110
|
+
input_ids = input_ids[: session_params.offset] + req.input_ids
|
111
|
+
else:
|
112
|
+
input_ids += req.input_ids
|
41
113
|
input_ids_unpadded = (
|
42
|
-
|
43
|
-
+
|
44
|
-
: self.reqs[-1].sampling_params.max_new_tokens
|
45
|
-
]
|
46
|
-
+ req.input_ids
|
114
|
+
last_req.origin_input_ids_unpadded
|
115
|
+
+ last_req.output_ids[: last_req.sampling_params.max_new_tokens]
|
47
116
|
)
|
117
|
+
if session_params.offset and session_params.offset != 0:
|
118
|
+
input_ids_unpadded = (
|
119
|
+
input_ids_unpadded[: session_params.offset] + req.input_ids
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
input_ids_unpadded += req.input_ids
|
48
123
|
else:
|
49
124
|
input_ids = req.input_ids
|
50
125
|
input_ids_unpadded = req.input_ids
|
@@ -57,13 +132,13 @@ class Session:
|
|
57
132
|
lora_path=req.lora_path,
|
58
133
|
session_id=self.session_id,
|
59
134
|
)
|
60
|
-
if
|
61
|
-
new_req.image_inputs =
|
135
|
+
if last_req is not None:
|
136
|
+
new_req.image_inputs = last_req.image_inputs
|
62
137
|
new_req.tokenizer = tokenizer
|
63
|
-
if
|
64
|
-
new_req.
|
65
|
-
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
66
|
-
)
|
138
|
+
if abort:
|
139
|
+
new_req.to_abort = True
|
67
140
|
else:
|
68
|
-
|
141
|
+
new_req_node = SessionReqNode(new_req, last_req_node)
|
142
|
+
self.req_nodes[req.rid] = new_req_node
|
143
|
+
|
69
144
|
return new_req
|