sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +60 -23
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +52 -167
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +130 -43
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +49 -11
- sglang/srt/model_executor/forward_batch_info.py +59 -27
- sglang/srt/model_executor/model_runner.py +210 -61
- sglang/srt/models/chatglm.py +4 -12
- sglang/srt/models/commandr.py +5 -1
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +5 -1
- sglang/srt/models/deepseek_v2.py +5 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +15 -7
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +16 -2
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +5 -1
- sglang/srt/models/mixtral.py +5 -1
- sglang/srt/models/mixtral_quant.py +5 -1
- sglang/srt/models/qwen.py +5 -2
- sglang/srt/models/qwen2.py +13 -3
- sglang/srt/models/qwen2_moe.py +5 -14
- sglang/srt/models/stablelm.py +5 -1
- sglang/srt/openai_api/adapter.py +117 -37
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
- sglang/srt/server.py +84 -56
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +23 -31
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
|
|
31
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
-
from sglang.srt.layers.logits_processor import
|
34
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
35
|
from sglang.srt.managers.io_struct import (
|
36
36
|
AbortReq,
|
37
37
|
BatchEmbeddingOut,
|
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
|
|
39
39
|
FlushCacheReq,
|
40
40
|
TokenizedEmbeddingReqInput,
|
41
41
|
TokenizedGenerateReqInput,
|
42
|
+
UpdateWeightReqInput,
|
43
|
+
UpdateWeightReqOutput,
|
42
44
|
)
|
43
45
|
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
44
46
|
from sglang.srt.managers.schedule_batch import (
|
@@ -54,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
54
56
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
55
57
|
from sglang.srt.server_args import ServerArgs
|
56
58
|
from sglang.srt.utils import (
|
59
|
+
configure_logger,
|
57
60
|
is_multimodal_model,
|
58
61
|
set_random_seed,
|
59
62
|
suppress_other_loggers,
|
@@ -85,10 +88,6 @@ class ModelTpServer:
|
|
85
88
|
self.schedule_policy = server_args.schedule_policy
|
86
89
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
87
90
|
|
88
|
-
# Chunked prefill
|
89
|
-
self.chunked_prefill_size = server_args.chunked_prefill_size
|
90
|
-
self.current_inflight_req = None
|
91
|
-
|
92
91
|
# Init model and tokenizer
|
93
92
|
self.model_config = ModelConfig(
|
94
93
|
server_args.model_path,
|
@@ -96,6 +95,7 @@ class ModelTpServer:
|
|
96
95
|
context_length=server_args.context_length,
|
97
96
|
model_overide_args=model_overide_args,
|
98
97
|
)
|
98
|
+
|
99
99
|
self.model_runner = ModelRunner(
|
100
100
|
model_config=self.model_config,
|
101
101
|
mem_fraction_static=server_args.mem_fraction_static,
|
@@ -135,11 +135,17 @@ class ModelTpServer:
|
|
135
135
|
self.model_config.context_len - 1,
|
136
136
|
self.max_total_num_tokens - 1,
|
137
137
|
)
|
138
|
+
|
139
|
+
# Sync random seed
|
140
|
+
server_args.random_seed = broadcast_recv_input(
|
141
|
+
[server_args.random_seed],
|
142
|
+
self.tp_rank,
|
143
|
+
self.model_runner.tp_group.cpu_group,
|
144
|
+
)[0]
|
138
145
|
set_random_seed(server_args.random_seed)
|
139
146
|
|
140
147
|
# Print info
|
141
148
|
logger.info(
|
142
|
-
f"[gpu={self.gpu_id}] "
|
143
149
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
144
150
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
145
151
|
f"max_running_requests={self.max_running_requests}, "
|
@@ -175,6 +181,13 @@ class ModelTpServer:
|
|
175
181
|
self.num_generated_tokens = 0
|
176
182
|
self.last_stats_tic = time.time()
|
177
183
|
|
184
|
+
# Chunked prefill
|
185
|
+
self.chunked_prefill_size = server_args.chunked_prefill_size
|
186
|
+
self.current_inflight_req = None
|
187
|
+
self.is_mixed_chunk = (
|
188
|
+
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
189
|
+
)
|
190
|
+
|
178
191
|
# Init the FSM cache for constrained generation
|
179
192
|
if not server_args.skip_tokenizer_init:
|
180
193
|
self.regex_fsm_cache = FSMCache(
|
@@ -211,6 +224,9 @@ class ModelTpServer:
|
|
211
224
|
self.flush_cache()
|
212
225
|
elif isinstance(recv_req, AbortReq):
|
213
226
|
self.abort_request(recv_req)
|
227
|
+
elif isinstance(recv_req, UpdateWeightReqInput):
|
228
|
+
success, message = self.update_weights(recv_req)
|
229
|
+
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
214
230
|
else:
|
215
231
|
raise ValueError(f"Invalid request: {recv_req}")
|
216
232
|
|
@@ -268,7 +284,7 @@ class ModelTpServer:
|
|
268
284
|
self.num_generated_tokens = 0
|
269
285
|
self.last_stats_tic = time.time()
|
270
286
|
logger.info(
|
271
|
-
f"
|
287
|
+
f"Decode batch. "
|
272
288
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
273
289
|
f"#token: {num_used}, "
|
274
290
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
@@ -307,11 +323,16 @@ class ModelTpServer:
|
|
307
323
|
if self.model_runner.is_generation:
|
308
324
|
req.pixel_values = recv_req.pixel_values
|
309
325
|
if req.pixel_values is not None:
|
326
|
+
image_hash = (
|
327
|
+
hash(tuple(recv_req.image_hash))
|
328
|
+
if isinstance(recv_req.image_hash, list)
|
329
|
+
else recv_req.image_hash
|
330
|
+
)
|
310
331
|
req.pad_value = [
|
311
|
-
(
|
312
|
-
(
|
313
|
-
(
|
314
|
-
(
|
332
|
+
(image_hash) % self.model_config.vocab_size,
|
333
|
+
(image_hash >> 16) % self.model_config.vocab_size,
|
334
|
+
(image_hash >> 32) % self.model_config.vocab_size,
|
335
|
+
(image_hash >> 64) % self.model_config.vocab_size,
|
315
336
|
]
|
316
337
|
req.image_size = recv_req.image_size
|
317
338
|
(
|
@@ -366,11 +387,14 @@ class ModelTpServer:
|
|
366
387
|
# Get priority queue
|
367
388
|
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
368
389
|
|
390
|
+
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
391
|
+
|
369
392
|
adder = PrefillAdder(
|
370
393
|
self.tree_cache,
|
371
394
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
372
395
|
self.max_prefill_tokens,
|
373
396
|
self.chunked_prefill_size,
|
397
|
+
num_mixed_running,
|
374
398
|
)
|
375
399
|
|
376
400
|
if self.running_batch is not None:
|
@@ -416,15 +440,27 @@ class ModelTpServer:
|
|
416
440
|
)
|
417
441
|
else:
|
418
442
|
tree_cache_hit_rate = 0.0
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
443
|
+
|
444
|
+
if num_mixed_running > 0:
|
445
|
+
logger.info(
|
446
|
+
f"Prefill batch"
|
447
|
+
f"(mixed #running-req: {num_mixed_running}). "
|
448
|
+
f"#new-seq: {len(can_run_list)}, "
|
449
|
+
f"#new-token: {adder.log_input_tokens}, "
|
450
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
451
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
452
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
453
|
+
)
|
454
|
+
else:
|
455
|
+
logger.info(
|
456
|
+
f"Prefill batch. "
|
457
|
+
f"#new-seq: {len(can_run_list)}, "
|
458
|
+
f"#new-token: {adder.log_input_tokens}, "
|
459
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
460
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
461
|
+
f"#running-req: {running_bs}, "
|
462
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
463
|
+
)
|
428
464
|
|
429
465
|
# Return the new batch
|
430
466
|
new_batch = ScheduleBatch.init_new(
|
@@ -440,21 +476,39 @@ class ModelTpServer:
|
|
440
476
|
# Build batch tensors
|
441
477
|
batch.prepare_for_extend(self.model_config.vocab_size)
|
442
478
|
|
479
|
+
decoding_reqs = []
|
480
|
+
if self.is_mixed_chunk and self.running_batch is not None:
|
481
|
+
self.running_batch.prepare_for_decode()
|
482
|
+
batch.mix_with_running(self.running_batch)
|
483
|
+
decoding_reqs = self.running_batch.reqs
|
484
|
+
self.running_batch = None
|
485
|
+
|
443
486
|
if self.model_runner.is_generation:
|
444
487
|
# Forward and sample the next tokens
|
445
488
|
if batch.extend_num_tokens != 0:
|
446
|
-
|
447
|
-
|
489
|
+
sample_output, logits_output = self.model_runner.forward(
|
490
|
+
batch, ForwardMode.EXTEND
|
491
|
+
)
|
492
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
493
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
494
|
+
next_token_ids
|
495
|
+
)
|
448
496
|
|
449
497
|
# Move logprobs to cpu
|
450
|
-
if
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
498
|
+
if logits_output.next_token_logprobs is not None:
|
499
|
+
logits_output.next_token_logprobs = (
|
500
|
+
logits_output.next_token_logprobs[
|
501
|
+
torch.arange(
|
502
|
+
len(next_token_ids), device=next_token_ids.device
|
503
|
+
),
|
504
|
+
next_token_ids,
|
505
|
+
].tolist()
|
506
|
+
)
|
507
|
+
logits_output.input_token_logprobs = (
|
508
|
+
logits_output.input_token_logprobs.tolist()
|
509
|
+
)
|
510
|
+
logits_output.normalized_prompt_logprobs = (
|
511
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
458
512
|
)
|
459
513
|
|
460
514
|
next_token_ids = next_token_ids.tolist()
|
@@ -477,9 +531,15 @@ class ModelTpServer:
|
|
477
531
|
req.output_ids.append(next_token_ids[i])
|
478
532
|
req.check_finished()
|
479
533
|
|
534
|
+
if req.regex_fsm is not None:
|
535
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
536
|
+
req.regex_fsm_state, next_token_ids[i]
|
537
|
+
)
|
538
|
+
|
480
539
|
if req.finished():
|
481
540
|
self.tree_cache.cache_finished_req(req)
|
482
|
-
|
541
|
+
elif req not in decoding_reqs:
|
542
|
+
# To reduce overhead, only cache prefill reqs
|
483
543
|
self.tree_cache.cache_unfinished_req(req)
|
484
544
|
|
485
545
|
if req is self.current_inflight_req:
|
@@ -487,12 +547,14 @@ class ModelTpServer:
|
|
487
547
|
self.req_to_token_pool.free(req.req_pool_idx)
|
488
548
|
|
489
549
|
if req.return_logprob:
|
490
|
-
self.add_logprob_return_values(
|
550
|
+
self.add_logprob_return_values(
|
551
|
+
i, req, pt, next_token_ids, logits_output
|
552
|
+
)
|
491
553
|
pt += req.extend_input_len
|
492
554
|
else:
|
493
555
|
assert batch.extend_num_tokens != 0
|
494
|
-
|
495
|
-
embeddings =
|
556
|
+
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
557
|
+
embeddings = logits_output.embeddings.tolist()
|
496
558
|
|
497
559
|
# Check finish conditions
|
498
560
|
for i, req in enumerate(batch.reqs):
|
@@ -520,7 +582,7 @@ class ModelTpServer:
|
|
520
582
|
req: Req,
|
521
583
|
pt: int,
|
522
584
|
next_token_ids: List[int],
|
523
|
-
output:
|
585
|
+
output: LogitsProcessorOutput,
|
524
586
|
):
|
525
587
|
if req.normalized_prompt_logprob is None:
|
526
588
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
@@ -579,7 +641,7 @@ class ModelTpServer:
|
|
579
641
|
self.new_token_ratio = new_token_ratio
|
580
642
|
|
581
643
|
logger.info(
|
582
|
-
"
|
644
|
+
"Decode out of memory happened. "
|
583
645
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
584
646
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
585
647
|
)
|
@@ -602,12 +664,17 @@ class ModelTpServer:
|
|
602
664
|
batch.prepare_for_decode()
|
603
665
|
|
604
666
|
# Forward and sample the next tokens
|
605
|
-
|
606
|
-
|
667
|
+
sample_output, logits_output = self.model_runner.forward(
|
668
|
+
batch, ForwardMode.DECODE
|
669
|
+
)
|
670
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
671
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
672
|
+
next_token_ids
|
673
|
+
)
|
607
674
|
|
608
675
|
# Move logprobs to cpu
|
609
|
-
if
|
610
|
-
next_token_logprobs =
|
676
|
+
if logits_output.next_token_logprobs is not None:
|
677
|
+
next_token_logprobs = logits_output.next_token_logprobs[
|
611
678
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
612
679
|
next_token_ids,
|
613
680
|
].tolist()
|
@@ -620,6 +687,11 @@ class ModelTpServer:
|
|
620
687
|
req.output_ids.append(next_token_id)
|
621
688
|
req.check_finished()
|
622
689
|
|
690
|
+
if req.regex_fsm is not None:
|
691
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
692
|
+
req.regex_fsm_state, next_token_id
|
693
|
+
)
|
694
|
+
|
623
695
|
if req.finished():
|
624
696
|
self.tree_cache.cache_finished_req(req)
|
625
697
|
|
@@ -628,7 +700,7 @@ class ModelTpServer:
|
|
628
700
|
(next_token_logprobs[i], next_token_id)
|
629
701
|
)
|
630
702
|
if req.top_logprobs_num > 0:
|
631
|
-
req.output_top_logprobs.append(
|
703
|
+
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
632
704
|
|
633
705
|
self.handle_finished_requests(batch)
|
634
706
|
|
@@ -743,12 +815,15 @@ class ModelTpServer:
|
|
743
815
|
self.token_to_kv_pool.clear()
|
744
816
|
torch.cuda.empty_cache()
|
745
817
|
logger.info("Cache flushed successfully!")
|
818
|
+
if_success = True
|
746
819
|
else:
|
747
|
-
|
820
|
+
logging.warning(
|
748
821
|
f"Cache not flushed because there are pending requests. "
|
749
822
|
f"#queue-req: {len(self.waiting_queue)}, "
|
750
823
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
751
824
|
)
|
825
|
+
if_success = False
|
826
|
+
return if_success
|
752
827
|
|
753
828
|
def abort_request(self, recv_req):
|
754
829
|
# Delete requests in the waiting queue
|
@@ -768,6 +843,15 @@ class ModelTpServer:
|
|
768
843
|
req.finished_reason = FINISH_ABORT()
|
769
844
|
break
|
770
845
|
|
846
|
+
def update_weights(self, recv_req):
|
847
|
+
success, message = self.model_runner.update_weights(
|
848
|
+
recv_req.model_path, recv_req.load_format
|
849
|
+
)
|
850
|
+
if success:
|
851
|
+
flash_cache_success = self.flush_cache()
|
852
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
853
|
+
return success, message
|
854
|
+
|
771
855
|
|
772
856
|
def run_tp_server(
|
773
857
|
gpu_id: int,
|
@@ -776,7 +860,9 @@ def run_tp_server(
|
|
776
860
|
nccl_port: int,
|
777
861
|
model_overide_args: dict,
|
778
862
|
):
|
779
|
-
"""Run a tensor parallel server."""
|
863
|
+
"""Run a tensor parallel model server."""
|
864
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
865
|
+
|
780
866
|
try:
|
781
867
|
model_server = ModelTpServer(
|
782
868
|
gpu_id,
|
@@ -832,6 +918,7 @@ def broadcast_recv_input(
|
|
832
918
|
|
833
919
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
834
920
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
921
|
+
return data
|
835
922
|
else:
|
836
923
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
837
924
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
@@ -16,7 +16,8 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from
|
19
|
+
from abc import ABC, abstractmethod
|
20
|
+
from typing import List, Tuple, Union
|
20
21
|
|
21
22
|
import torch
|
22
23
|
|
@@ -52,14 +53,21 @@ class ReqToTokenPool:
|
|
52
53
|
self.free_slots = list(range(self.size))
|
53
54
|
|
54
55
|
|
55
|
-
class BaseTokenToKVPool:
|
56
|
+
class BaseTokenToKVPool(ABC):
|
56
57
|
"""A memory pool that maps a token to its kv cache locations"""
|
57
58
|
|
58
59
|
def __init__(
|
59
60
|
self,
|
60
61
|
size: int,
|
62
|
+
dtype: torch.dtype,
|
61
63
|
):
|
62
64
|
self.size = size
|
65
|
+
self.dtype = dtype
|
66
|
+
if dtype == torch.float8_e5m2:
|
67
|
+
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
68
|
+
self.store_dtype = torch.uint8
|
69
|
+
else:
|
70
|
+
self.store_dtype = dtype
|
63
71
|
|
64
72
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
65
73
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
|
|
112
120
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
113
121
|
self.mem_state[0] = False
|
114
122
|
|
123
|
+
@abstractmethod
|
124
|
+
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
125
|
+
raise NotImplementedError()
|
126
|
+
|
127
|
+
@abstractmethod
|
128
|
+
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
129
|
+
raise NotImplementedError()
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
133
|
+
raise NotImplementedError()
|
134
|
+
|
135
|
+
@abstractmethod
|
136
|
+
def set_kv_buffer(
|
137
|
+
self,
|
138
|
+
layer_id: int,
|
139
|
+
loc: torch.Tensor,
|
140
|
+
cache_k: torch.Tensor,
|
141
|
+
cache_v: torch.Tensor,
|
142
|
+
) -> None:
|
143
|
+
raise NotImplementedError()
|
144
|
+
|
115
145
|
|
116
146
|
class MHATokenToKVPool(BaseTokenToKVPool):
|
117
147
|
|
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
123
153
|
head_dim: int,
|
124
154
|
layer_num: int,
|
125
155
|
):
|
126
|
-
super().__init__(size)
|
156
|
+
super().__init__(size, dtype)
|
127
157
|
|
128
158
|
# [size, head_num, head_dim] for each layer
|
129
159
|
self.k_buffer = [
|
130
|
-
torch.empty(
|
160
|
+
torch.empty(
|
161
|
+
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
162
|
+
)
|
131
163
|
for _ in range(layer_num)
|
132
164
|
]
|
133
165
|
self.v_buffer = [
|
134
|
-
torch.empty(
|
166
|
+
torch.empty(
|
167
|
+
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
168
|
+
)
|
135
169
|
for _ in range(layer_num)
|
136
170
|
]
|
137
171
|
|
138
172
|
def get_key_buffer(self, layer_id: int):
|
173
|
+
if self.store_dtype != self.dtype:
|
174
|
+
return self.k_buffer[layer_id].view(self.dtype)
|
139
175
|
return self.k_buffer[layer_id]
|
140
176
|
|
141
177
|
def get_value_buffer(self, layer_id: int):
|
178
|
+
if self.store_dtype != self.dtype:
|
179
|
+
return self.v_buffer[layer_id].view(self.dtype)
|
142
180
|
return self.v_buffer[layer_id]
|
143
181
|
|
144
182
|
def get_kv_buffer(self, layer_id: int):
|
145
|
-
return self.
|
183
|
+
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
184
|
+
|
185
|
+
def set_kv_buffer(
|
186
|
+
self,
|
187
|
+
layer_id: int,
|
188
|
+
loc: torch.Tensor,
|
189
|
+
cache_k: torch.Tensor,
|
190
|
+
cache_v: torch.Tensor,
|
191
|
+
):
|
192
|
+
if cache_k.dtype != self.dtype:
|
193
|
+
cache_k = cache_k.to(self.dtype)
|
194
|
+
if cache_v.dtype != self.dtype:
|
195
|
+
cache_v = cache_v.to(self.dtype)
|
196
|
+
if self.store_dtype != self.dtype:
|
197
|
+
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
198
|
+
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
199
|
+
else:
|
200
|
+
self.k_buffer[layer_id][loc] = cache_k
|
201
|
+
self.v_buffer[layer_id][loc] = cache_v
|
146
202
|
|
147
203
|
|
148
204
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
155
211
|
qk_rope_head_dim: int,
|
156
212
|
layer_num: int,
|
157
213
|
):
|
158
|
-
super().__init__(size)
|
214
|
+
super().__init__(size, dtype)
|
159
215
|
|
160
216
|
self.kv_lora_rank = kv_lora_rank
|
161
217
|
self.kv_buffer = [
|
162
218
|
torch.empty(
|
163
219
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
164
|
-
dtype=
|
220
|
+
dtype=self.store_dtype,
|
165
221
|
device="cuda",
|
166
222
|
)
|
167
223
|
for _ in range(layer_num)
|
168
224
|
]
|
169
225
|
|
170
226
|
def get_key_buffer(self, layer_id: int):
|
227
|
+
if self.store_dtype != self.dtype:
|
228
|
+
return self.kv_buffer[layer_id].view(self.dtype)
|
171
229
|
return self.kv_buffer[layer_id]
|
172
230
|
|
173
231
|
def get_value_buffer(self, layer_id: int):
|
232
|
+
if self.store_dtype != self.dtype:
|
233
|
+
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
174
234
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
175
235
|
|
176
236
|
def get_kv_buffer(self, layer_id: int):
|
177
237
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
238
|
+
|
239
|
+
def set_kv_buffer(
|
240
|
+
self,
|
241
|
+
layer_id: int,
|
242
|
+
loc: torch.Tensor,
|
243
|
+
cache_k: torch.Tensor,
|
244
|
+
cache_v: torch.Tensor,
|
245
|
+
):
|
246
|
+
if cache_k.dtype != self.dtype:
|
247
|
+
cache_k = cache_k.to(self.dtype)
|
248
|
+
if self.store_dtype != self.dtype:
|
249
|
+
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
250
|
+
else:
|
251
|
+
self.kv_buffer[layer_id][loc] = cache_k
|
sglang/srt/mm_utils.py
CHANGED
@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
# Source: https://github.com/
|
16
|
+
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
|
17
|
+
"""
|
18
|
+
Utilities for multi-modal models.
|
19
|
+
|
20
|
+
This python file mainly contains utilities that were used in the
|
21
|
+
image processing logic of llava-next including operations such as
|
22
|
+
anyres and anyres_max
|
23
|
+
|
24
|
+
Currently supports the anyres and anyres_max operation for CLIP and
|
25
|
+
SigLip. For more information, you may refer to the paper or the blog
|
26
|
+
|
27
|
+
LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
|
28
|
+
LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
|
29
|
+
|
30
|
+
"""
|
17
31
|
import ast
|
18
32
|
import base64
|
19
33
|
import math
|
34
|
+
import re
|
20
35
|
from io import BytesIO
|
21
36
|
|
22
37
|
import numpy as np
|
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
|
|
40
55
|
min_wasted_resolution = float("inf")
|
41
56
|
|
42
57
|
for width, height in possible_resolutions:
|
58
|
+
# Calculate the downscaled size to keep the aspect ratio
|
43
59
|
scale = min(width / original_width, height / original_height)
|
44
60
|
downscaled_width, downscaled_height = int(original_width * scale), int(
|
45
61
|
original_height * scale
|
46
62
|
)
|
63
|
+
|
64
|
+
# Calculate effective and wasted resolutions
|
47
65
|
effective_resolution = min(
|
48
66
|
downscaled_width * downscaled_height, original_width * original_height
|
49
67
|
)
|
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
129
147
|
Returns:
|
130
148
|
tuple: The shape of the image patch grid in the format (width, height).
|
131
149
|
"""
|
150
|
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
151
|
+
assert patch_size in [
|
152
|
+
224,
|
153
|
+
336,
|
154
|
+
384,
|
155
|
+
448,
|
156
|
+
512,
|
157
|
+
], "patch_size should be in [224, 336, 384, 448, 512]"
|
158
|
+
# Use regex to extract the range from the input string
|
159
|
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
160
|
+
range_start = tuple(map(int, matches[0]))
|
161
|
+
range_end = tuple(map(int, matches[-1]))
|
162
|
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
163
|
+
grid_pinpoints = [
|
164
|
+
(i, j)
|
165
|
+
for i in range(range_start[0], range_end[0] + 1)
|
166
|
+
for j in range(range_start[1], range_end[1] + 1)
|
167
|
+
]
|
168
|
+
# Multiply all elements by patch_size
|
169
|
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
132
170
|
if type(grid_pinpoints) is list:
|
133
171
|
possible_resolutions = grid_pinpoints
|
134
172
|
else:
|
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
149
187
|
Returns:
|
150
188
|
np.array: An np array containing the processed image patches.
|
151
189
|
"""
|
190
|
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
191
|
+
try:
|
192
|
+
patch_size = processor.size[0]
|
193
|
+
except Exception as e:
|
194
|
+
patch_size = processor.size["shortest_edge"]
|
195
|
+
assert patch_size in [
|
196
|
+
224,
|
197
|
+
336,
|
198
|
+
384,
|
199
|
+
448,
|
200
|
+
512,
|
201
|
+
], "patch_size should be in [224, 336, 384, 448, 512]"
|
202
|
+
# Use regex to extract the range from the input string
|
203
|
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
204
|
+
range_start = tuple(map(int, matches[0]))
|
205
|
+
range_end = tuple(map(int, matches[-1]))
|
206
|
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
207
|
+
grid_pinpoints = [
|
208
|
+
(i, j)
|
209
|
+
for i in range(range_start[0], range_end[0] + 1)
|
210
|
+
for j in range(range_start[1], range_end[1] + 1)
|
211
|
+
]
|
212
|
+
# Multiply all elements by patch_size
|
213
|
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
214
|
+
|
152
215
|
if type(grid_pinpoints) is list:
|
153
216
|
possible_resolutions = grid_pinpoints
|
154
217
|
else:
|
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
156
219
|
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
157
220
|
image_padded = resize_and_pad_image(image, best_resolution)
|
158
221
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
222
|
+
# For Siglip processor, only have size but no crop size
|
223
|
+
crop_size = (
|
224
|
+
processor.crop_size["height"]
|
225
|
+
if "crop_size" in processor.__dict__
|
226
|
+
else processor.size["height"]
|
163
227
|
)
|
228
|
+
shortest_edge = (
|
229
|
+
processor.size["shortest_edge"]
|
230
|
+
if "shortest_edge" in processor.size
|
231
|
+
else processor.size["height"]
|
232
|
+
)
|
233
|
+
patches = divide_to_patches(image_padded, crop_size)
|
234
|
+
|
235
|
+
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
164
236
|
|
165
237
|
image_patches = [image_original_resize] + patches
|
166
238
|
image_patches = [
|
167
|
-
processor.preprocess(image_patch)["pixel_values"][0]
|
239
|
+
processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
|
168
240
|
for image_patch in image_patches
|
169
241
|
]
|
170
242
|
return np.stack(image_patches, axis=0)
|
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
|
|
255
327
|
)
|
256
328
|
image = image_processor.preprocess(image)["pixel_values"][0]
|
257
329
|
new_images.append(image)
|
258
|
-
elif
|
330
|
+
elif "anyres" in image_aspect_ratio:
|
259
331
|
for image in images:
|
260
332
|
image = process_anyres_image(
|
261
333
|
image, image_processor, model_cfg.image_grid_pinpoints
|