sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -179,6 +179,27 @@ class EmbeddingBatchResult:
|
|
179
179
|
bid: int
|
180
180
|
|
181
181
|
|
182
|
+
class IdleSleeper:
|
183
|
+
"""
|
184
|
+
In setups which have long inactivity periods it is desirable to reduce
|
185
|
+
system power consumption when sglang does nothing. This would lead not only
|
186
|
+
to power savings, but also to more CPU thermal headroom when a request
|
187
|
+
eventually comes. This is important in cases when multiple GPUs are connected
|
188
|
+
as each GPU would otherwise pin one thread at 100% CPU usage.
|
189
|
+
|
190
|
+
The simplest solution is to use zmq.Poller on all sockets that may receive
|
191
|
+
data that needs handling immediately.
|
192
|
+
"""
|
193
|
+
|
194
|
+
def __init__(self, sockets):
|
195
|
+
self.poller = zmq.Poller()
|
196
|
+
for s in sockets:
|
197
|
+
self.poller.register(s, zmq.POLLIN)
|
198
|
+
|
199
|
+
def maybe_sleep(self):
|
200
|
+
self.poller.poll(1000)
|
201
|
+
|
202
|
+
|
182
203
|
class Scheduler(
|
183
204
|
SchedulerOutputProcessorMixin,
|
184
205
|
SchedulerDisaggregationDecodeMixin,
|
@@ -228,6 +249,8 @@ class Scheduler(
|
|
228
249
|
|
229
250
|
# Init inter-process communication
|
230
251
|
context = zmq.Context(2)
|
252
|
+
self.idle_sleeper = None
|
253
|
+
|
231
254
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
232
255
|
self.recv_from_tokenizer = get_zmq_socket(
|
233
256
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
@@ -250,6 +273,13 @@ class Scheduler(
|
|
250
273
|
self.recv_from_rpc = get_zmq_socket(
|
251
274
|
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
252
275
|
)
|
276
|
+
if self.server_args.sleep_on_idle:
|
277
|
+
self.idle_sleeper = IdleSleeper(
|
278
|
+
[
|
279
|
+
self.recv_from_tokenizer,
|
280
|
+
self.recv_from_rpc,
|
281
|
+
]
|
282
|
+
)
|
253
283
|
else:
|
254
284
|
self.recv_from_tokenizer = None
|
255
285
|
self.recv_from_rpc = None
|
@@ -361,7 +391,7 @@ class Scheduler(
|
|
361
391
|
self.forward_ct = 0
|
362
392
|
self.forward_ct_decode = 0
|
363
393
|
self.num_generated_tokens = 0
|
364
|
-
self.
|
394
|
+
self.last_prefill_tokens = 0
|
365
395
|
self.last_decode_stats_tic = time.perf_counter()
|
366
396
|
self.last_prefill_stats_tic = time.perf_counter()
|
367
397
|
self.return_health_check_ct = 0
|
@@ -478,6 +508,10 @@ class Scheduler(
|
|
478
508
|
)
|
479
509
|
self.init_disaggregation()
|
480
510
|
|
511
|
+
def maybe_sleep_on_idle(self):
|
512
|
+
if self.idle_sleeper is not None:
|
513
|
+
self.idle_sleeper.maybe_sleep()
|
514
|
+
|
481
515
|
def init_tokenizer(self):
|
482
516
|
server_args = self.server_args
|
483
517
|
|
@@ -585,7 +619,7 @@ class Scheduler(
|
|
585
619
|
self.disaggregation_mode == DisaggregationMode.DECODE
|
586
620
|
): # *2 for the headroom.
|
587
621
|
buffer_size = (self.req_to_token_pool.size) * 2
|
588
|
-
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
622
|
+
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
589
623
|
buffer_size
|
590
624
|
)
|
591
625
|
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
@@ -593,7 +627,8 @@ class Scheduler(
|
|
593
627
|
# The decode requests polling kv cache
|
594
628
|
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
595
629
|
gloo_group=self.attn_tp_cpu_group,
|
596
|
-
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
630
|
+
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
631
|
+
tp_rank=self.tp_rank,
|
597
632
|
metadata_buffers=self.disagg_metadata_buffers,
|
598
633
|
scheduler=self,
|
599
634
|
tree_cache=self.tree_cache,
|
@@ -608,7 +643,7 @@ class Scheduler(
|
|
608
643
|
if self.draft_worker is None
|
609
644
|
else self.draft_worker.model_runner.token_to_kv_pool
|
610
645
|
),
|
611
|
-
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
646
|
+
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
612
647
|
metadata_buffers=self.disagg_metadata_buffers,
|
613
648
|
scheduler=self,
|
614
649
|
transfer_queue=self.disagg_decode_transfer_queue,
|
@@ -616,7 +651,12 @@ class Scheduler(
|
|
616
651
|
gloo_group=self.attn_tp_cpu_group,
|
617
652
|
tp_rank=self.tp_rank,
|
618
653
|
tp_size=self.tp_size,
|
654
|
+
dp_size=self.server_args.dp_size,
|
655
|
+
gpu_id=self.gpu_id,
|
619
656
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
657
|
+
max_total_num_tokens=self.max_total_num_tokens,
|
658
|
+
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
|
659
|
+
num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
|
620
660
|
transfer_backend=self.transfer_backend,
|
621
661
|
)
|
622
662
|
|
@@ -626,7 +666,7 @@ class Scheduler(
|
|
626
666
|
elif self.disaggregation_mode == DisaggregationMode.PREFILL:
|
627
667
|
# *2 for the headroom.
|
628
668
|
buffer_size = self.max_running_requests * 2
|
629
|
-
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
669
|
+
self.req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
630
670
|
buffer_size
|
631
671
|
)
|
632
672
|
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
@@ -638,14 +678,20 @@ class Scheduler(
|
|
638
678
|
if self.draft_worker is None
|
639
679
|
else self.draft_worker.model_runner.token_to_kv_pool
|
640
680
|
),
|
641
|
-
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
681
|
+
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
642
682
|
metadata_buffers=self.disagg_metadata_buffers,
|
643
683
|
tp_rank=self.tp_rank,
|
644
684
|
tp_size=self.tp_size,
|
685
|
+
gpu_id=self.gpu_id,
|
645
686
|
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
646
687
|
gloo_group=self.attn_tp_cpu_group,
|
647
|
-
|
688
|
+
max_total_num_tokens=self.max_total_num_tokens,
|
689
|
+
decode_tp_size=self.server_args.disaggregation_decode_tp,
|
690
|
+
decode_dp_size=self.server_args.disaggregation_decode_dp,
|
648
691
|
scheduler=self,
|
692
|
+
pp_rank=self.pp_rank,
|
693
|
+
pp_size=self.pp_size,
|
694
|
+
transfer_backend=self.transfer_backend,
|
649
695
|
)
|
650
696
|
# The prefill requests that are in the middle of kv sending
|
651
697
|
self.disagg_prefill_inflight_queue: List[Req] = []
|
@@ -667,6 +713,7 @@ class Scheduler(
|
|
667
713
|
# When the server is idle, do self-check and re-init some states
|
668
714
|
self.check_memory()
|
669
715
|
self.new_token_ratio = self.init_new_token_ratio
|
716
|
+
self.maybe_sleep_on_idle()
|
670
717
|
|
671
718
|
self.last_batch = batch
|
672
719
|
|
@@ -711,6 +758,7 @@ class Scheduler(
|
|
711
758
|
# When the server is idle, do self-check and re-init some states
|
712
759
|
self.check_memory()
|
713
760
|
self.new_token_ratio = self.init_new_token_ratio
|
761
|
+
self.maybe_sleep_on_idle()
|
714
762
|
|
715
763
|
self.last_batch = batch
|
716
764
|
|
@@ -816,6 +864,7 @@ class Scheduler(
|
|
816
864
|
if server_is_idle:
|
817
865
|
self.check_memory()
|
818
866
|
self.new_token_ratio = self.init_new_token_ratio
|
867
|
+
self.maybe_sleep_on_idle()
|
819
868
|
|
820
869
|
def recv_requests(self) -> List[Req]:
|
821
870
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
@@ -1073,18 +1122,22 @@ class Scheduler(
|
|
1073
1122
|
def _add_request_to_queue(self, req: Req):
|
1074
1123
|
req.queue_time_start = time.perf_counter()
|
1075
1124
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1076
|
-
self.disagg_prefill_bootstrap_queue.add(
|
1125
|
+
self.disagg_prefill_bootstrap_queue.add(
|
1126
|
+
req, self.model_config.num_key_value_heads
|
1127
|
+
)
|
1077
1128
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1078
1129
|
self.disagg_decode_prealloc_queue.add(req)
|
1079
1130
|
else:
|
1080
1131
|
self.waiting_queue.append(req)
|
1081
1132
|
|
1082
|
-
def _extend_requests_to_queue(self, reqs: List[Req]):
|
1133
|
+
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
|
1083
1134
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1084
|
-
self.disagg_prefill_bootstrap_queue.extend(
|
1135
|
+
self.disagg_prefill_bootstrap_queue.extend(
|
1136
|
+
reqs, self.model_config.num_key_value_heads
|
1137
|
+
)
|
1085
1138
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
1086
1139
|
# If this is a decode server, we put the request to the decode pending prealloc queue
|
1087
|
-
self.disagg_decode_prealloc_queue.extend(reqs)
|
1140
|
+
self.disagg_decode_prealloc_queue.extend(reqs, is_retracted)
|
1088
1141
|
else:
|
1089
1142
|
self.waiting_queue.extend(reqs)
|
1090
1143
|
|
@@ -1097,6 +1150,7 @@ class Scheduler(
|
|
1097
1150
|
recv_req.input_text,
|
1098
1151
|
recv_req.input_ids,
|
1099
1152
|
recv_req.sampling_params,
|
1153
|
+
token_type_ids=recv_req.token_type_ids,
|
1100
1154
|
)
|
1101
1155
|
req.tokenizer = self.tokenizer
|
1102
1156
|
|
@@ -1141,8 +1195,8 @@ class Scheduler(
|
|
1141
1195
|
):
|
1142
1196
|
gap_latency = time.perf_counter() - self.last_prefill_stats_tic
|
1143
1197
|
self.last_prefill_stats_tic = time.perf_counter()
|
1144
|
-
self.last_input_throughput = self.
|
1145
|
-
self.
|
1198
|
+
self.last_input_throughput = self.last_prefill_tokens / gap_latency
|
1199
|
+
self.last_prefill_tokens = adder.log_input_tokens
|
1146
1200
|
|
1147
1201
|
num_used = self.max_total_num_tokens - (
|
1148
1202
|
self.token_to_kv_pool_allocator.available_size()
|
@@ -1156,15 +1210,15 @@ class Scheduler(
|
|
1156
1210
|
f"#new-token: {adder.log_input_tokens}, "
|
1157
1211
|
f"#cached-token: {adder.log_hit_tokens}, "
|
1158
1212
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
1159
|
-
f"#running-req: {running_bs}, "
|
1160
1213
|
)
|
1161
1214
|
|
1162
1215
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
1163
1216
|
f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
|
1164
1217
|
f += f"#queue-req: {len(self.waiting_queue)}, "
|
1165
1218
|
f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)}, "
|
1166
|
-
f += f"
|
1219
|
+
f += f"input throughput (token/s): {self.last_input_throughput:.2f} "
|
1167
1220
|
else:
|
1221
|
+
f += f"#running-req: {running_bs}, "
|
1168
1222
|
f += f"#queue-req: {len(self.waiting_queue)}"
|
1169
1223
|
|
1170
1224
|
logger.info(f)
|
@@ -1227,6 +1281,7 @@ class Scheduler(
|
|
1227
1281
|
|
1228
1282
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
1229
1283
|
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
|
1284
|
+
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "
|
1230
1285
|
|
1231
1286
|
msg += (
|
1232
1287
|
f"cuda graph: {can_run_cuda_graph}, "
|
@@ -1528,7 +1583,7 @@ class Scheduler(
|
|
1528
1583
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
1529
1584
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
1530
1585
|
)
|
1531
|
-
self._extend_requests_to_queue(retracted_reqs)
|
1586
|
+
self._extend_requests_to_queue(retracted_reqs, is_retracted=True)
|
1532
1587
|
else:
|
1533
1588
|
self.new_token_ratio = max(
|
1534
1589
|
self.new_token_ratio - self.new_token_ratio_decay,
|
@@ -2055,7 +2110,8 @@ class Scheduler(
|
|
2055
2110
|
# In this case, we change the input_ids to be only one token to make this prefill cheap.
|
2056
2111
|
if req.rid.startswith(recv_req.rid):
|
2057
2112
|
logger.debug(f"Abort grammar queue request. {req.rid=}")
|
2058
|
-
req.grammar
|
2113
|
+
if req.grammar:
|
2114
|
+
req.grammar.cancel()
|
2059
2115
|
req.set_finish_with_abort("Aborted by AbortReq.")
|
2060
2116
|
|
2061
2117
|
# Delete requests in the running batch
|
@@ -418,6 +418,20 @@ class TokenizerManager:
|
|
418
418
|
|
419
419
|
obj.normalize_batch_and_arguments()
|
420
420
|
|
421
|
+
if isinstance(obj, GenerateReqInput):
|
422
|
+
return_hidden_states = obj.return_hidden_states
|
423
|
+
has_return_hidden_states = return_hidden_states == True or (
|
424
|
+
isinstance(return_hidden_states, list) and any(return_hidden_states)
|
425
|
+
)
|
426
|
+
if (
|
427
|
+
not self.server_args.enable_return_hidden_states
|
428
|
+
and has_return_hidden_states
|
429
|
+
):
|
430
|
+
raise ValueError(
|
431
|
+
"return_hidden_states=True requires the server to be started "
|
432
|
+
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
|
433
|
+
)
|
434
|
+
|
421
435
|
if self.log_requests:
|
422
436
|
max_length, skip_names, _ = self.log_request_metadata
|
423
437
|
logger.info(
|
@@ -445,6 +459,10 @@ class TokenizerManager:
|
|
445
459
|
# Tokenize
|
446
460
|
input_embeds = None
|
447
461
|
input_text = obj.text
|
462
|
+
token_type_ids = None
|
463
|
+
is_cross_encoder_request = (
|
464
|
+
isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
|
465
|
+
)
|
448
466
|
if obj.input_embeds is not None:
|
449
467
|
if not self.server_args.disable_radix_cache:
|
450
468
|
raise ValueError(
|
@@ -463,7 +481,14 @@ class TokenizerManager:
|
|
463
481
|
"accept text prompts. Please provide input_ids or re-initialize "
|
464
482
|
"the engine with skip_tokenizer_init=False."
|
465
483
|
)
|
466
|
-
|
484
|
+
encoded = self.tokenizer(
|
485
|
+
input_text, return_token_type_ids=is_cross_encoder_request
|
486
|
+
)
|
487
|
+
|
488
|
+
input_ids = encoded["input_ids"]
|
489
|
+
if is_cross_encoder_request:
|
490
|
+
input_ids = encoded["input_ids"][0]
|
491
|
+
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
467
492
|
|
468
493
|
if self.mm_processor and obj.contains_mm_input():
|
469
494
|
image_inputs = await self.mm_processor.process_mm_data_async(
|
@@ -479,7 +504,7 @@ class TokenizerManager:
|
|
479
504
|
|
480
505
|
self._validate_token_len(obj, input_ids)
|
481
506
|
return self._create_tokenized_object(
|
482
|
-
obj, input_text, input_ids, input_embeds, image_inputs
|
507
|
+
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
|
483
508
|
)
|
484
509
|
|
485
510
|
def _validate_token_len(
|
@@ -518,6 +543,7 @@ class TokenizerManager:
|
|
518
543
|
input_ids: List[int],
|
519
544
|
input_embeds: Optional[Union[List[float], None]] = None,
|
520
545
|
image_inputs: Optional[Dict] = None,
|
546
|
+
token_type_ids: Optional[List[int]] = None,
|
521
547
|
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
522
548
|
"""Create a tokenized request object from common parameters."""
|
523
549
|
|
@@ -578,6 +604,7 @@ class TokenizerManager:
|
|
578
604
|
input_text,
|
579
605
|
input_ids,
|
580
606
|
image_inputs,
|
607
|
+
token_type_ids,
|
581
608
|
sampling_params,
|
582
609
|
)
|
583
610
|
|
@@ -9,12 +9,14 @@ import torch
|
|
9
9
|
from sglang.srt.managers.cache_controller import HiCacheController
|
10
10
|
from sglang.srt.mem_cache.memory_pool import (
|
11
11
|
MHATokenToKVPool,
|
12
|
-
MHATokenToKVPoolHost,
|
13
12
|
MLATokenToKVPool,
|
14
|
-
MLATokenToKVPoolHost,
|
15
13
|
ReqToTokenPool,
|
16
14
|
TokenToKVPoolAllocator,
|
17
15
|
)
|
16
|
+
from sglang.srt.mem_cache.memory_pool_host import (
|
17
|
+
MHATokenToKVPoolHost,
|
18
|
+
MLATokenToKVPoolHost,
|
19
|
+
)
|
18
20
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
19
21
|
|
20
22
|
logger = logging.getLogger(__name__)
|