sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
|
|
31
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
-
from sglang.srt.layers.logits_processor import
|
34
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
35
35
|
from sglang.srt.managers.io_struct import (
|
36
36
|
AbortReq,
|
37
37
|
BatchEmbeddingOut,
|
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
|
|
39
39
|
FlushCacheReq,
|
40
40
|
TokenizedEmbeddingReqInput,
|
41
41
|
TokenizedGenerateReqInput,
|
42
|
+
UpdateWeightReqInput,
|
43
|
+
UpdateWeightReqOutput,
|
42
44
|
)
|
43
45
|
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
44
46
|
from sglang.srt.managers.schedule_batch import (
|
@@ -54,7 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
54
56
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
55
57
|
from sglang.srt.server_args import ServerArgs
|
56
58
|
from sglang.srt.utils import (
|
57
|
-
|
59
|
+
configure_logger,
|
58
60
|
is_multimodal_model,
|
59
61
|
set_random_seed,
|
60
62
|
suppress_other_loggers,
|
@@ -86,10 +88,6 @@ class ModelTpServer:
|
|
86
88
|
self.schedule_policy = server_args.schedule_policy
|
87
89
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
88
90
|
|
89
|
-
# Chunked prefill
|
90
|
-
self.chunked_prefill_size = server_args.chunked_prefill_size
|
91
|
-
self.current_inflight_req = None
|
92
|
-
|
93
91
|
# Init model and tokenizer
|
94
92
|
self.model_config = ModelConfig(
|
95
93
|
server_args.model_path,
|
@@ -97,6 +95,7 @@ class ModelTpServer:
|
|
97
95
|
context_length=server_args.context_length,
|
98
96
|
model_overide_args=model_overide_args,
|
99
97
|
)
|
98
|
+
|
100
99
|
self.model_runner = ModelRunner(
|
101
100
|
model_config=self.model_config,
|
102
101
|
mem_fraction_static=server_args.mem_fraction_static,
|
@@ -132,18 +131,21 @@ class ModelTpServer:
|
|
132
131
|
),
|
133
132
|
self.model_runner.req_to_token_pool.size - 1,
|
134
133
|
)
|
135
|
-
self.int_token_logit_bias = torch.tensor(
|
136
|
-
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
137
|
-
)
|
138
134
|
self.max_req_input_len = min(
|
139
135
|
self.model_config.context_len - 1,
|
140
136
|
self.max_total_num_tokens - 1,
|
141
137
|
)
|
138
|
+
|
139
|
+
# Sync random seed
|
140
|
+
server_args.random_seed = broadcast_recv_input(
|
141
|
+
[server_args.random_seed],
|
142
|
+
self.tp_rank,
|
143
|
+
self.model_runner.tp_group.cpu_group,
|
144
|
+
)[0]
|
142
145
|
set_random_seed(server_args.random_seed)
|
143
146
|
|
144
147
|
# Print info
|
145
148
|
logger.info(
|
146
|
-
f"[gpu={self.gpu_id}] "
|
147
149
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
148
150
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
149
151
|
f"max_running_requests={self.max_running_requests}, "
|
@@ -179,6 +181,13 @@ class ModelTpServer:
|
|
179
181
|
self.num_generated_tokens = 0
|
180
182
|
self.last_stats_tic = time.time()
|
181
183
|
|
184
|
+
# Chunked prefill
|
185
|
+
self.chunked_prefill_size = server_args.chunked_prefill_size
|
186
|
+
self.current_inflight_req = None
|
187
|
+
self.is_mixed_chunk = (
|
188
|
+
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
189
|
+
)
|
190
|
+
|
182
191
|
# Init the FSM cache for constrained generation
|
183
192
|
if not server_args.skip_tokenizer_init:
|
184
193
|
self.regex_fsm_cache = FSMCache(
|
@@ -215,6 +224,9 @@ class ModelTpServer:
|
|
215
224
|
self.flush_cache()
|
216
225
|
elif isinstance(recv_req, AbortReq):
|
217
226
|
self.abort_request(recv_req)
|
227
|
+
elif isinstance(recv_req, UpdateWeightReqInput):
|
228
|
+
success, message = self.update_weights(recv_req)
|
229
|
+
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
218
230
|
else:
|
219
231
|
raise ValueError(f"Invalid request: {recv_req}")
|
220
232
|
|
@@ -272,7 +284,7 @@ class ModelTpServer:
|
|
272
284
|
self.num_generated_tokens = 0
|
273
285
|
self.last_stats_tic = time.time()
|
274
286
|
logger.info(
|
275
|
-
f"
|
287
|
+
f"Decode batch. "
|
276
288
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
277
289
|
f"#token: {num_used}, "
|
278
290
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
@@ -311,11 +323,16 @@ class ModelTpServer:
|
|
311
323
|
if self.model_runner.is_generation:
|
312
324
|
req.pixel_values = recv_req.pixel_values
|
313
325
|
if req.pixel_values is not None:
|
326
|
+
image_hash = (
|
327
|
+
hash(tuple(recv_req.image_hash))
|
328
|
+
if isinstance(recv_req.image_hash, list)
|
329
|
+
else recv_req.image_hash
|
330
|
+
)
|
314
331
|
req.pad_value = [
|
315
|
-
(
|
316
|
-
(
|
317
|
-
(
|
318
|
-
(
|
332
|
+
(image_hash) % self.model_config.vocab_size,
|
333
|
+
(image_hash >> 16) % self.model_config.vocab_size,
|
334
|
+
(image_hash >> 32) % self.model_config.vocab_size,
|
335
|
+
(image_hash >> 64) % self.model_config.vocab_size,
|
319
336
|
]
|
320
337
|
req.image_size = recv_req.image_size
|
321
338
|
(
|
@@ -370,11 +387,14 @@ class ModelTpServer:
|
|
370
387
|
# Get priority queue
|
371
388
|
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
372
389
|
|
390
|
+
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
391
|
+
|
373
392
|
adder = PrefillAdder(
|
374
393
|
self.tree_cache,
|
375
394
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
376
395
|
self.max_prefill_tokens,
|
377
396
|
self.chunked_prefill_size,
|
397
|
+
num_mixed_running,
|
378
398
|
)
|
379
399
|
|
380
400
|
if self.running_batch is not None:
|
@@ -420,15 +440,27 @@ class ModelTpServer:
|
|
420
440
|
)
|
421
441
|
else:
|
422
442
|
tree_cache_hit_rate = 0.0
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
443
|
+
|
444
|
+
if num_mixed_running > 0:
|
445
|
+
logger.info(
|
446
|
+
f"Prefill batch"
|
447
|
+
f"(mixed #running-req: {num_mixed_running}). "
|
448
|
+
f"#new-seq: {len(can_run_list)}, "
|
449
|
+
f"#new-token: {adder.log_input_tokens}, "
|
450
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
451
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
452
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
453
|
+
)
|
454
|
+
else:
|
455
|
+
logger.info(
|
456
|
+
f"Prefill batch. "
|
457
|
+
f"#new-seq: {len(can_run_list)}, "
|
458
|
+
f"#new-token: {adder.log_input_tokens}, "
|
459
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
460
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
461
|
+
f"#running-req: {running_bs}, "
|
462
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
463
|
+
)
|
432
464
|
|
433
465
|
# Return the new batch
|
434
466
|
new_batch = ScheduleBatch.init_new(
|
@@ -442,25 +474,41 @@ class ModelTpServer:
|
|
442
474
|
|
443
475
|
def forward_prefill_batch(self, batch: ScheduleBatch):
|
444
476
|
# Build batch tensors
|
445
|
-
batch.prepare_for_extend(
|
446
|
-
|
447
|
-
|
477
|
+
batch.prepare_for_extend(self.model_config.vocab_size)
|
478
|
+
|
479
|
+
decoding_reqs = []
|
480
|
+
if self.is_mixed_chunk and self.running_batch is not None:
|
481
|
+
self.running_batch.prepare_for_decode()
|
482
|
+
batch.mix_with_running(self.running_batch)
|
483
|
+
decoding_reqs = self.running_batch.reqs
|
484
|
+
self.running_batch = None
|
448
485
|
|
449
486
|
if self.model_runner.is_generation:
|
450
487
|
# Forward and sample the next tokens
|
451
488
|
if batch.extend_num_tokens != 0:
|
452
|
-
|
453
|
-
|
489
|
+
sample_output, logits_output = self.model_runner.forward(
|
490
|
+
batch, ForwardMode.EXTEND
|
491
|
+
)
|
492
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
493
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
494
|
+
next_token_ids
|
495
|
+
)
|
454
496
|
|
455
497
|
# Move logprobs to cpu
|
456
|
-
if
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
498
|
+
if logits_output.next_token_logprobs is not None:
|
499
|
+
logits_output.next_token_logprobs = (
|
500
|
+
logits_output.next_token_logprobs[
|
501
|
+
torch.arange(
|
502
|
+
len(next_token_ids), device=next_token_ids.device
|
503
|
+
),
|
504
|
+
next_token_ids,
|
505
|
+
].tolist()
|
506
|
+
)
|
507
|
+
logits_output.input_token_logprobs = (
|
508
|
+
logits_output.input_token_logprobs.tolist()
|
509
|
+
)
|
510
|
+
logits_output.normalized_prompt_logprobs = (
|
511
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
464
512
|
)
|
465
513
|
|
466
514
|
next_token_ids = next_token_ids.tolist()
|
@@ -483,9 +531,15 @@ class ModelTpServer:
|
|
483
531
|
req.output_ids.append(next_token_ids[i])
|
484
532
|
req.check_finished()
|
485
533
|
|
534
|
+
if req.regex_fsm is not None:
|
535
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
536
|
+
req.regex_fsm_state, next_token_ids[i]
|
537
|
+
)
|
538
|
+
|
486
539
|
if req.finished():
|
487
540
|
self.tree_cache.cache_finished_req(req)
|
488
|
-
|
541
|
+
elif req not in decoding_reqs:
|
542
|
+
# To reduce overhead, only cache prefill reqs
|
489
543
|
self.tree_cache.cache_unfinished_req(req)
|
490
544
|
|
491
545
|
if req is self.current_inflight_req:
|
@@ -493,12 +547,14 @@ class ModelTpServer:
|
|
493
547
|
self.req_to_token_pool.free(req.req_pool_idx)
|
494
548
|
|
495
549
|
if req.return_logprob:
|
496
|
-
self.add_logprob_return_values(
|
550
|
+
self.add_logprob_return_values(
|
551
|
+
i, req, pt, next_token_ids, logits_output
|
552
|
+
)
|
497
553
|
pt += req.extend_input_len
|
498
554
|
else:
|
499
555
|
assert batch.extend_num_tokens != 0
|
500
|
-
|
501
|
-
embeddings =
|
556
|
+
logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
557
|
+
embeddings = logits_output.embeddings.tolist()
|
502
558
|
|
503
559
|
# Check finish conditions
|
504
560
|
for i, req in enumerate(batch.reqs):
|
@@ -526,7 +582,7 @@ class ModelTpServer:
|
|
526
582
|
req: Req,
|
527
583
|
pt: int,
|
528
584
|
next_token_ids: List[int],
|
529
|
-
output:
|
585
|
+
output: LogitsProcessorOutput,
|
530
586
|
):
|
531
587
|
if req.normalized_prompt_logprob is None:
|
532
588
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
@@ -585,7 +641,7 @@ class ModelTpServer:
|
|
585
641
|
self.new_token_ratio = new_token_ratio
|
586
642
|
|
587
643
|
logger.info(
|
588
|
-
"
|
644
|
+
"Decode out of memory happened. "
|
589
645
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
590
646
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
591
647
|
)
|
@@ -608,12 +664,17 @@ class ModelTpServer:
|
|
608
664
|
batch.prepare_for_decode()
|
609
665
|
|
610
666
|
# Forward and sample the next tokens
|
611
|
-
|
612
|
-
|
667
|
+
sample_output, logits_output = self.model_runner.forward(
|
668
|
+
batch, ForwardMode.DECODE
|
669
|
+
)
|
670
|
+
next_token_ids = batch.check_sample_results(sample_output)
|
671
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
672
|
+
next_token_ids
|
673
|
+
)
|
613
674
|
|
614
675
|
# Move logprobs to cpu
|
615
|
-
if
|
616
|
-
next_token_logprobs =
|
676
|
+
if logits_output.next_token_logprobs is not None:
|
677
|
+
next_token_logprobs = logits_output.next_token_logprobs[
|
617
678
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
618
679
|
next_token_ids,
|
619
680
|
].tolist()
|
@@ -626,6 +687,11 @@ class ModelTpServer:
|
|
626
687
|
req.output_ids.append(next_token_id)
|
627
688
|
req.check_finished()
|
628
689
|
|
690
|
+
if req.regex_fsm is not None:
|
691
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
692
|
+
req.regex_fsm_state, next_token_id
|
693
|
+
)
|
694
|
+
|
629
695
|
if req.finished():
|
630
696
|
self.tree_cache.cache_finished_req(req)
|
631
697
|
|
@@ -634,7 +700,7 @@ class ModelTpServer:
|
|
634
700
|
(next_token_logprobs[i], next_token_id)
|
635
701
|
)
|
636
702
|
if req.top_logprobs_num > 0:
|
637
|
-
req.output_top_logprobs.append(
|
703
|
+
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
638
704
|
|
639
705
|
self.handle_finished_requests(batch)
|
640
706
|
|
@@ -749,12 +815,15 @@ class ModelTpServer:
|
|
749
815
|
self.token_to_kv_pool.clear()
|
750
816
|
torch.cuda.empty_cache()
|
751
817
|
logger.info("Cache flushed successfully!")
|
818
|
+
if_success = True
|
752
819
|
else:
|
753
|
-
|
820
|
+
logging.warning(
|
754
821
|
f"Cache not flushed because there are pending requests. "
|
755
822
|
f"#queue-req: {len(self.waiting_queue)}, "
|
756
823
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
757
824
|
)
|
825
|
+
if_success = False
|
826
|
+
return if_success
|
758
827
|
|
759
828
|
def abort_request(self, recv_req):
|
760
829
|
# Delete requests in the waiting queue
|
@@ -774,6 +843,15 @@ class ModelTpServer:
|
|
774
843
|
req.finished_reason = FINISH_ABORT()
|
775
844
|
break
|
776
845
|
|
846
|
+
def update_weights(self, recv_req):
|
847
|
+
success, message = self.model_runner.update_weights(
|
848
|
+
recv_req.model_path, recv_req.load_format
|
849
|
+
)
|
850
|
+
if success:
|
851
|
+
flash_cache_success = self.flush_cache()
|
852
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
853
|
+
return success, message
|
854
|
+
|
777
855
|
|
778
856
|
def run_tp_server(
|
779
857
|
gpu_id: int,
|
@@ -782,7 +860,9 @@ def run_tp_server(
|
|
782
860
|
nccl_port: int,
|
783
861
|
model_overide_args: dict,
|
784
862
|
):
|
785
|
-
"""Run a tensor parallel server."""
|
863
|
+
"""Run a tensor parallel model server."""
|
864
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
865
|
+
|
786
866
|
try:
|
787
867
|
model_server = ModelTpServer(
|
788
868
|
gpu_id,
|
@@ -838,6 +918,7 @@ def broadcast_recv_input(
|
|
838
918
|
|
839
919
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
840
920
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
921
|
+
return data
|
841
922
|
else:
|
842
923
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
843
924
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
@@ -16,7 +16,8 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from
|
19
|
+
from abc import ABC, abstractmethod
|
20
|
+
from typing import List, Tuple, Union
|
20
21
|
|
21
22
|
import torch
|
22
23
|
|
@@ -52,14 +53,21 @@ class ReqToTokenPool:
|
|
52
53
|
self.free_slots = list(range(self.size))
|
53
54
|
|
54
55
|
|
55
|
-
class BaseTokenToKVPool:
|
56
|
+
class BaseTokenToKVPool(ABC):
|
56
57
|
"""A memory pool that maps a token to its kv cache locations"""
|
57
58
|
|
58
59
|
def __init__(
|
59
60
|
self,
|
60
61
|
size: int,
|
62
|
+
dtype: torch.dtype,
|
61
63
|
):
|
62
64
|
self.size = size
|
65
|
+
self.dtype = dtype
|
66
|
+
if dtype == torch.float8_e5m2:
|
67
|
+
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
68
|
+
self.store_dtype = torch.uint8
|
69
|
+
else:
|
70
|
+
self.store_dtype = dtype
|
63
71
|
|
64
72
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
65
73
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
|
|
112
120
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
113
121
|
self.mem_state[0] = False
|
114
122
|
|
123
|
+
@abstractmethod
|
124
|
+
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
125
|
+
raise NotImplementedError()
|
126
|
+
|
127
|
+
@abstractmethod
|
128
|
+
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
129
|
+
raise NotImplementedError()
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
133
|
+
raise NotImplementedError()
|
134
|
+
|
135
|
+
@abstractmethod
|
136
|
+
def set_kv_buffer(
|
137
|
+
self,
|
138
|
+
layer_id: int,
|
139
|
+
loc: torch.Tensor,
|
140
|
+
cache_k: torch.Tensor,
|
141
|
+
cache_v: torch.Tensor,
|
142
|
+
) -> None:
|
143
|
+
raise NotImplementedError()
|
144
|
+
|
115
145
|
|
116
146
|
class MHATokenToKVPool(BaseTokenToKVPool):
|
117
147
|
|
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
123
153
|
head_dim: int,
|
124
154
|
layer_num: int,
|
125
155
|
):
|
126
|
-
super().__init__(size)
|
156
|
+
super().__init__(size, dtype)
|
127
157
|
|
128
158
|
# [size, head_num, head_dim] for each layer
|
129
159
|
self.k_buffer = [
|
130
|
-
torch.empty(
|
160
|
+
torch.empty(
|
161
|
+
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
162
|
+
)
|
131
163
|
for _ in range(layer_num)
|
132
164
|
]
|
133
165
|
self.v_buffer = [
|
134
|
-
torch.empty(
|
166
|
+
torch.empty(
|
167
|
+
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
168
|
+
)
|
135
169
|
for _ in range(layer_num)
|
136
170
|
]
|
137
171
|
|
138
172
|
def get_key_buffer(self, layer_id: int):
|
173
|
+
if self.store_dtype != self.dtype:
|
174
|
+
return self.k_buffer[layer_id].view(self.dtype)
|
139
175
|
return self.k_buffer[layer_id]
|
140
176
|
|
141
177
|
def get_value_buffer(self, layer_id: int):
|
178
|
+
if self.store_dtype != self.dtype:
|
179
|
+
return self.v_buffer[layer_id].view(self.dtype)
|
142
180
|
return self.v_buffer[layer_id]
|
143
181
|
|
144
182
|
def get_kv_buffer(self, layer_id: int):
|
145
|
-
return self.
|
183
|
+
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
184
|
+
|
185
|
+
def set_kv_buffer(
|
186
|
+
self,
|
187
|
+
layer_id: int,
|
188
|
+
loc: torch.Tensor,
|
189
|
+
cache_k: torch.Tensor,
|
190
|
+
cache_v: torch.Tensor,
|
191
|
+
):
|
192
|
+
if cache_k.dtype != self.dtype:
|
193
|
+
cache_k = cache_k.to(self.dtype)
|
194
|
+
if cache_v.dtype != self.dtype:
|
195
|
+
cache_v = cache_v.to(self.dtype)
|
196
|
+
if self.store_dtype != self.dtype:
|
197
|
+
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
198
|
+
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
199
|
+
else:
|
200
|
+
self.k_buffer[layer_id][loc] = cache_k
|
201
|
+
self.v_buffer[layer_id][loc] = cache_v
|
146
202
|
|
147
203
|
|
148
204
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
155
211
|
qk_rope_head_dim: int,
|
156
212
|
layer_num: int,
|
157
213
|
):
|
158
|
-
super().__init__(size)
|
214
|
+
super().__init__(size, dtype)
|
159
215
|
|
160
216
|
self.kv_lora_rank = kv_lora_rank
|
161
217
|
self.kv_buffer = [
|
162
218
|
torch.empty(
|
163
219
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
164
|
-
dtype=
|
220
|
+
dtype=self.store_dtype,
|
165
221
|
device="cuda",
|
166
222
|
)
|
167
223
|
for _ in range(layer_num)
|
168
224
|
]
|
169
225
|
|
170
226
|
def get_key_buffer(self, layer_id: int):
|
227
|
+
if self.store_dtype != self.dtype:
|
228
|
+
return self.kv_buffer[layer_id].view(self.dtype)
|
171
229
|
return self.kv_buffer[layer_id]
|
172
230
|
|
173
231
|
def get_value_buffer(self, layer_id: int):
|
232
|
+
if self.store_dtype != self.dtype:
|
233
|
+
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
174
234
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
175
235
|
|
176
236
|
def get_kv_buffer(self, layer_id: int):
|
177
237
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
238
|
+
|
239
|
+
def set_kv_buffer(
|
240
|
+
self,
|
241
|
+
layer_id: int,
|
242
|
+
loc: torch.Tensor,
|
243
|
+
cache_k: torch.Tensor,
|
244
|
+
cache_v: torch.Tensor,
|
245
|
+
):
|
246
|
+
if cache_k.dtype != self.dtype:
|
247
|
+
cache_k = cache_k.to(self.dtype)
|
248
|
+
if self.store_dtype != self.dtype:
|
249
|
+
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
250
|
+
else:
|
251
|
+
self.kv_buffer[layer_id][loc] = cache_k
|
sglang/srt/mm_utils.py
CHANGED
@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
# Source: https://github.com/
|
16
|
+
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
|
17
|
+
"""
|
18
|
+
Utilities for multi-modal models.
|
19
|
+
|
20
|
+
This python file mainly contains utilities that were used in the
|
21
|
+
image processing logic of llava-next including operations such as
|
22
|
+
anyres and anyres_max
|
23
|
+
|
24
|
+
Currently supports the anyres and anyres_max operation for CLIP and
|
25
|
+
SigLip. For more information, you may refer to the paper or the blog
|
26
|
+
|
27
|
+
LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
|
28
|
+
LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
|
29
|
+
|
30
|
+
"""
|
17
31
|
import ast
|
18
32
|
import base64
|
19
33
|
import math
|
34
|
+
import re
|
20
35
|
from io import BytesIO
|
21
36
|
|
22
37
|
import numpy as np
|
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
|
|
40
55
|
min_wasted_resolution = float("inf")
|
41
56
|
|
42
57
|
for width, height in possible_resolutions:
|
58
|
+
# Calculate the downscaled size to keep the aspect ratio
|
43
59
|
scale = min(width / original_width, height / original_height)
|
44
60
|
downscaled_width, downscaled_height = int(original_width * scale), int(
|
45
61
|
original_height * scale
|
46
62
|
)
|
63
|
+
|
64
|
+
# Calculate effective and wasted resolutions
|
47
65
|
effective_resolution = min(
|
48
66
|
downscaled_width * downscaled_height, original_width * original_height
|
49
67
|
)
|
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
129
147
|
Returns:
|
130
148
|
tuple: The shape of the image patch grid in the format (width, height).
|
131
149
|
"""
|
150
|
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
151
|
+
assert patch_size in [
|
152
|
+
224,
|
153
|
+
336,
|
154
|
+
384,
|
155
|
+
448,
|
156
|
+
512,
|
157
|
+
], "patch_size should be in [224, 336, 384, 448, 512]"
|
158
|
+
# Use regex to extract the range from the input string
|
159
|
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
160
|
+
range_start = tuple(map(int, matches[0]))
|
161
|
+
range_end = tuple(map(int, matches[-1]))
|
162
|
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
163
|
+
grid_pinpoints = [
|
164
|
+
(i, j)
|
165
|
+
for i in range(range_start[0], range_end[0] + 1)
|
166
|
+
for j in range(range_start[1], range_end[1] + 1)
|
167
|
+
]
|
168
|
+
# Multiply all elements by patch_size
|
169
|
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
132
170
|
if type(grid_pinpoints) is list:
|
133
171
|
possible_resolutions = grid_pinpoints
|
134
172
|
else:
|
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
149
187
|
Returns:
|
150
188
|
np.array: An np array containing the processed image patches.
|
151
189
|
"""
|
190
|
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
191
|
+
try:
|
192
|
+
patch_size = processor.size[0]
|
193
|
+
except Exception as e:
|
194
|
+
patch_size = processor.size["shortest_edge"]
|
195
|
+
assert patch_size in [
|
196
|
+
224,
|
197
|
+
336,
|
198
|
+
384,
|
199
|
+
448,
|
200
|
+
512,
|
201
|
+
], "patch_size should be in [224, 336, 384, 448, 512]"
|
202
|
+
# Use regex to extract the range from the input string
|
203
|
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
204
|
+
range_start = tuple(map(int, matches[0]))
|
205
|
+
range_end = tuple(map(int, matches[-1]))
|
206
|
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
207
|
+
grid_pinpoints = [
|
208
|
+
(i, j)
|
209
|
+
for i in range(range_start[0], range_end[0] + 1)
|
210
|
+
for j in range(range_start[1], range_end[1] + 1)
|
211
|
+
]
|
212
|
+
# Multiply all elements by patch_size
|
213
|
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
214
|
+
|
152
215
|
if type(grid_pinpoints) is list:
|
153
216
|
possible_resolutions = grid_pinpoints
|
154
217
|
else:
|
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
156
219
|
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
157
220
|
image_padded = resize_and_pad_image(image, best_resolution)
|
158
221
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
222
|
+
# For Siglip processor, only have size but no crop size
|
223
|
+
crop_size = (
|
224
|
+
processor.crop_size["height"]
|
225
|
+
if "crop_size" in processor.__dict__
|
226
|
+
else processor.size["height"]
|
163
227
|
)
|
228
|
+
shortest_edge = (
|
229
|
+
processor.size["shortest_edge"]
|
230
|
+
if "shortest_edge" in processor.size
|
231
|
+
else processor.size["height"]
|
232
|
+
)
|
233
|
+
patches = divide_to_patches(image_padded, crop_size)
|
234
|
+
|
235
|
+
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
164
236
|
|
165
237
|
image_patches = [image_original_resize] + patches
|
166
238
|
image_patches = [
|
167
|
-
processor.preprocess(image_patch)["pixel_values"][0]
|
239
|
+
processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
|
168
240
|
for image_patch in image_patches
|
169
241
|
]
|
170
242
|
return np.stack(image_patches, axis=0)
|
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
|
|
255
327
|
)
|
256
328
|
image = image_processor.preprocess(image)["pixel_values"][0]
|
257
329
|
new_images.append(image)
|
258
|
-
elif
|
330
|
+
elif "anyres" in image_aspect_ratio:
|
259
331
|
for image in images:
|
260
332
|
image = process_anyres_image(
|
261
333
|
image, image_processor, model_cfg.image_grid_pinpoints
|