sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
31
31
  from sglang.srt.constrained.fsm_cache import FSMCache
32
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
34
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
36
36
  AbortReq,
37
37
  BatchEmbeddingOut,
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
39
39
  FlushCacheReq,
40
40
  TokenizedEmbeddingReqInput,
41
41
  TokenizedGenerateReqInput,
42
+ UpdateWeightReqInput,
43
+ UpdateWeightReqOutput,
42
44
  )
43
45
  from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
44
46
  from sglang.srt.managers.schedule_batch import (
@@ -54,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
54
56
  from sglang.srt.model_executor.model_runner import ModelRunner
55
57
  from sglang.srt.server_args import ServerArgs
56
58
  from sglang.srt.utils import (
59
+ configure_logger,
57
60
  is_multimodal_model,
58
61
  set_random_seed,
59
62
  suppress_other_loggers,
@@ -85,10 +88,6 @@ class ModelTpServer:
85
88
  self.schedule_policy = server_args.schedule_policy
86
89
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
87
90
 
88
- # Chunked prefill
89
- self.chunked_prefill_size = server_args.chunked_prefill_size
90
- self.current_inflight_req = None
91
-
92
91
  # Init model and tokenizer
93
92
  self.model_config = ModelConfig(
94
93
  server_args.model_path,
@@ -96,6 +95,7 @@ class ModelTpServer:
96
95
  context_length=server_args.context_length,
97
96
  model_overide_args=model_overide_args,
98
97
  )
98
+
99
99
  self.model_runner = ModelRunner(
100
100
  model_config=self.model_config,
101
101
  mem_fraction_static=server_args.mem_fraction_static,
@@ -135,11 +135,17 @@ class ModelTpServer:
135
135
  self.model_config.context_len - 1,
136
136
  self.max_total_num_tokens - 1,
137
137
  )
138
+
139
+ # Sync random seed
140
+ server_args.random_seed = broadcast_recv_input(
141
+ [server_args.random_seed],
142
+ self.tp_rank,
143
+ self.model_runner.tp_group.cpu_group,
144
+ )[0]
138
145
  set_random_seed(server_args.random_seed)
139
146
 
140
147
  # Print info
141
148
  logger.info(
142
- f"[gpu={self.gpu_id}] "
143
149
  f"max_total_num_tokens={self.max_total_num_tokens}, "
144
150
  f"max_prefill_tokens={self.max_prefill_tokens}, "
145
151
  f"max_running_requests={self.max_running_requests}, "
@@ -175,6 +181,13 @@ class ModelTpServer:
175
181
  self.num_generated_tokens = 0
176
182
  self.last_stats_tic = time.time()
177
183
 
184
+ # Chunked prefill
185
+ self.chunked_prefill_size = server_args.chunked_prefill_size
186
+ self.current_inflight_req = None
187
+ self.is_mixed_chunk = (
188
+ self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
189
+ )
190
+
178
191
  # Init the FSM cache for constrained generation
179
192
  if not server_args.skip_tokenizer_init:
180
193
  self.regex_fsm_cache = FSMCache(
@@ -211,6 +224,9 @@ class ModelTpServer:
211
224
  self.flush_cache()
212
225
  elif isinstance(recv_req, AbortReq):
213
226
  self.abort_request(recv_req)
227
+ elif isinstance(recv_req, UpdateWeightReqInput):
228
+ success, message = self.update_weights(recv_req)
229
+ self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
214
230
  else:
215
231
  raise ValueError(f"Invalid request: {recv_req}")
216
232
 
@@ -268,7 +284,7 @@ class ModelTpServer:
268
284
  self.num_generated_tokens = 0
269
285
  self.last_stats_tic = time.time()
270
286
  logger.info(
271
- f"[gpu={self.gpu_id}] Decode batch. "
287
+ f"Decode batch. "
272
288
  f"#running-req: {len(self.running_batch.reqs)}, "
273
289
  f"#token: {num_used}, "
274
290
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
@@ -307,11 +323,16 @@ class ModelTpServer:
307
323
  if self.model_runner.is_generation:
308
324
  req.pixel_values = recv_req.pixel_values
309
325
  if req.pixel_values is not None:
326
+ image_hash = (
327
+ hash(tuple(recv_req.image_hash))
328
+ if isinstance(recv_req.image_hash, list)
329
+ else recv_req.image_hash
330
+ )
310
331
  req.pad_value = [
311
- (recv_req.image_hash) % self.model_config.vocab_size,
312
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
313
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
314
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
332
+ (image_hash) % self.model_config.vocab_size,
333
+ (image_hash >> 16) % self.model_config.vocab_size,
334
+ (image_hash >> 32) % self.model_config.vocab_size,
335
+ (image_hash >> 64) % self.model_config.vocab_size,
315
336
  ]
316
337
  req.image_size = recv_req.image_size
317
338
  (
@@ -366,11 +387,14 @@ class ModelTpServer:
366
387
  # Get priority queue
367
388
  prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
368
389
 
390
+ num_mixed_running = running_bs if self.is_mixed_chunk else 0
391
+
369
392
  adder = PrefillAdder(
370
393
  self.tree_cache,
371
394
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
372
395
  self.max_prefill_tokens,
373
396
  self.chunked_prefill_size,
397
+ num_mixed_running,
374
398
  )
375
399
 
376
400
  if self.running_batch is not None:
@@ -416,15 +440,27 @@ class ModelTpServer:
416
440
  )
417
441
  else:
418
442
  tree_cache_hit_rate = 0.0
419
- logger.info(
420
- f"[gpu={self.gpu_id}] Prefill batch. "
421
- f"#new-seq: {len(can_run_list)}, "
422
- f"#new-token: {adder.log_input_tokens}, "
423
- f"#cached-token: {adder.log_hit_tokens}, "
424
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
425
- f"#running-req: {running_bs}, "
426
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
427
- )
443
+
444
+ if num_mixed_running > 0:
445
+ logger.info(
446
+ f"Prefill batch"
447
+ f"(mixed #running-req: {num_mixed_running}). "
448
+ f"#new-seq: {len(can_run_list)}, "
449
+ f"#new-token: {adder.log_input_tokens}, "
450
+ f"#cached-token: {adder.log_hit_tokens}, "
451
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
452
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
453
+ )
454
+ else:
455
+ logger.info(
456
+ f"Prefill batch. "
457
+ f"#new-seq: {len(can_run_list)}, "
458
+ f"#new-token: {adder.log_input_tokens}, "
459
+ f"#cached-token: {adder.log_hit_tokens}, "
460
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
461
+ f"#running-req: {running_bs}, "
462
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
463
+ )
428
464
 
429
465
  # Return the new batch
430
466
  new_batch = ScheduleBatch.init_new(
@@ -440,21 +476,39 @@ class ModelTpServer:
440
476
  # Build batch tensors
441
477
  batch.prepare_for_extend(self.model_config.vocab_size)
442
478
 
479
+ decoding_reqs = []
480
+ if self.is_mixed_chunk and self.running_batch is not None:
481
+ self.running_batch.prepare_for_decode()
482
+ batch.mix_with_running(self.running_batch)
483
+ decoding_reqs = self.running_batch.reqs
484
+ self.running_batch = None
485
+
443
486
  if self.model_runner.is_generation:
444
487
  # Forward and sample the next tokens
445
488
  if batch.extend_num_tokens != 0:
446
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
447
- next_token_ids = batch.sample(output.next_token_logits)
489
+ sample_output, logits_output = self.model_runner.forward(
490
+ batch, ForwardMode.EXTEND
491
+ )
492
+ next_token_ids = batch.check_sample_results(sample_output)
493
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
494
+ next_token_ids
495
+ )
448
496
 
449
497
  # Move logprobs to cpu
450
- if output.next_token_logprobs is not None:
451
- output.next_token_logprobs = output.next_token_logprobs[
452
- torch.arange(len(next_token_ids), device=next_token_ids.device),
453
- next_token_ids,
454
- ].tolist()
455
- output.input_token_logprobs = output.input_token_logprobs.tolist()
456
- output.normalized_prompt_logprobs = (
457
- output.normalized_prompt_logprobs.tolist()
498
+ if logits_output.next_token_logprobs is not None:
499
+ logits_output.next_token_logprobs = (
500
+ logits_output.next_token_logprobs[
501
+ torch.arange(
502
+ len(next_token_ids), device=next_token_ids.device
503
+ ),
504
+ next_token_ids,
505
+ ].tolist()
506
+ )
507
+ logits_output.input_token_logprobs = (
508
+ logits_output.input_token_logprobs.tolist()
509
+ )
510
+ logits_output.normalized_prompt_logprobs = (
511
+ logits_output.normalized_prompt_logprobs.tolist()
458
512
  )
459
513
 
460
514
  next_token_ids = next_token_ids.tolist()
@@ -477,9 +531,15 @@ class ModelTpServer:
477
531
  req.output_ids.append(next_token_ids[i])
478
532
  req.check_finished()
479
533
 
534
+ if req.regex_fsm is not None:
535
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
536
+ req.regex_fsm_state, next_token_ids[i]
537
+ )
538
+
480
539
  if req.finished():
481
540
  self.tree_cache.cache_finished_req(req)
482
- else:
541
+ elif req not in decoding_reqs:
542
+ # To reduce overhead, only cache prefill reqs
483
543
  self.tree_cache.cache_unfinished_req(req)
484
544
 
485
545
  if req is self.current_inflight_req:
@@ -487,12 +547,14 @@ class ModelTpServer:
487
547
  self.req_to_token_pool.free(req.req_pool_idx)
488
548
 
489
549
  if req.return_logprob:
490
- self.add_logprob_return_values(i, req, pt, next_token_ids, output)
550
+ self.add_logprob_return_values(
551
+ i, req, pt, next_token_ids, logits_output
552
+ )
491
553
  pt += req.extend_input_len
492
554
  else:
493
555
  assert batch.extend_num_tokens != 0
494
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
495
- embeddings = output.embeddings.tolist()
556
+ logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
557
+ embeddings = logits_output.embeddings.tolist()
496
558
 
497
559
  # Check finish conditions
498
560
  for i, req in enumerate(batch.reqs):
@@ -520,7 +582,7 @@ class ModelTpServer:
520
582
  req: Req,
521
583
  pt: int,
522
584
  next_token_ids: List[int],
523
- output: LogitProcessorOutput,
585
+ output: LogitsProcessorOutput,
524
586
  ):
525
587
  if req.normalized_prompt_logprob is None:
526
588
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -579,7 +641,7 @@ class ModelTpServer:
579
641
  self.new_token_ratio = new_token_ratio
580
642
 
581
643
  logger.info(
582
- "decode out of memory happened, "
644
+ "Decode out of memory happened. "
583
645
  f"#retracted_reqs: {len(retracted_reqs)}, "
584
646
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
585
647
  )
@@ -602,12 +664,17 @@ class ModelTpServer:
602
664
  batch.prepare_for_decode()
603
665
 
604
666
  # Forward and sample the next tokens
605
- output = self.model_runner.forward(batch, ForwardMode.DECODE)
606
- next_token_ids = batch.sample(output.next_token_logits)
667
+ sample_output, logits_output = self.model_runner.forward(
668
+ batch, ForwardMode.DECODE
669
+ )
670
+ next_token_ids = batch.check_sample_results(sample_output)
671
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
672
+ next_token_ids
673
+ )
607
674
 
608
675
  # Move logprobs to cpu
609
- if output.next_token_logprobs is not None:
610
- next_token_logprobs = output.next_token_logprobs[
676
+ if logits_output.next_token_logprobs is not None:
677
+ next_token_logprobs = logits_output.next_token_logprobs[
611
678
  torch.arange(len(next_token_ids), device=next_token_ids.device),
612
679
  next_token_ids,
613
680
  ].tolist()
@@ -620,6 +687,11 @@ class ModelTpServer:
620
687
  req.output_ids.append(next_token_id)
621
688
  req.check_finished()
622
689
 
690
+ if req.regex_fsm is not None:
691
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
692
+ req.regex_fsm_state, next_token_id
693
+ )
694
+
623
695
  if req.finished():
624
696
  self.tree_cache.cache_finished_req(req)
625
697
 
@@ -628,7 +700,7 @@ class ModelTpServer:
628
700
  (next_token_logprobs[i], next_token_id)
629
701
  )
630
702
  if req.top_logprobs_num > 0:
631
- req.output_top_logprobs.append(output.output_top_logprobs[i])
703
+ req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
632
704
 
633
705
  self.handle_finished_requests(batch)
634
706
 
@@ -743,12 +815,15 @@ class ModelTpServer:
743
815
  self.token_to_kv_pool.clear()
744
816
  torch.cuda.empty_cache()
745
817
  logger.info("Cache flushed successfully!")
818
+ if_success = True
746
819
  else:
747
- warnings.warn(
820
+ logging.warning(
748
821
  f"Cache not flushed because there are pending requests. "
749
822
  f"#queue-req: {len(self.waiting_queue)}, "
750
823
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
751
824
  )
825
+ if_success = False
826
+ return if_success
752
827
 
753
828
  def abort_request(self, recv_req):
754
829
  # Delete requests in the waiting queue
@@ -768,6 +843,15 @@ class ModelTpServer:
768
843
  req.finished_reason = FINISH_ABORT()
769
844
  break
770
845
 
846
+ def update_weights(self, recv_req):
847
+ success, message = self.model_runner.update_weights(
848
+ recv_req.model_path, recv_req.load_format
849
+ )
850
+ if success:
851
+ flash_cache_success = self.flush_cache()
852
+ assert flash_cache_success, "Cache flush failed after updating weights"
853
+ return success, message
854
+
771
855
 
772
856
  def run_tp_server(
773
857
  gpu_id: int,
@@ -776,7 +860,9 @@ def run_tp_server(
776
860
  nccl_port: int,
777
861
  model_overide_args: dict,
778
862
  ):
779
- """Run a tensor parallel server."""
863
+ """Run a tensor parallel model server."""
864
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
865
+
780
866
  try:
781
867
  model_server = ModelTpServer(
782
868
  gpu_id,
@@ -832,6 +918,7 @@ def broadcast_recv_input(
832
918
 
833
919
  dist.broadcast(tensor_size, src=0, group=dist_group)
834
920
  dist.broadcast(tensor_data, src=0, group=dist_group)
921
+ return data
835
922
  else:
836
923
  tensor_size = torch.tensor([0], dtype=torch.long)
837
924
  dist.broadcast(tensor_size, src=0, group=dist_group)
@@ -16,7 +16,8 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
- from typing import List, Union
19
+ from abc import ABC, abstractmethod
20
+ from typing import List, Tuple, Union
20
21
 
21
22
  import torch
22
23
 
@@ -52,14 +53,21 @@ class ReqToTokenPool:
52
53
  self.free_slots = list(range(self.size))
53
54
 
54
55
 
55
- class BaseTokenToKVPool:
56
+ class BaseTokenToKVPool(ABC):
56
57
  """A memory pool that maps a token to its kv cache locations"""
57
58
 
58
59
  def __init__(
59
60
  self,
60
61
  size: int,
62
+ dtype: torch.dtype,
61
63
  ):
62
64
  self.size = size
65
+ self.dtype = dtype
66
+ if dtype == torch.float8_e5m2:
67
+ # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
68
+ self.store_dtype = torch.uint8
69
+ else:
70
+ self.store_dtype = dtype
63
71
 
64
72
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
65
73
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
112
120
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
113
121
  self.mem_state[0] = False
114
122
 
123
+ @abstractmethod
124
+ def get_key_buffer(self, layer_id: int) -> torch.Tensor:
125
+ raise NotImplementedError()
126
+
127
+ @abstractmethod
128
+ def get_value_buffer(self, layer_id: int) -> torch.Tensor:
129
+ raise NotImplementedError()
130
+
131
+ @abstractmethod
132
+ def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ raise NotImplementedError()
134
+
135
+ @abstractmethod
136
+ def set_kv_buffer(
137
+ self,
138
+ layer_id: int,
139
+ loc: torch.Tensor,
140
+ cache_k: torch.Tensor,
141
+ cache_v: torch.Tensor,
142
+ ) -> None:
143
+ raise NotImplementedError()
144
+
115
145
 
116
146
  class MHATokenToKVPool(BaseTokenToKVPool):
117
147
 
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
123
153
  head_dim: int,
124
154
  layer_num: int,
125
155
  ):
126
- super().__init__(size)
156
+ super().__init__(size, dtype)
127
157
 
128
158
  # [size, head_num, head_dim] for each layer
129
159
  self.k_buffer = [
130
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
160
+ torch.empty(
161
+ (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
162
+ )
131
163
  for _ in range(layer_num)
132
164
  ]
133
165
  self.v_buffer = [
134
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
166
+ torch.empty(
167
+ (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
168
+ )
135
169
  for _ in range(layer_num)
136
170
  ]
137
171
 
138
172
  def get_key_buffer(self, layer_id: int):
173
+ if self.store_dtype != self.dtype:
174
+ return self.k_buffer[layer_id].view(self.dtype)
139
175
  return self.k_buffer[layer_id]
140
176
 
141
177
  def get_value_buffer(self, layer_id: int):
178
+ if self.store_dtype != self.dtype:
179
+ return self.v_buffer[layer_id].view(self.dtype)
142
180
  return self.v_buffer[layer_id]
143
181
 
144
182
  def get_kv_buffer(self, layer_id: int):
145
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
183
+ return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
184
+
185
+ def set_kv_buffer(
186
+ self,
187
+ layer_id: int,
188
+ loc: torch.Tensor,
189
+ cache_k: torch.Tensor,
190
+ cache_v: torch.Tensor,
191
+ ):
192
+ if cache_k.dtype != self.dtype:
193
+ cache_k = cache_k.to(self.dtype)
194
+ if cache_v.dtype != self.dtype:
195
+ cache_v = cache_v.to(self.dtype)
196
+ if self.store_dtype != self.dtype:
197
+ self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
198
+ self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
199
+ else:
200
+ self.k_buffer[layer_id][loc] = cache_k
201
+ self.v_buffer[layer_id][loc] = cache_v
146
202
 
147
203
 
148
204
  class MLATokenToKVPool(BaseTokenToKVPool):
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
155
211
  qk_rope_head_dim: int,
156
212
  layer_num: int,
157
213
  ):
158
- super().__init__(size)
214
+ super().__init__(size, dtype)
159
215
 
160
216
  self.kv_lora_rank = kv_lora_rank
161
217
  self.kv_buffer = [
162
218
  torch.empty(
163
219
  (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
164
- dtype=dtype,
220
+ dtype=self.store_dtype,
165
221
  device="cuda",
166
222
  )
167
223
  for _ in range(layer_num)
168
224
  ]
169
225
 
170
226
  def get_key_buffer(self, layer_id: int):
227
+ if self.store_dtype != self.dtype:
228
+ return self.kv_buffer[layer_id].view(self.dtype)
171
229
  return self.kv_buffer[layer_id]
172
230
 
173
231
  def get_value_buffer(self, layer_id: int):
232
+ if self.store_dtype != self.dtype:
233
+ return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
174
234
  return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
175
235
 
176
236
  def get_kv_buffer(self, layer_id: int):
177
237
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
238
+
239
+ def set_kv_buffer(
240
+ self,
241
+ layer_id: int,
242
+ loc: torch.Tensor,
243
+ cache_k: torch.Tensor,
244
+ cache_v: torch.Tensor,
245
+ ):
246
+ if cache_k.dtype != self.dtype:
247
+ cache_k = cache_k.to(self.dtype)
248
+ if self.store_dtype != self.dtype:
249
+ self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
250
+ else:
251
+ self.kv_buffer[layer_id][loc] = cache_k
sglang/srt/mm_utils.py CHANGED
@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- # Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
16
+ # Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
17
+ """
18
+ Utilities for multi-modal models.
19
+
20
+ This python file mainly contains utilities that were used in the
21
+ image processing logic of llava-next including operations such as
22
+ anyres and anyres_max
23
+
24
+ Currently supports the anyres and anyres_max operation for CLIP and
25
+ SigLip. For more information, you may refer to the paper or the blog
26
+
27
+ LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
28
+ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
29
+
30
+ """
17
31
  import ast
18
32
  import base64
19
33
  import math
34
+ import re
20
35
  from io import BytesIO
21
36
 
22
37
  import numpy as np
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
40
55
  min_wasted_resolution = float("inf")
41
56
 
42
57
  for width, height in possible_resolutions:
58
+ # Calculate the downscaled size to keep the aspect ratio
43
59
  scale = min(width / original_width, height / original_height)
44
60
  downscaled_width, downscaled_height = int(original_width * scale), int(
45
61
  original_height * scale
46
62
  )
63
+
64
+ # Calculate effective and wasted resolutions
47
65
  effective_resolution = min(
48
66
  downscaled_width * downscaled_height, original_width * original_height
49
67
  )
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
129
147
  Returns:
130
148
  tuple: The shape of the image patch grid in the format (width, height).
131
149
  """
150
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
151
+ assert patch_size in [
152
+ 224,
153
+ 336,
154
+ 384,
155
+ 448,
156
+ 512,
157
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
158
+ # Use regex to extract the range from the input string
159
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
160
+ range_start = tuple(map(int, matches[0]))
161
+ range_end = tuple(map(int, matches[-1]))
162
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
163
+ grid_pinpoints = [
164
+ (i, j)
165
+ for i in range(range_start[0], range_end[0] + 1)
166
+ for j in range(range_start[1], range_end[1] + 1)
167
+ ]
168
+ # Multiply all elements by patch_size
169
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
132
170
  if type(grid_pinpoints) is list:
133
171
  possible_resolutions = grid_pinpoints
134
172
  else:
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
149
187
  Returns:
150
188
  np.array: An np array containing the processed image patches.
151
189
  """
190
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
191
+ try:
192
+ patch_size = processor.size[0]
193
+ except Exception as e:
194
+ patch_size = processor.size["shortest_edge"]
195
+ assert patch_size in [
196
+ 224,
197
+ 336,
198
+ 384,
199
+ 448,
200
+ 512,
201
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
202
+ # Use regex to extract the range from the input string
203
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
204
+ range_start = tuple(map(int, matches[0]))
205
+ range_end = tuple(map(int, matches[-1]))
206
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
207
+ grid_pinpoints = [
208
+ (i, j)
209
+ for i in range(range_start[0], range_end[0] + 1)
210
+ for j in range(range_start[1], range_end[1] + 1)
211
+ ]
212
+ # Multiply all elements by patch_size
213
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
214
+
152
215
  if type(grid_pinpoints) is list:
153
216
  possible_resolutions = grid_pinpoints
154
217
  else:
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
156
219
  best_resolution = select_best_resolution(image.size, possible_resolutions)
157
220
  image_padded = resize_and_pad_image(image, best_resolution)
158
221
 
159
- patches = divide_to_patches(image_padded, processor.crop_size["height"])
160
-
161
- image_original_resize = image.resize(
162
- (processor.size["shortest_edge"], processor.size["shortest_edge"])
222
+ # For Siglip processor, only have size but no crop size
223
+ crop_size = (
224
+ processor.crop_size["height"]
225
+ if "crop_size" in processor.__dict__
226
+ else processor.size["height"]
163
227
  )
228
+ shortest_edge = (
229
+ processor.size["shortest_edge"]
230
+ if "shortest_edge" in processor.size
231
+ else processor.size["height"]
232
+ )
233
+ patches = divide_to_patches(image_padded, crop_size)
234
+
235
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
164
236
 
165
237
  image_patches = [image_original_resize] + patches
166
238
  image_patches = [
167
- processor.preprocess(image_patch)["pixel_values"][0]
239
+ processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
168
240
  for image_patch in image_patches
169
241
  ]
170
242
  return np.stack(image_patches, axis=0)
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
255
327
  )
256
328
  image = image_processor.preprocess(image)["pixel_values"][0]
257
329
  new_images.append(image)
258
- elif image_aspect_ratio == "anyres":
330
+ elif "anyres" in image_aspect_ratio:
259
331
  for image in images:
260
332
  image = process_anyres_image(
261
333
  image, image_processor, model_cfg.image_grid_pinpoints