sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +1 -0
- sglang/bench_serving.py +9 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +51 -5
- sglang/srt/layers/attention/triton_backend.py +16 -25
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
- sglang/srt/layers/moe/topk.py +191 -0
- sglang/srt/layers/quantization/__init__.py +5 -50
- sglang/srt/layers/quantization/fp8.py +221 -36
- sglang/srt/layers/quantization/fp8_kernel.py +278 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/radix_attention.py +8 -1
- sglang/srt/layers/sampler.py +27 -5
- sglang/srt/layers/torchao_utils.py +31 -0
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +54 -34
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +171 -136
- sglang/srt/managers/tokenizer_manager.py +184 -133
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -11
- sglang/srt/model_executor/model_runner.py +28 -14
- sglang/srt/model_parallel.py +66 -5
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +73 -9
- sglang/srt/models/llama.py +22 -0
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +8 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/server.py +2 -1
- sglang/srt/server_args.py +19 -9
- sglang/srt/utils.py +40 -54
- sglang/test/test_block_fp8.py +341 -0
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
|
|
22
22
|
import sys
|
23
23
|
import time
|
24
24
|
import uuid
|
25
|
-
from typing import Dict, List, Optional, Tuple, Union
|
25
|
+
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
|
26
26
|
|
27
27
|
import fastapi
|
28
28
|
import uvloop
|
@@ -30,6 +30,7 @@ import zmq
|
|
30
30
|
import zmq.asyncio
|
31
31
|
from fastapi import BackgroundTasks
|
32
32
|
|
33
|
+
from sglang.srt.aio_rwlock import RWLock
|
33
34
|
from sglang.srt.configs.model_config import ModelConfig
|
34
35
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
35
36
|
from sglang.srt.managers.image_processor import (
|
@@ -62,7 +63,11 @@ from sglang.srt.managers.io_struct import (
|
|
62
63
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
63
64
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
64
65
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
65
|
-
from sglang.srt.utils import
|
66
|
+
from sglang.srt.utils import (
|
67
|
+
dataclass_to_string_truncated,
|
68
|
+
get_zmq_socket,
|
69
|
+
kill_process_tree,
|
70
|
+
)
|
66
71
|
|
67
72
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
68
73
|
|
@@ -76,11 +81,15 @@ class ReqState:
|
|
76
81
|
out_list: List
|
77
82
|
finished: bool
|
78
83
|
event: asyncio.Event
|
84
|
+
obj: Any
|
79
85
|
|
80
86
|
# For metrics
|
81
87
|
created_time: float
|
82
88
|
first_token_time: Optional[float] = None
|
83
89
|
|
90
|
+
# For streaming output
|
91
|
+
last_output_offset: int = 0
|
92
|
+
|
84
93
|
|
85
94
|
class TokenizerManager:
|
86
95
|
"""TokenizerManager is a process that tokenizes the text."""
|
@@ -119,6 +128,7 @@ class TokenizerManager:
|
|
119
128
|
|
120
129
|
self.is_generation = self.model_config.is_generation
|
121
130
|
self.context_len = self.model_config.context_len
|
131
|
+
self.image_token_id = self.model_config.image_token_id
|
122
132
|
|
123
133
|
# Create image processor placeholder
|
124
134
|
self.image_processor = get_dummy_image_processor()
|
@@ -151,9 +161,12 @@ class TokenizerManager:
|
|
151
161
|
self.to_create_loop = True
|
152
162
|
self.rid_to_state: Dict[str, ReqState] = {}
|
153
163
|
|
154
|
-
#
|
155
|
-
self.model_update_lock =
|
156
|
-
self.model_update_result =
|
164
|
+
# The event to notify the weight sync is finished.
|
165
|
+
self.model_update_lock = RWLock()
|
166
|
+
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
167
|
+
None
|
168
|
+
)
|
169
|
+
self.asyncio_tasks = set()
|
157
170
|
|
158
171
|
# For session info
|
159
172
|
self.session_futures = {} # session_id -> asyncio event
|
@@ -180,9 +193,6 @@ class TokenizerManager:
|
|
180
193
|
if self.to_create_loop:
|
181
194
|
self.create_handle_loop()
|
182
195
|
|
183
|
-
while self.model_update_lock.locked():
|
184
|
-
await asyncio.sleep(0.001)
|
185
|
-
|
186
196
|
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
|
187
197
|
raise ValueError(
|
188
198
|
"This model does not appear to be an embedding model by default. "
|
@@ -190,17 +200,24 @@ class TokenizerManager:
|
|
190
200
|
)
|
191
201
|
|
192
202
|
obj.normalize_batch_and_arguments()
|
193
|
-
|
194
|
-
if
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
203
|
+
|
204
|
+
if self.server_args.log_requests:
|
205
|
+
logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
|
206
|
+
|
207
|
+
async with self.model_update_lock.reader_lock:
|
208
|
+
is_single = obj.is_single
|
209
|
+
if is_single:
|
210
|
+
tokenized_obj = await self._tokenize_one_request(obj)
|
211
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
212
|
+
async for response in self._wait_one_response(
|
213
|
+
obj, request, created_time
|
214
|
+
):
|
215
|
+
yield response
|
216
|
+
else:
|
217
|
+
async for response in self._handle_batch_request(
|
218
|
+
obj, request, created_time
|
219
|
+
):
|
220
|
+
yield response
|
204
221
|
|
205
222
|
async def _tokenize_one_request(
|
206
223
|
self,
|
@@ -214,7 +231,7 @@ class TokenizerManager:
|
|
214
231
|
if not self.server_args.disable_radix_cache:
|
215
232
|
raise ValueError(
|
216
233
|
"input_embeds is provided while disable_radix_cache is False. "
|
217
|
-
"Please add `--disable-radix-
|
234
|
+
"Please add `--disable-radix-cache` when you launch the server "
|
218
235
|
"if you want to use input_embeds as inputs."
|
219
236
|
)
|
220
237
|
input_embeds = obj.input_embeds
|
@@ -283,7 +300,7 @@ class TokenizerManager:
|
|
283
300
|
):
|
284
301
|
"""Wait for the response of one request."""
|
285
302
|
event = asyncio.Event()
|
286
|
-
state = ReqState([], False, event, created_time=created_time)
|
303
|
+
state = ReqState([], False, event, obj, created_time=created_time)
|
287
304
|
self.rid_to_state[obj.rid] = state
|
288
305
|
|
289
306
|
while True:
|
@@ -295,27 +312,25 @@ class TokenizerManager:
|
|
295
312
|
raise ValueError(f"Abort request {obj.rid}")
|
296
313
|
continue
|
297
314
|
|
298
|
-
|
299
|
-
out = self.convert_logprob_style(
|
300
|
-
state.out_list[-1],
|
301
|
-
obj.return_logprob,
|
302
|
-
obj.top_logprobs_num,
|
303
|
-
obj.return_text_in_logprobs,
|
304
|
-
)
|
305
|
-
else: # isinstance(obj, (EmbeddingReqInput,))
|
306
|
-
out = state.out_list[-1]
|
315
|
+
out = state.out_list[-1]
|
307
316
|
|
308
317
|
state.out_list = []
|
309
318
|
if state.finished:
|
310
319
|
if self.server_args.log_requests:
|
311
|
-
|
312
|
-
logger.info(
|
320
|
+
msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
|
321
|
+
logger.info(msg)
|
313
322
|
del self.rid_to_state[obj.rid]
|
314
323
|
yield out
|
315
324
|
break
|
316
325
|
|
317
326
|
state.event.clear()
|
318
|
-
|
327
|
+
|
328
|
+
if obj.stream:
|
329
|
+
yield out
|
330
|
+
else:
|
331
|
+
if request is not None and await request.is_disconnected():
|
332
|
+
self.abort_request(obj.rid)
|
333
|
+
raise ValueError(f"Abort request {obj.rid}")
|
319
334
|
|
320
335
|
async def _handle_batch_request(
|
321
336
|
self,
|
@@ -424,55 +439,52 @@ class TokenizerManager:
|
|
424
439
|
self,
|
425
440
|
obj: UpdateWeightFromDiskReqInput,
|
426
441
|
request: Optional[fastapi.Request] = None,
|
427
|
-
):
|
442
|
+
) -> Tuple[bool, str]:
|
428
443
|
if self.to_create_loop:
|
429
444
|
self.create_handle_loop()
|
430
445
|
|
431
446
|
# default the load format to the server_args
|
432
447
|
if obj.load_format is None:
|
433
448
|
obj.load_format = self.server_args.load_format
|
449
|
+
logger.info("Start update_weights. Load format=%s", obj.load_format)
|
434
450
|
|
435
|
-
if
|
451
|
+
if True:
|
452
|
+
# Hold the lock if it is not async. This means that weight sync
|
453
|
+
# cannot run while requests are in progress.
|
454
|
+
async with self.model_update_lock.writer_lock:
|
455
|
+
return await self._wait_for_model_update_from_disk(obj)
|
436
456
|
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
self.
|
446
|
-
self.
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
self.server_args.load_format = obj.load_format
|
463
|
-
self.model_path = obj.model_path
|
464
|
-
all_message = [r.message for r in result]
|
465
|
-
all_message = " | ".join(all_message)
|
466
|
-
return all_success, all_message
|
467
|
-
|
468
|
-
else:
|
469
|
-
return False, "Another update is in progress. Please try again later."
|
457
|
+
async def _wait_for_model_update_from_disk(
|
458
|
+
self, obj: UpdateWeightFromDiskReqInput
|
459
|
+
) -> Tuple[bool, str, int]:
|
460
|
+
self.send_to_scheduler.send_pyobj(obj)
|
461
|
+
self.model_update_result = asyncio.Future()
|
462
|
+
if self.server_args.dp_size == 1:
|
463
|
+
result = await self.model_update_result
|
464
|
+
if result.success:
|
465
|
+
self.served_model_name = obj.model_path
|
466
|
+
self.server_args.model_path = obj.model_path
|
467
|
+
self.server_args.load_format = obj.load_format
|
468
|
+
self.model_path = obj.model_path
|
469
|
+
return result.success, result.message
|
470
|
+
else: # self.server_args.dp_size > 1
|
471
|
+
self.model_update_tmp = []
|
472
|
+
result = await self.model_update_result
|
473
|
+
|
474
|
+
all_success = all([r.success for r in result])
|
475
|
+
if all_success is True:
|
476
|
+
self.server_args.model_path = obj.model_path
|
477
|
+
self.server_args.load_format = obj.load_format
|
478
|
+
self.model_path = obj.model_path
|
479
|
+
all_message = [r.message for r in result]
|
480
|
+
all_message = " | ".join(all_message)
|
481
|
+
return all_success, all_message
|
470
482
|
|
471
483
|
async def init_weights_update_group(
|
472
484
|
self,
|
473
485
|
obj: InitWeightsUpdateGroupReqInput,
|
474
486
|
request: Optional[fastapi.Request] = None,
|
475
|
-
) -> bool:
|
487
|
+
) -> Tuple[bool, str]:
|
476
488
|
if self.to_create_loop:
|
477
489
|
self.create_handle_loop()
|
478
490
|
self.send_to_scheduler.send_pyobj(obj)
|
@@ -488,25 +500,22 @@ class TokenizerManager:
|
|
488
500
|
self,
|
489
501
|
obj: UpdateWeightsFromDistributedReqInput,
|
490
502
|
request: Optional[fastapi.Request] = None,
|
491
|
-
):
|
503
|
+
) -> Tuple[bool, str]:
|
492
504
|
if self.to_create_loop:
|
493
505
|
self.create_handle_loop()
|
494
506
|
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
return
|
507
|
-
False,
|
508
|
-
"Another parameter update is in progress. Please try again later.",
|
509
|
-
)
|
507
|
+
# This means that weight sync
|
508
|
+
# cannot run while requests are in progress.
|
509
|
+
async with self.model_update_lock.writer_lock:
|
510
|
+
self.send_to_scheduler.send_pyobj(obj)
|
511
|
+
self.parameter_update_result: Awaitable[
|
512
|
+
UpdateWeightsFromDistributedReqOutput
|
513
|
+
] = asyncio.Future()
|
514
|
+
assert (
|
515
|
+
self.server_args.dp_size == 1
|
516
|
+
), "dp_size must be for update weights from distributed"
|
517
|
+
result = await self.parameter_update_result
|
518
|
+
return result.success, result.message
|
510
519
|
|
511
520
|
async def get_weights_by_name(
|
512
521
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
@@ -565,15 +574,15 @@ class TokenizerManager:
|
|
565
574
|
|
566
575
|
self.to_create_loop = False
|
567
576
|
loop = asyncio.get_event_loop()
|
568
|
-
loop.create_task(self.handle_loop())
|
577
|
+
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
|
569
578
|
|
570
579
|
signal_handler = SignalHandler(self)
|
571
580
|
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
572
|
-
loop.create_task(self.sigterm_watchdog())
|
581
|
+
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
|
573
582
|
|
574
583
|
async def sigterm_watchdog(self):
|
575
584
|
while not self.gracefully_exit:
|
576
|
-
await asyncio.sleep(
|
585
|
+
await asyncio.sleep(5)
|
577
586
|
|
578
587
|
# drain requests
|
579
588
|
while True:
|
@@ -609,29 +618,55 @@ class TokenizerManager:
|
|
609
618
|
if state is None:
|
610
619
|
continue
|
611
620
|
|
612
|
-
|
621
|
+
meta_info = {
|
622
|
+
"id": rid,
|
623
|
+
"finish_reason": recv_obj.finished_reasons[i],
|
624
|
+
"prompt_tokens": recv_obj.prompt_tokens[i],
|
625
|
+
}
|
626
|
+
|
627
|
+
if getattr(state.obj, "return_logprob", False):
|
628
|
+
self.convert_logprob_style(
|
629
|
+
meta_info,
|
630
|
+
state.obj.top_logprobs_num,
|
631
|
+
state.obj.return_text_in_logprobs,
|
632
|
+
recv_obj,
|
633
|
+
i,
|
634
|
+
)
|
635
|
+
|
636
|
+
if not isinstance(recv_obj, BatchEmbeddingOut):
|
637
|
+
meta_info.update(
|
638
|
+
{
|
639
|
+
"completion_tokens": recv_obj.completion_tokens[i],
|
640
|
+
"cached_tokens": recv_obj.cached_tokens[i],
|
641
|
+
}
|
642
|
+
)
|
643
|
+
|
613
644
|
if isinstance(recv_obj, BatchStrOut):
|
614
645
|
out_dict = {
|
615
646
|
"text": recv_obj.output_strs[i],
|
616
|
-
"meta_info":
|
647
|
+
"meta_info": meta_info,
|
617
648
|
}
|
618
649
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
619
650
|
out_dict = {
|
620
651
|
"token_ids": recv_obj.output_ids[i],
|
621
|
-
"meta_info":
|
652
|
+
"meta_info": meta_info,
|
622
653
|
}
|
623
654
|
else:
|
624
655
|
assert isinstance(recv_obj, BatchEmbeddingOut)
|
625
656
|
out_dict = {
|
626
657
|
"embedding": recv_obj.embeddings[i],
|
627
|
-
"meta_info":
|
658
|
+
"meta_info": meta_info,
|
628
659
|
}
|
629
660
|
state.out_list.append(out_dict)
|
630
|
-
state.finished = recv_obj.
|
661
|
+
state.finished = recv_obj.finished_reasons[i] is not None
|
631
662
|
state.event.set()
|
632
663
|
|
633
664
|
if self.enable_metrics:
|
634
|
-
completion_tokens =
|
665
|
+
completion_tokens = (
|
666
|
+
recv_obj.completion_tokens[i]
|
667
|
+
if recv_obj.completion_tokens
|
668
|
+
else 0
|
669
|
+
)
|
635
670
|
|
636
671
|
if state.first_token_time is None:
|
637
672
|
state.first_token_time = time.time()
|
@@ -647,7 +682,7 @@ class TokenizerManager:
|
|
647
682
|
|
648
683
|
if state.finished:
|
649
684
|
self.metrics_collector.inc_prompt_tokens(
|
650
|
-
recv_obj.
|
685
|
+
recv_obj.prompt_tokens[i]
|
651
686
|
)
|
652
687
|
self.metrics_collector.inc_generation_tokens(
|
653
688
|
completion_tokens
|
@@ -696,57 +731,73 @@ class TokenizerManager:
|
|
696
731
|
|
697
732
|
def convert_logprob_style(
|
698
733
|
self,
|
699
|
-
|
700
|
-
return_logprob: bool,
|
734
|
+
meta_info: dict,
|
701
735
|
top_logprobs_num: int,
|
702
736
|
return_text_in_logprobs: bool,
|
737
|
+
recv_obj: BatchStrOut,
|
738
|
+
recv_obj_index: int,
|
703
739
|
):
|
704
|
-
|
705
|
-
|
706
|
-
|
740
|
+
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
741
|
+
recv_obj.input_token_logprobs_val[recv_obj_index],
|
742
|
+
recv_obj.input_token_logprobs_idx[recv_obj_index],
|
743
|
+
return_text_in_logprobs,
|
744
|
+
)
|
745
|
+
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
746
|
+
recv_obj.output_token_logprobs_val[recv_obj_index],
|
747
|
+
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
748
|
+
return_text_in_logprobs,
|
749
|
+
)
|
750
|
+
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
751
|
+
recv_obj_index
|
752
|
+
]
|
753
|
+
|
754
|
+
if top_logprobs_num > 0:
|
755
|
+
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
756
|
+
recv_obj.input_top_logprobs_val[recv_obj_index],
|
757
|
+
recv_obj.input_top_logprobs_idx[recv_obj_index],
|
758
|
+
return_text_in_logprobs,
|
707
759
|
)
|
708
|
-
|
709
|
-
|
760
|
+
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
761
|
+
recv_obj.output_top_logprobs_val[recv_obj_index],
|
762
|
+
recv_obj.output_top_logprobs_idx[recv_obj_index],
|
763
|
+
return_text_in_logprobs,
|
710
764
|
)
|
711
765
|
|
712
|
-
if top_logprobs_num > 0:
|
713
|
-
ret["meta_info"]["input_top_logprobs"] = (
|
714
|
-
self.detokenize_top_logprobs_tokens(
|
715
|
-
ret["meta_info"]["input_top_logprobs"],
|
716
|
-
return_text_in_logprobs,
|
717
|
-
)
|
718
|
-
)
|
719
|
-
ret["meta_info"]["output_top_logprobs"] = (
|
720
|
-
self.detokenize_top_logprobs_tokens(
|
721
|
-
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
722
|
-
)
|
723
|
-
)
|
724
|
-
return ret
|
725
|
-
|
726
766
|
def detokenize_logprob_tokens(
|
727
|
-
self,
|
767
|
+
self,
|
768
|
+
token_logprobs_val: List[float],
|
769
|
+
token_logprobs_idx: List[int],
|
770
|
+
decode_to_text: bool,
|
728
771
|
):
|
729
|
-
# TODO(lianmin): This should run on DetokenizerManager
|
730
772
|
if not decode_to_text:
|
731
|
-
return [
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
]
|
773
|
+
return [
|
774
|
+
(logprob, token_id, None)
|
775
|
+
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
|
776
|
+
]
|
777
|
+
else:
|
778
|
+
assert self.tokenizer is not None
|
779
|
+
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
|
780
|
+
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
|
740
781
|
|
741
|
-
def detokenize_top_logprobs_tokens(
|
782
|
+
def detokenize_top_logprobs_tokens(
|
783
|
+
self,
|
784
|
+
token_logprobs_val: List[float],
|
785
|
+
token_logprobs_idx: List[int],
|
786
|
+
decode_to_text: bool,
|
787
|
+
):
|
742
788
|
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
743
789
|
# We should batch all top-k tokens in all positions.
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
790
|
+
ret = []
|
791
|
+
for i in range(len(token_logprobs_val)):
|
792
|
+
if token_logprobs_val[i]:
|
793
|
+
ret.append(
|
794
|
+
self.detokenize_logprob_tokens(
|
795
|
+
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
|
796
|
+
)
|
748
797
|
)
|
749
|
-
|
798
|
+
else:
|
799
|
+
ret.append(None)
|
800
|
+
return ret
|
750
801
|
|
751
802
|
|
752
803
|
class SignalHandler:
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Callable
|
2
|
+
from typing import Callable, List, Tuple
|
3
3
|
|
4
4
|
|
5
5
|
class BasePrefixCache(ABC):
|
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
|
|
10
10
|
pass
|
11
11
|
|
12
12
|
@abstractmethod
|
13
|
-
def match_prefix(self, **kwargs):
|
13
|
+
def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
|
14
14
|
pass
|
15
15
|
|
16
16
|
@abstractmethod
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
4
|
|
5
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
5
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
6
6
|
|
7
7
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
8
8
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
|
|
30
30
|
def reset(self):
|
31
31
|
self.entries = {}
|
32
32
|
|
33
|
-
def match_prefix(self, rid: int, key: List[int]):
|
33
|
+
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
|
34
34
|
if rid not in self.entries:
|
35
35
|
return [], None
|
36
36
|
|
@@ -184,26 +184,35 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
184
184
|
device: str,
|
185
185
|
):
|
186
186
|
super().__init__(size, dtype, device)
|
187
|
+
self.head_num = head_num
|
188
|
+
self.head_dim = head_dim
|
189
|
+
self.layer_num = layer_num
|
190
|
+
self._create_buffers()
|
187
191
|
|
192
|
+
def _create_buffers(self):
|
188
193
|
# [size, head_num, head_dim] for each layer
|
189
194
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
190
195
|
self.k_buffer = [
|
191
196
|
torch.empty(
|
192
|
-
(size + 1, head_num, head_dim),
|
197
|
+
(self.size + 1, self.head_num, self.head_dim),
|
193
198
|
dtype=self.store_dtype,
|
194
|
-
device=device,
|
199
|
+
device=self.device,
|
195
200
|
)
|
196
|
-
for _ in range(layer_num)
|
201
|
+
for _ in range(self.layer_num)
|
197
202
|
]
|
198
203
|
self.v_buffer = [
|
199
204
|
torch.empty(
|
200
|
-
(size + 1, head_num, head_dim),
|
205
|
+
(self.size + 1, self.head_num, self.head_dim),
|
201
206
|
dtype=self.store_dtype,
|
202
|
-
device=device,
|
207
|
+
device=self.device,
|
203
208
|
)
|
204
|
-
for _ in range(layer_num)
|
209
|
+
for _ in range(self.layer_num)
|
205
210
|
]
|
206
211
|
|
212
|
+
def _clear_buffers(self):
|
213
|
+
del self.k_buffer
|
214
|
+
del self.v_buffer
|
215
|
+
|
207
216
|
def get_key_buffer(self, layer_id: int):
|
208
217
|
if self.store_dtype != self.dtype:
|
209
218
|
return self.k_buffer[layer_id].view(self.dtype)
|
@@ -245,7 +254,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
245
254
|
|
246
255
|
|
247
256
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
248
|
-
|
249
257
|
def __init__(
|
250
258
|
self,
|
251
259
|
size: int,
|
@@ -298,7 +306,6 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
298
306
|
|
299
307
|
|
300
308
|
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
301
|
-
|
302
309
|
def __init__(
|
303
310
|
self,
|
304
311
|
size: int,
|
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
|
|
22
22
|
import heapq
|
23
23
|
import time
|
24
24
|
from collections import defaultdict
|
25
|
-
from typing import TYPE_CHECKING, Callable, List, Optional
|
25
|
+
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
26
26
|
|
27
27
|
import torch
|
28
28
|
|
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
|
|
76
76
|
self.root_node.lock_ref = 1
|
77
77
|
self.evictable_size_ = 0
|
78
78
|
|
79
|
-
def match_prefix(self, key: List, **kwargs):
|
79
|
+
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
80
|
+
"""Find the matching prefix from the radix tree.
|
81
|
+
Args:
|
82
|
+
key: A list of token IDs to find a matching prefix.
|
83
|
+
Returns:
|
84
|
+
A tuple of a tensor of matching prefix token IDs and
|
85
|
+
the last node that contains the prefix values. Note that
|
86
|
+
this API can modify the internal state of the Radix tree.
|
87
|
+
The last node create a new child if the prefix is shorter
|
88
|
+
than the last node's value.
|
89
|
+
"""
|
80
90
|
if self.disable:
|
81
91
|
return [], self.root_node
|
82
92
|
|