sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +6 -0
- sglang/bench_latency.py +7 -3
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +9 -0
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +100 -1
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/fused_moe/layer.py +2 -2
- sglang/srt/layers/logits_processor.py +56 -19
- sglang/srt/layers/radix_attention.py +3 -4
- sglang/srt/layers/sampler.py +101 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +46 -166
- sglang/srt/managers/tokenizer_manager.py +192 -83
- sglang/srt/managers/tp_worker.py +118 -24
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +32 -8
- sglang/srt/model_executor/forward_batch_info.py +51 -26
- sglang/srt/model_executor/model_runner.py +201 -58
- sglang/srt/models/gemma2.py +10 -6
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +11 -1
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/qwen2.py +9 -3
- sglang/srt/openai_api/adapter.py +200 -39
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +136 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
- sglang/srt/server.py +92 -57
- sglang/srt/server_args.py +43 -15
- sglang/srt/utils.py +26 -16
- sglang/test/runners.py +22 -30
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_utils.py +36 -53
- sglang/version.py +1 -1
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
- sglang-0.2.14.post1.dist-info/RECORD +114 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang-0.2.13.dist-info/RECORD +0 -112
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
|
|
39
39
|
FlushCacheReq,
|
40
40
|
TokenizedEmbeddingReqInput,
|
41
41
|
TokenizedGenerateReqInput,
|
42
|
+
UpdateWeightReqInput,
|
43
|
+
UpdateWeightReqOutput,
|
42
44
|
)
|
43
45
|
from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
|
44
46
|
from sglang.srt.managers.schedule_batch import (
|
@@ -54,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
54
56
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
55
57
|
from sglang.srt.server_args import ServerArgs
|
56
58
|
from sglang.srt.utils import (
|
59
|
+
configure_logger,
|
57
60
|
is_multimodal_model,
|
58
61
|
set_random_seed,
|
59
62
|
suppress_other_loggers,
|
@@ -85,10 +88,6 @@ class ModelTpServer:
|
|
85
88
|
self.schedule_policy = server_args.schedule_policy
|
86
89
|
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
|
87
90
|
|
88
|
-
# Chunked prefill
|
89
|
-
self.chunked_prefill_size = server_args.chunked_prefill_size
|
90
|
-
self.current_inflight_req = None
|
91
|
-
|
92
91
|
# Init model and tokenizer
|
93
92
|
self.model_config = ModelConfig(
|
94
93
|
server_args.model_path,
|
@@ -96,6 +95,7 @@ class ModelTpServer:
|
|
96
95
|
context_length=server_args.context_length,
|
97
96
|
model_overide_args=model_overide_args,
|
98
97
|
)
|
98
|
+
|
99
99
|
self.model_runner = ModelRunner(
|
100
100
|
model_config=self.model_config,
|
101
101
|
mem_fraction_static=server_args.mem_fraction_static,
|
@@ -135,11 +135,17 @@ class ModelTpServer:
|
|
135
135
|
self.model_config.context_len - 1,
|
136
136
|
self.max_total_num_tokens - 1,
|
137
137
|
)
|
138
|
+
|
139
|
+
# Sync random seed
|
140
|
+
server_args.random_seed = broadcast_recv_input(
|
141
|
+
[server_args.random_seed],
|
142
|
+
self.tp_rank,
|
143
|
+
self.model_runner.tp_group.cpu_group,
|
144
|
+
)[0]
|
138
145
|
set_random_seed(server_args.random_seed)
|
139
146
|
|
140
147
|
# Print info
|
141
148
|
logger.info(
|
142
|
-
f"[gpu={self.gpu_id}] "
|
143
149
|
f"max_total_num_tokens={self.max_total_num_tokens}, "
|
144
150
|
f"max_prefill_tokens={self.max_prefill_tokens}, "
|
145
151
|
f"max_running_requests={self.max_running_requests}, "
|
@@ -175,6 +181,13 @@ class ModelTpServer:
|
|
175
181
|
self.num_generated_tokens = 0
|
176
182
|
self.last_stats_tic = time.time()
|
177
183
|
|
184
|
+
# Chunked prefill
|
185
|
+
self.chunked_prefill_size = server_args.chunked_prefill_size
|
186
|
+
self.current_inflight_req = None
|
187
|
+
self.is_mixed_chunk = (
|
188
|
+
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
|
189
|
+
)
|
190
|
+
|
178
191
|
# Init the FSM cache for constrained generation
|
179
192
|
if not server_args.skip_tokenizer_init:
|
180
193
|
self.regex_fsm_cache = FSMCache(
|
@@ -184,6 +197,16 @@ class ModelTpServer:
|
|
184
197
|
"trust_remote_code": server_args.trust_remote_code,
|
185
198
|
},
|
186
199
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
200
|
+
json_schema_mode=False,
|
201
|
+
)
|
202
|
+
self.json_fsm_cache = FSMCache(
|
203
|
+
server_args.tokenizer_path,
|
204
|
+
{
|
205
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
206
|
+
"trust_remote_code": server_args.trust_remote_code,
|
207
|
+
},
|
208
|
+
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
209
|
+
json_schema_mode=True,
|
187
210
|
)
|
188
211
|
self.jump_forward_cache = JumpForwardCache()
|
189
212
|
|
@@ -211,6 +234,9 @@ class ModelTpServer:
|
|
211
234
|
self.flush_cache()
|
212
235
|
elif isinstance(recv_req, AbortReq):
|
213
236
|
self.abort_request(recv_req)
|
237
|
+
elif isinstance(recv_req, UpdateWeightReqInput):
|
238
|
+
success, message = self.update_weights(recv_req)
|
239
|
+
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
|
214
240
|
else:
|
215
241
|
raise ValueError(f"Invalid request: {recv_req}")
|
216
242
|
|
@@ -268,7 +294,7 @@ class ModelTpServer:
|
|
268
294
|
self.num_generated_tokens = 0
|
269
295
|
self.last_stats_tic = time.time()
|
270
296
|
logger.info(
|
271
|
-
f"
|
297
|
+
f"Decode batch. "
|
272
298
|
f"#running-req: {len(self.running_batch.reqs)}, "
|
273
299
|
f"#token: {num_used}, "
|
274
300
|
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
@@ -307,11 +333,16 @@ class ModelTpServer:
|
|
307
333
|
if self.model_runner.is_generation:
|
308
334
|
req.pixel_values = recv_req.pixel_values
|
309
335
|
if req.pixel_values is not None:
|
336
|
+
image_hash = (
|
337
|
+
hash(tuple(recv_req.image_hash))
|
338
|
+
if isinstance(recv_req.image_hash, list)
|
339
|
+
else recv_req.image_hash
|
340
|
+
)
|
310
341
|
req.pad_value = [
|
311
|
-
(
|
312
|
-
(
|
313
|
-
(
|
314
|
-
(
|
342
|
+
(image_hash) % self.model_config.vocab_size,
|
343
|
+
(image_hash >> 16) % self.model_config.vocab_size,
|
344
|
+
(image_hash >> 32) % self.model_config.vocab_size,
|
345
|
+
(image_hash >> 64) % self.model_config.vocab_size,
|
315
346
|
]
|
316
347
|
req.image_size = recv_req.image_size
|
317
348
|
(
|
@@ -328,8 +359,17 @@ class ModelTpServer:
|
|
328
359
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
329
360
|
req.stream = recv_req.stream
|
330
361
|
|
362
|
+
# Init regex fsm fron json
|
363
|
+
if req.sampling_params.json_schema is not None:
|
364
|
+
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
|
365
|
+
req.sampling_params.json_schema
|
366
|
+
)
|
367
|
+
if not self.disable_regex_jump_forward:
|
368
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
369
|
+
computed_regex_string
|
370
|
+
)
|
331
371
|
# Init regex fsm
|
332
|
-
|
372
|
+
elif req.sampling_params.regex is not None:
|
333
373
|
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
334
374
|
if not self.disable_regex_jump_forward:
|
335
375
|
req.jump_forward_map = self.jump_forward_cache.query(
|
@@ -366,11 +406,14 @@ class ModelTpServer:
|
|
366
406
|
# Get priority queue
|
367
407
|
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
|
368
408
|
|
409
|
+
num_mixed_running = running_bs if self.is_mixed_chunk else 0
|
410
|
+
|
369
411
|
adder = PrefillAdder(
|
370
412
|
self.tree_cache,
|
371
413
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
|
372
414
|
self.max_prefill_tokens,
|
373
415
|
self.chunked_prefill_size,
|
416
|
+
num_mixed_running,
|
374
417
|
)
|
375
418
|
|
376
419
|
if self.running_batch is not None:
|
@@ -416,15 +459,27 @@ class ModelTpServer:
|
|
416
459
|
)
|
417
460
|
else:
|
418
461
|
tree_cache_hit_rate = 0.0
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
462
|
+
|
463
|
+
if num_mixed_running > 0:
|
464
|
+
logger.info(
|
465
|
+
f"Prefill batch"
|
466
|
+
f"(mixed #running-req: {num_mixed_running}). "
|
467
|
+
f"#new-seq: {len(can_run_list)}, "
|
468
|
+
f"#new-token: {adder.log_input_tokens}, "
|
469
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
470
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
471
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
472
|
+
)
|
473
|
+
else:
|
474
|
+
logger.info(
|
475
|
+
f"Prefill batch. "
|
476
|
+
f"#new-seq: {len(can_run_list)}, "
|
477
|
+
f"#new-token: {adder.log_input_tokens}, "
|
478
|
+
f"#cached-token: {adder.log_hit_tokens}, "
|
479
|
+
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
|
480
|
+
f"#running-req: {running_bs}, "
|
481
|
+
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
|
482
|
+
)
|
428
483
|
|
429
484
|
# Return the new batch
|
430
485
|
new_batch = ScheduleBatch.init_new(
|
@@ -440,11 +495,21 @@ class ModelTpServer:
|
|
440
495
|
# Build batch tensors
|
441
496
|
batch.prepare_for_extend(self.model_config.vocab_size)
|
442
497
|
|
498
|
+
decoding_reqs = []
|
499
|
+
if self.is_mixed_chunk and self.running_batch is not None:
|
500
|
+
self.running_batch.prepare_for_decode()
|
501
|
+
batch.mix_with_running(self.running_batch)
|
502
|
+
decoding_reqs = self.running_batch.reqs
|
503
|
+
self.running_batch = None
|
504
|
+
|
443
505
|
if self.model_runner.is_generation:
|
444
506
|
# Forward and sample the next tokens
|
445
507
|
if batch.extend_num_tokens != 0:
|
446
508
|
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
447
509
|
next_token_ids = batch.sample(output.next_token_logits)
|
510
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
511
|
+
next_token_ids
|
512
|
+
)
|
448
513
|
|
449
514
|
# Move logprobs to cpu
|
450
515
|
if output.next_token_logprobs is not None:
|
@@ -477,9 +542,15 @@ class ModelTpServer:
|
|
477
542
|
req.output_ids.append(next_token_ids[i])
|
478
543
|
req.check_finished()
|
479
544
|
|
545
|
+
if req.regex_fsm is not None:
|
546
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
547
|
+
req.regex_fsm_state, next_token_ids[i]
|
548
|
+
)
|
549
|
+
|
480
550
|
if req.finished():
|
481
551
|
self.tree_cache.cache_finished_req(req)
|
482
|
-
|
552
|
+
elif req not in decoding_reqs:
|
553
|
+
# To reduce overhead, only cache prefill reqs
|
483
554
|
self.tree_cache.cache_unfinished_req(req)
|
484
555
|
|
485
556
|
if req is self.current_inflight_req:
|
@@ -579,7 +650,7 @@ class ModelTpServer:
|
|
579
650
|
self.new_token_ratio = new_token_ratio
|
580
651
|
|
581
652
|
logger.info(
|
582
|
-
"
|
653
|
+
"Decode out of memory happened. "
|
583
654
|
f"#retracted_reqs: {len(retracted_reqs)}, "
|
584
655
|
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
|
585
656
|
)
|
@@ -604,6 +675,9 @@ class ModelTpServer:
|
|
604
675
|
# Forward and sample the next tokens
|
605
676
|
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
606
677
|
next_token_ids = batch.sample(output.next_token_logits)
|
678
|
+
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
679
|
+
next_token_ids
|
680
|
+
)
|
607
681
|
|
608
682
|
# Move logprobs to cpu
|
609
683
|
if output.next_token_logprobs is not None:
|
@@ -620,6 +694,11 @@ class ModelTpServer:
|
|
620
694
|
req.output_ids.append(next_token_id)
|
621
695
|
req.check_finished()
|
622
696
|
|
697
|
+
if req.regex_fsm is not None:
|
698
|
+
req.regex_fsm_state = req.regex_fsm.get_next_state(
|
699
|
+
req.regex_fsm_state, next_token_id
|
700
|
+
)
|
701
|
+
|
623
702
|
if req.finished():
|
624
703
|
self.tree_cache.cache_finished_req(req)
|
625
704
|
|
@@ -743,12 +822,15 @@ class ModelTpServer:
|
|
743
822
|
self.token_to_kv_pool.clear()
|
744
823
|
torch.cuda.empty_cache()
|
745
824
|
logger.info("Cache flushed successfully!")
|
825
|
+
if_success = True
|
746
826
|
else:
|
747
|
-
|
827
|
+
logging.warning(
|
748
828
|
f"Cache not flushed because there are pending requests. "
|
749
829
|
f"#queue-req: {len(self.waiting_queue)}, "
|
750
830
|
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
|
751
831
|
)
|
832
|
+
if_success = False
|
833
|
+
return if_success
|
752
834
|
|
753
835
|
def abort_request(self, recv_req):
|
754
836
|
# Delete requests in the waiting queue
|
@@ -768,6 +850,15 @@ class ModelTpServer:
|
|
768
850
|
req.finished_reason = FINISH_ABORT()
|
769
851
|
break
|
770
852
|
|
853
|
+
def update_weights(self, recv_req):
|
854
|
+
success, message = self.model_runner.update_weights(
|
855
|
+
recv_req.model_path, recv_req.load_format
|
856
|
+
)
|
857
|
+
if success:
|
858
|
+
flash_cache_success = self.flush_cache()
|
859
|
+
assert flash_cache_success, "Cache flush failed after updating weights"
|
860
|
+
return success, message
|
861
|
+
|
771
862
|
|
772
863
|
def run_tp_server(
|
773
864
|
gpu_id: int,
|
@@ -776,7 +867,9 @@ def run_tp_server(
|
|
776
867
|
nccl_port: int,
|
777
868
|
model_overide_args: dict,
|
778
869
|
):
|
779
|
-
"""Run a tensor parallel server."""
|
870
|
+
"""Run a tensor parallel model server."""
|
871
|
+
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
872
|
+
|
780
873
|
try:
|
781
874
|
model_server = ModelTpServer(
|
782
875
|
gpu_id,
|
@@ -832,6 +925,7 @@ def broadcast_recv_input(
|
|
832
925
|
|
833
926
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
834
927
|
dist.broadcast(tensor_data, src=0, group=dist_group)
|
928
|
+
return data
|
835
929
|
else:
|
836
930
|
tensor_size = torch.tensor([0], dtype=torch.long)
|
837
931
|
dist.broadcast(tensor_size, src=0, group=dist_group)
|
@@ -16,7 +16,8 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from
|
19
|
+
from abc import ABC, abstractmethod
|
20
|
+
from typing import List, Tuple, Union
|
20
21
|
|
21
22
|
import torch
|
22
23
|
|
@@ -52,14 +53,21 @@ class ReqToTokenPool:
|
|
52
53
|
self.free_slots = list(range(self.size))
|
53
54
|
|
54
55
|
|
55
|
-
class BaseTokenToKVPool:
|
56
|
+
class BaseTokenToKVPool(ABC):
|
56
57
|
"""A memory pool that maps a token to its kv cache locations"""
|
57
58
|
|
58
59
|
def __init__(
|
59
60
|
self,
|
60
61
|
size: int,
|
62
|
+
dtype: torch.dtype,
|
61
63
|
):
|
62
64
|
self.size = size
|
65
|
+
self.dtype = dtype
|
66
|
+
if dtype == torch.float8_e5m2:
|
67
|
+
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
68
|
+
self.store_dtype = torch.uint8
|
69
|
+
else:
|
70
|
+
self.store_dtype = dtype
|
63
71
|
|
64
72
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
65
73
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
|
|
112
120
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
113
121
|
self.mem_state[0] = False
|
114
122
|
|
123
|
+
@abstractmethod
|
124
|
+
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
125
|
+
raise NotImplementedError()
|
126
|
+
|
127
|
+
@abstractmethod
|
128
|
+
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
129
|
+
raise NotImplementedError()
|
130
|
+
|
131
|
+
@abstractmethod
|
132
|
+
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
133
|
+
raise NotImplementedError()
|
134
|
+
|
135
|
+
@abstractmethod
|
136
|
+
def set_kv_buffer(
|
137
|
+
self,
|
138
|
+
layer_id: int,
|
139
|
+
loc: torch.Tensor,
|
140
|
+
cache_k: torch.Tensor,
|
141
|
+
cache_v: torch.Tensor,
|
142
|
+
) -> None:
|
143
|
+
raise NotImplementedError()
|
144
|
+
|
115
145
|
|
116
146
|
class MHATokenToKVPool(BaseTokenToKVPool):
|
117
147
|
|
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
123
153
|
head_dim: int,
|
124
154
|
layer_num: int,
|
125
155
|
):
|
126
|
-
super().__init__(size)
|
156
|
+
super().__init__(size, dtype)
|
127
157
|
|
128
158
|
# [size, head_num, head_dim] for each layer
|
129
159
|
self.k_buffer = [
|
130
|
-
torch.empty(
|
160
|
+
torch.empty(
|
161
|
+
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
162
|
+
)
|
131
163
|
for _ in range(layer_num)
|
132
164
|
]
|
133
165
|
self.v_buffer = [
|
134
|
-
torch.empty(
|
166
|
+
torch.empty(
|
167
|
+
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
168
|
+
)
|
135
169
|
for _ in range(layer_num)
|
136
170
|
]
|
137
171
|
|
138
172
|
def get_key_buffer(self, layer_id: int):
|
173
|
+
if self.store_dtype != self.dtype:
|
174
|
+
return self.k_buffer[layer_id].view(self.dtype)
|
139
175
|
return self.k_buffer[layer_id]
|
140
176
|
|
141
177
|
def get_value_buffer(self, layer_id: int):
|
178
|
+
if self.store_dtype != self.dtype:
|
179
|
+
return self.v_buffer[layer_id].view(self.dtype)
|
142
180
|
return self.v_buffer[layer_id]
|
143
181
|
|
144
182
|
def get_kv_buffer(self, layer_id: int):
|
145
|
-
return self.
|
183
|
+
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
184
|
+
|
185
|
+
def set_kv_buffer(
|
186
|
+
self,
|
187
|
+
layer_id: int,
|
188
|
+
loc: torch.Tensor,
|
189
|
+
cache_k: torch.Tensor,
|
190
|
+
cache_v: torch.Tensor,
|
191
|
+
):
|
192
|
+
if cache_k.dtype != self.dtype:
|
193
|
+
cache_k = cache_k.to(self.dtype)
|
194
|
+
if cache_v.dtype != self.dtype:
|
195
|
+
cache_v = cache_v.to(self.dtype)
|
196
|
+
if self.store_dtype != self.dtype:
|
197
|
+
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
198
|
+
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
199
|
+
else:
|
200
|
+
self.k_buffer[layer_id][loc] = cache_k
|
201
|
+
self.v_buffer[layer_id][loc] = cache_v
|
146
202
|
|
147
203
|
|
148
204
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
155
211
|
qk_rope_head_dim: int,
|
156
212
|
layer_num: int,
|
157
213
|
):
|
158
|
-
super().__init__(size)
|
214
|
+
super().__init__(size, dtype)
|
159
215
|
|
160
216
|
self.kv_lora_rank = kv_lora_rank
|
161
217
|
self.kv_buffer = [
|
162
218
|
torch.empty(
|
163
219
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
164
|
-
dtype=
|
220
|
+
dtype=self.store_dtype,
|
165
221
|
device="cuda",
|
166
222
|
)
|
167
223
|
for _ in range(layer_num)
|
168
224
|
]
|
169
225
|
|
170
226
|
def get_key_buffer(self, layer_id: int):
|
227
|
+
if self.store_dtype != self.dtype:
|
228
|
+
return self.kv_buffer[layer_id].view(self.dtype)
|
171
229
|
return self.kv_buffer[layer_id]
|
172
230
|
|
173
231
|
def get_value_buffer(self, layer_id: int):
|
232
|
+
if self.store_dtype != self.dtype:
|
233
|
+
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
174
234
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
175
235
|
|
176
236
|
def get_kv_buffer(self, layer_id: int):
|
177
237
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
238
|
+
|
239
|
+
def set_kv_buffer(
|
240
|
+
self,
|
241
|
+
layer_id: int,
|
242
|
+
loc: torch.Tensor,
|
243
|
+
cache_k: torch.Tensor,
|
244
|
+
cache_v: torch.Tensor,
|
245
|
+
):
|
246
|
+
if cache_k.dtype != self.dtype:
|
247
|
+
cache_k = cache_k.to(self.dtype)
|
248
|
+
if self.store_dtype != self.dtype:
|
249
|
+
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
250
|
+
else:
|
251
|
+
self.kv_buffer[layer_id][loc] = cache_k
|
sglang/srt/mm_utils.py
CHANGED
@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
# Source: https://github.com/
|
16
|
+
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
|
17
|
+
"""
|
18
|
+
Utilities for multi-modal models.
|
19
|
+
|
20
|
+
This python file mainly contains utilities that were used in the
|
21
|
+
image processing logic of llava-next including operations such as
|
22
|
+
anyres and anyres_max
|
23
|
+
|
24
|
+
Currently supports the anyres and anyres_max operation for CLIP and
|
25
|
+
SigLip. For more information, you may refer to the paper or the blog
|
26
|
+
|
27
|
+
LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
|
28
|
+
LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
|
29
|
+
|
30
|
+
"""
|
17
31
|
import ast
|
18
32
|
import base64
|
19
33
|
import math
|
34
|
+
import re
|
20
35
|
from io import BytesIO
|
21
36
|
|
22
37
|
import numpy as np
|
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
|
|
40
55
|
min_wasted_resolution = float("inf")
|
41
56
|
|
42
57
|
for width, height in possible_resolutions:
|
58
|
+
# Calculate the downscaled size to keep the aspect ratio
|
43
59
|
scale = min(width / original_width, height / original_height)
|
44
60
|
downscaled_width, downscaled_height = int(original_width * scale), int(
|
45
61
|
original_height * scale
|
46
62
|
)
|
63
|
+
|
64
|
+
# Calculate effective and wasted resolutions
|
47
65
|
effective_resolution = min(
|
48
66
|
downscaled_width * downscaled_height, original_width * original_height
|
49
67
|
)
|
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|
129
147
|
Returns:
|
130
148
|
tuple: The shape of the image patch grid in the format (width, height).
|
131
149
|
"""
|
150
|
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
151
|
+
assert patch_size in [
|
152
|
+
224,
|
153
|
+
336,
|
154
|
+
384,
|
155
|
+
448,
|
156
|
+
512,
|
157
|
+
], "patch_size should be in [224, 336, 384, 448, 512]"
|
158
|
+
# Use regex to extract the range from the input string
|
159
|
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
160
|
+
range_start = tuple(map(int, matches[0]))
|
161
|
+
range_end = tuple(map(int, matches[-1]))
|
162
|
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
163
|
+
grid_pinpoints = [
|
164
|
+
(i, j)
|
165
|
+
for i in range(range_start[0], range_end[0] + 1)
|
166
|
+
for j in range(range_start[1], range_end[1] + 1)
|
167
|
+
]
|
168
|
+
# Multiply all elements by patch_size
|
169
|
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
132
170
|
if type(grid_pinpoints) is list:
|
133
171
|
possible_resolutions = grid_pinpoints
|
134
172
|
else:
|
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
149
187
|
Returns:
|
150
188
|
np.array: An np array containing the processed image patches.
|
151
189
|
"""
|
190
|
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
191
|
+
try:
|
192
|
+
patch_size = processor.size[0]
|
193
|
+
except Exception as e:
|
194
|
+
patch_size = processor.size["shortest_edge"]
|
195
|
+
assert patch_size in [
|
196
|
+
224,
|
197
|
+
336,
|
198
|
+
384,
|
199
|
+
448,
|
200
|
+
512,
|
201
|
+
], "patch_size should be in [224, 336, 384, 448, 512]"
|
202
|
+
# Use regex to extract the range from the input string
|
203
|
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
204
|
+
range_start = tuple(map(int, matches[0]))
|
205
|
+
range_end = tuple(map(int, matches[-1]))
|
206
|
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
207
|
+
grid_pinpoints = [
|
208
|
+
(i, j)
|
209
|
+
for i in range(range_start[0], range_end[0] + 1)
|
210
|
+
for j in range(range_start[1], range_end[1] + 1)
|
211
|
+
]
|
212
|
+
# Multiply all elements by patch_size
|
213
|
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
214
|
+
|
152
215
|
if type(grid_pinpoints) is list:
|
153
216
|
possible_resolutions = grid_pinpoints
|
154
217
|
else:
|
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
|
|
156
219
|
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
157
220
|
image_padded = resize_and_pad_image(image, best_resolution)
|
158
221
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
222
|
+
# For Siglip processor, only have size but no crop size
|
223
|
+
crop_size = (
|
224
|
+
processor.crop_size["height"]
|
225
|
+
if "crop_size" in processor.__dict__
|
226
|
+
else processor.size["height"]
|
163
227
|
)
|
228
|
+
shortest_edge = (
|
229
|
+
processor.size["shortest_edge"]
|
230
|
+
if "shortest_edge" in processor.size
|
231
|
+
else processor.size["height"]
|
232
|
+
)
|
233
|
+
patches = divide_to_patches(image_padded, crop_size)
|
234
|
+
|
235
|
+
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
164
236
|
|
165
237
|
image_patches = [image_original_resize] + patches
|
166
238
|
image_patches = [
|
167
|
-
processor.preprocess(image_patch)["pixel_values"][0]
|
239
|
+
processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
|
168
240
|
for image_patch in image_patches
|
169
241
|
]
|
170
242
|
return np.stack(image_patches, axis=0)
|
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
|
|
255
327
|
)
|
256
328
|
image = image_processor.preprocess(image)["pixel_values"][0]
|
257
329
|
new_images.append(image)
|
258
|
-
elif
|
330
|
+
elif "anyres" in image_aspect_ratio:
|
259
331
|
for image in images:
|
260
332
|
image = process_anyres_image(
|
261
333
|
image, image_processor, model_cfg.image_grid_pinpoints
|