sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
31
31
  from sglang.srt.constrained.fsm_cache import FSMCache
32
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
34
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
36
36
  AbortReq,
37
37
  BatchEmbeddingOut,
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
39
39
  FlushCacheReq,
40
40
  TokenizedEmbeddingReqInput,
41
41
  TokenizedGenerateReqInput,
42
+ UpdateWeightReqInput,
43
+ UpdateWeightReqOutput,
42
44
  )
43
45
  from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
44
46
  from sglang.srt.managers.schedule_batch import (
@@ -54,7 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
54
56
  from sglang.srt.model_executor.model_runner import ModelRunner
55
57
  from sglang.srt.server_args import ServerArgs
56
58
  from sglang.srt.utils import (
57
- get_int_token_logit_bias,
59
+ configure_logger,
58
60
  is_multimodal_model,
59
61
  set_random_seed,
60
62
  suppress_other_loggers,
@@ -86,10 +88,6 @@ class ModelTpServer:
86
88
  self.schedule_policy = server_args.schedule_policy
87
89
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
88
90
 
89
- # Chunked prefill
90
- self.chunked_prefill_size = server_args.chunked_prefill_size
91
- self.current_inflight_req = None
92
-
93
91
  # Init model and tokenizer
94
92
  self.model_config = ModelConfig(
95
93
  server_args.model_path,
@@ -97,6 +95,7 @@ class ModelTpServer:
97
95
  context_length=server_args.context_length,
98
96
  model_overide_args=model_overide_args,
99
97
  )
98
+
100
99
  self.model_runner = ModelRunner(
101
100
  model_config=self.model_config,
102
101
  mem_fraction_static=server_args.mem_fraction_static,
@@ -132,18 +131,21 @@ class ModelTpServer:
132
131
  ),
133
132
  self.model_runner.req_to_token_pool.size - 1,
134
133
  )
135
- self.int_token_logit_bias = torch.tensor(
136
- get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
137
- )
138
134
  self.max_req_input_len = min(
139
135
  self.model_config.context_len - 1,
140
136
  self.max_total_num_tokens - 1,
141
137
  )
138
+
139
+ # Sync random seed
140
+ server_args.random_seed = broadcast_recv_input(
141
+ [server_args.random_seed],
142
+ self.tp_rank,
143
+ self.model_runner.tp_group.cpu_group,
144
+ )[0]
142
145
  set_random_seed(server_args.random_seed)
143
146
 
144
147
  # Print info
145
148
  logger.info(
146
- f"[gpu={self.gpu_id}] "
147
149
  f"max_total_num_tokens={self.max_total_num_tokens}, "
148
150
  f"max_prefill_tokens={self.max_prefill_tokens}, "
149
151
  f"max_running_requests={self.max_running_requests}, "
@@ -179,6 +181,13 @@ class ModelTpServer:
179
181
  self.num_generated_tokens = 0
180
182
  self.last_stats_tic = time.time()
181
183
 
184
+ # Chunked prefill
185
+ self.chunked_prefill_size = server_args.chunked_prefill_size
186
+ self.current_inflight_req = None
187
+ self.is_mixed_chunk = (
188
+ self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
189
+ )
190
+
182
191
  # Init the FSM cache for constrained generation
183
192
  if not server_args.skip_tokenizer_init:
184
193
  self.regex_fsm_cache = FSMCache(
@@ -215,6 +224,9 @@ class ModelTpServer:
215
224
  self.flush_cache()
216
225
  elif isinstance(recv_req, AbortReq):
217
226
  self.abort_request(recv_req)
227
+ elif isinstance(recv_req, UpdateWeightReqInput):
228
+ success, message = self.update_weights(recv_req)
229
+ self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
218
230
  else:
219
231
  raise ValueError(f"Invalid request: {recv_req}")
220
232
 
@@ -272,7 +284,7 @@ class ModelTpServer:
272
284
  self.num_generated_tokens = 0
273
285
  self.last_stats_tic = time.time()
274
286
  logger.info(
275
- f"[gpu={self.gpu_id}] Decode batch. "
287
+ f"Decode batch. "
276
288
  f"#running-req: {len(self.running_batch.reqs)}, "
277
289
  f"#token: {num_used}, "
278
290
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
@@ -311,11 +323,16 @@ class ModelTpServer:
311
323
  if self.model_runner.is_generation:
312
324
  req.pixel_values = recv_req.pixel_values
313
325
  if req.pixel_values is not None:
326
+ image_hash = (
327
+ hash(tuple(recv_req.image_hash))
328
+ if isinstance(recv_req.image_hash, list)
329
+ else recv_req.image_hash
330
+ )
314
331
  req.pad_value = [
315
- (recv_req.image_hash) % self.model_config.vocab_size,
316
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
317
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
318
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
332
+ (image_hash) % self.model_config.vocab_size,
333
+ (image_hash >> 16) % self.model_config.vocab_size,
334
+ (image_hash >> 32) % self.model_config.vocab_size,
335
+ (image_hash >> 64) % self.model_config.vocab_size,
319
336
  ]
320
337
  req.image_size = recv_req.image_size
321
338
  (
@@ -370,11 +387,14 @@ class ModelTpServer:
370
387
  # Get priority queue
371
388
  prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
372
389
 
390
+ num_mixed_running = running_bs if self.is_mixed_chunk else 0
391
+
373
392
  adder = PrefillAdder(
374
393
  self.tree_cache,
375
394
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
376
395
  self.max_prefill_tokens,
377
396
  self.chunked_prefill_size,
397
+ num_mixed_running,
378
398
  )
379
399
 
380
400
  if self.running_batch is not None:
@@ -420,15 +440,27 @@ class ModelTpServer:
420
440
  )
421
441
  else:
422
442
  tree_cache_hit_rate = 0.0
423
- logger.info(
424
- f"[gpu={self.gpu_id}] Prefill batch. "
425
- f"#new-seq: {len(can_run_list)}, "
426
- f"#new-token: {adder.log_input_tokens}, "
427
- f"#cached-token: {adder.log_hit_tokens}, "
428
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
429
- f"#running-req: {running_bs}, "
430
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
431
- )
443
+
444
+ if num_mixed_running > 0:
445
+ logger.info(
446
+ f"Prefill batch"
447
+ f"(mixed #running-req: {num_mixed_running}). "
448
+ f"#new-seq: {len(can_run_list)}, "
449
+ f"#new-token: {adder.log_input_tokens}, "
450
+ f"#cached-token: {adder.log_hit_tokens}, "
451
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
452
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
453
+ )
454
+ else:
455
+ logger.info(
456
+ f"Prefill batch. "
457
+ f"#new-seq: {len(can_run_list)}, "
458
+ f"#new-token: {adder.log_input_tokens}, "
459
+ f"#cached-token: {adder.log_hit_tokens}, "
460
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
461
+ f"#running-req: {running_bs}, "
462
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
463
+ )
432
464
 
433
465
  # Return the new batch
434
466
  new_batch = ScheduleBatch.init_new(
@@ -442,25 +474,41 @@ class ModelTpServer:
442
474
 
443
475
  def forward_prefill_batch(self, batch: ScheduleBatch):
444
476
  # Build batch tensors
445
- batch.prepare_for_extend(
446
- self.model_config.vocab_size, self.int_token_logit_bias
447
- )
477
+ batch.prepare_for_extend(self.model_config.vocab_size)
478
+
479
+ decoding_reqs = []
480
+ if self.is_mixed_chunk and self.running_batch is not None:
481
+ self.running_batch.prepare_for_decode()
482
+ batch.mix_with_running(self.running_batch)
483
+ decoding_reqs = self.running_batch.reqs
484
+ self.running_batch = None
448
485
 
449
486
  if self.model_runner.is_generation:
450
487
  # Forward and sample the next tokens
451
488
  if batch.extend_num_tokens != 0:
452
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
453
- next_token_ids = batch.sample(output.next_token_logits)
489
+ sample_output, logits_output = self.model_runner.forward(
490
+ batch, ForwardMode.EXTEND
491
+ )
492
+ next_token_ids = batch.check_sample_results(sample_output)
493
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
494
+ next_token_ids
495
+ )
454
496
 
455
497
  # Move logprobs to cpu
456
- if output.next_token_logprobs is not None:
457
- output.next_token_logprobs = output.next_token_logprobs[
458
- torch.arange(len(next_token_ids), device=next_token_ids.device),
459
- next_token_ids,
460
- ].tolist()
461
- output.input_token_logprobs = output.input_token_logprobs.tolist()
462
- output.normalized_prompt_logprobs = (
463
- output.normalized_prompt_logprobs.tolist()
498
+ if logits_output.next_token_logprobs is not None:
499
+ logits_output.next_token_logprobs = (
500
+ logits_output.next_token_logprobs[
501
+ torch.arange(
502
+ len(next_token_ids), device=next_token_ids.device
503
+ ),
504
+ next_token_ids,
505
+ ].tolist()
506
+ )
507
+ logits_output.input_token_logprobs = (
508
+ logits_output.input_token_logprobs.tolist()
509
+ )
510
+ logits_output.normalized_prompt_logprobs = (
511
+ logits_output.normalized_prompt_logprobs.tolist()
464
512
  )
465
513
 
466
514
  next_token_ids = next_token_ids.tolist()
@@ -483,9 +531,15 @@ class ModelTpServer:
483
531
  req.output_ids.append(next_token_ids[i])
484
532
  req.check_finished()
485
533
 
534
+ if req.regex_fsm is not None:
535
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
536
+ req.regex_fsm_state, next_token_ids[i]
537
+ )
538
+
486
539
  if req.finished():
487
540
  self.tree_cache.cache_finished_req(req)
488
- else:
541
+ elif req not in decoding_reqs:
542
+ # To reduce overhead, only cache prefill reqs
489
543
  self.tree_cache.cache_unfinished_req(req)
490
544
 
491
545
  if req is self.current_inflight_req:
@@ -493,12 +547,14 @@ class ModelTpServer:
493
547
  self.req_to_token_pool.free(req.req_pool_idx)
494
548
 
495
549
  if req.return_logprob:
496
- self.add_logprob_return_values(i, req, pt, next_token_ids, output)
550
+ self.add_logprob_return_values(
551
+ i, req, pt, next_token_ids, logits_output
552
+ )
497
553
  pt += req.extend_input_len
498
554
  else:
499
555
  assert batch.extend_num_tokens != 0
500
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
501
- embeddings = output.embeddings.tolist()
556
+ logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
557
+ embeddings = logits_output.embeddings.tolist()
502
558
 
503
559
  # Check finish conditions
504
560
  for i, req in enumerate(batch.reqs):
@@ -526,7 +582,7 @@ class ModelTpServer:
526
582
  req: Req,
527
583
  pt: int,
528
584
  next_token_ids: List[int],
529
- output: LogitProcessorOutput,
585
+ output: LogitsProcessorOutput,
530
586
  ):
531
587
  if req.normalized_prompt_logprob is None:
532
588
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -585,7 +641,7 @@ class ModelTpServer:
585
641
  self.new_token_ratio = new_token_ratio
586
642
 
587
643
  logger.info(
588
- "decode out of memory happened, "
644
+ "Decode out of memory happened. "
589
645
  f"#retracted_reqs: {len(retracted_reqs)}, "
590
646
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
591
647
  )
@@ -608,12 +664,17 @@ class ModelTpServer:
608
664
  batch.prepare_for_decode()
609
665
 
610
666
  # Forward and sample the next tokens
611
- output = self.model_runner.forward(batch, ForwardMode.DECODE)
612
- next_token_ids = batch.sample(output.next_token_logits)
667
+ sample_output, logits_output = self.model_runner.forward(
668
+ batch, ForwardMode.DECODE
669
+ )
670
+ next_token_ids = batch.check_sample_results(sample_output)
671
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
672
+ next_token_ids
673
+ )
613
674
 
614
675
  # Move logprobs to cpu
615
- if output.next_token_logprobs is not None:
616
- next_token_logprobs = output.next_token_logprobs[
676
+ if logits_output.next_token_logprobs is not None:
677
+ next_token_logprobs = logits_output.next_token_logprobs[
617
678
  torch.arange(len(next_token_ids), device=next_token_ids.device),
618
679
  next_token_ids,
619
680
  ].tolist()
@@ -626,6 +687,11 @@ class ModelTpServer:
626
687
  req.output_ids.append(next_token_id)
627
688
  req.check_finished()
628
689
 
690
+ if req.regex_fsm is not None:
691
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
692
+ req.regex_fsm_state, next_token_id
693
+ )
694
+
629
695
  if req.finished():
630
696
  self.tree_cache.cache_finished_req(req)
631
697
 
@@ -634,7 +700,7 @@ class ModelTpServer:
634
700
  (next_token_logprobs[i], next_token_id)
635
701
  )
636
702
  if req.top_logprobs_num > 0:
637
- req.output_top_logprobs.append(output.output_top_logprobs[i])
703
+ req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
638
704
 
639
705
  self.handle_finished_requests(batch)
640
706
 
@@ -749,12 +815,15 @@ class ModelTpServer:
749
815
  self.token_to_kv_pool.clear()
750
816
  torch.cuda.empty_cache()
751
817
  logger.info("Cache flushed successfully!")
818
+ if_success = True
752
819
  else:
753
- warnings.warn(
820
+ logging.warning(
754
821
  f"Cache not flushed because there are pending requests. "
755
822
  f"#queue-req: {len(self.waiting_queue)}, "
756
823
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
757
824
  )
825
+ if_success = False
826
+ return if_success
758
827
 
759
828
  def abort_request(self, recv_req):
760
829
  # Delete requests in the waiting queue
@@ -774,6 +843,15 @@ class ModelTpServer:
774
843
  req.finished_reason = FINISH_ABORT()
775
844
  break
776
845
 
846
+ def update_weights(self, recv_req):
847
+ success, message = self.model_runner.update_weights(
848
+ recv_req.model_path, recv_req.load_format
849
+ )
850
+ if success:
851
+ flash_cache_success = self.flush_cache()
852
+ assert flash_cache_success, "Cache flush failed after updating weights"
853
+ return success, message
854
+
777
855
 
778
856
  def run_tp_server(
779
857
  gpu_id: int,
@@ -782,7 +860,9 @@ def run_tp_server(
782
860
  nccl_port: int,
783
861
  model_overide_args: dict,
784
862
  ):
785
- """Run a tensor parallel server."""
863
+ """Run a tensor parallel model server."""
864
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
865
+
786
866
  try:
787
867
  model_server = ModelTpServer(
788
868
  gpu_id,
@@ -838,6 +918,7 @@ def broadcast_recv_input(
838
918
 
839
919
  dist.broadcast(tensor_size, src=0, group=dist_group)
840
920
  dist.broadcast(tensor_data, src=0, group=dist_group)
921
+ return data
841
922
  else:
842
923
  tensor_size = torch.tensor([0], dtype=torch.long)
843
924
  dist.broadcast(tensor_size, src=0, group=dist_group)
@@ -16,7 +16,8 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
- from typing import List, Union
19
+ from abc import ABC, abstractmethod
20
+ from typing import List, Tuple, Union
20
21
 
21
22
  import torch
22
23
 
@@ -52,14 +53,21 @@ class ReqToTokenPool:
52
53
  self.free_slots = list(range(self.size))
53
54
 
54
55
 
55
- class BaseTokenToKVPool:
56
+ class BaseTokenToKVPool(ABC):
56
57
  """A memory pool that maps a token to its kv cache locations"""
57
58
 
58
59
  def __init__(
59
60
  self,
60
61
  size: int,
62
+ dtype: torch.dtype,
61
63
  ):
62
64
  self.size = size
65
+ self.dtype = dtype
66
+ if dtype == torch.float8_e5m2:
67
+ # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
68
+ self.store_dtype = torch.uint8
69
+ else:
70
+ self.store_dtype = dtype
63
71
 
64
72
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
65
73
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
112
120
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
113
121
  self.mem_state[0] = False
114
122
 
123
+ @abstractmethod
124
+ def get_key_buffer(self, layer_id: int) -> torch.Tensor:
125
+ raise NotImplementedError()
126
+
127
+ @abstractmethod
128
+ def get_value_buffer(self, layer_id: int) -> torch.Tensor:
129
+ raise NotImplementedError()
130
+
131
+ @abstractmethod
132
+ def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ raise NotImplementedError()
134
+
135
+ @abstractmethod
136
+ def set_kv_buffer(
137
+ self,
138
+ layer_id: int,
139
+ loc: torch.Tensor,
140
+ cache_k: torch.Tensor,
141
+ cache_v: torch.Tensor,
142
+ ) -> None:
143
+ raise NotImplementedError()
144
+
115
145
 
116
146
  class MHATokenToKVPool(BaseTokenToKVPool):
117
147
 
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
123
153
  head_dim: int,
124
154
  layer_num: int,
125
155
  ):
126
- super().__init__(size)
156
+ super().__init__(size, dtype)
127
157
 
128
158
  # [size, head_num, head_dim] for each layer
129
159
  self.k_buffer = [
130
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
160
+ torch.empty(
161
+ (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
162
+ )
131
163
  for _ in range(layer_num)
132
164
  ]
133
165
  self.v_buffer = [
134
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
166
+ torch.empty(
167
+ (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
168
+ )
135
169
  for _ in range(layer_num)
136
170
  ]
137
171
 
138
172
  def get_key_buffer(self, layer_id: int):
173
+ if self.store_dtype != self.dtype:
174
+ return self.k_buffer[layer_id].view(self.dtype)
139
175
  return self.k_buffer[layer_id]
140
176
 
141
177
  def get_value_buffer(self, layer_id: int):
178
+ if self.store_dtype != self.dtype:
179
+ return self.v_buffer[layer_id].view(self.dtype)
142
180
  return self.v_buffer[layer_id]
143
181
 
144
182
  def get_kv_buffer(self, layer_id: int):
145
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
183
+ return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
184
+
185
+ def set_kv_buffer(
186
+ self,
187
+ layer_id: int,
188
+ loc: torch.Tensor,
189
+ cache_k: torch.Tensor,
190
+ cache_v: torch.Tensor,
191
+ ):
192
+ if cache_k.dtype != self.dtype:
193
+ cache_k = cache_k.to(self.dtype)
194
+ if cache_v.dtype != self.dtype:
195
+ cache_v = cache_v.to(self.dtype)
196
+ if self.store_dtype != self.dtype:
197
+ self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
198
+ self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
199
+ else:
200
+ self.k_buffer[layer_id][loc] = cache_k
201
+ self.v_buffer[layer_id][loc] = cache_v
146
202
 
147
203
 
148
204
  class MLATokenToKVPool(BaseTokenToKVPool):
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
155
211
  qk_rope_head_dim: int,
156
212
  layer_num: int,
157
213
  ):
158
- super().__init__(size)
214
+ super().__init__(size, dtype)
159
215
 
160
216
  self.kv_lora_rank = kv_lora_rank
161
217
  self.kv_buffer = [
162
218
  torch.empty(
163
219
  (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
164
- dtype=dtype,
220
+ dtype=self.store_dtype,
165
221
  device="cuda",
166
222
  )
167
223
  for _ in range(layer_num)
168
224
  ]
169
225
 
170
226
  def get_key_buffer(self, layer_id: int):
227
+ if self.store_dtype != self.dtype:
228
+ return self.kv_buffer[layer_id].view(self.dtype)
171
229
  return self.kv_buffer[layer_id]
172
230
 
173
231
  def get_value_buffer(self, layer_id: int):
232
+ if self.store_dtype != self.dtype:
233
+ return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
174
234
  return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
175
235
 
176
236
  def get_kv_buffer(self, layer_id: int):
177
237
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
238
+
239
+ def set_kv_buffer(
240
+ self,
241
+ layer_id: int,
242
+ loc: torch.Tensor,
243
+ cache_k: torch.Tensor,
244
+ cache_v: torch.Tensor,
245
+ ):
246
+ if cache_k.dtype != self.dtype:
247
+ cache_k = cache_k.to(self.dtype)
248
+ if self.store_dtype != self.dtype:
249
+ self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
250
+ else:
251
+ self.kv_buffer[layer_id][loc] = cache_k
sglang/srt/mm_utils.py CHANGED
@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- # Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
16
+ # Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
17
+ """
18
+ Utilities for multi-modal models.
19
+
20
+ This python file mainly contains utilities that were used in the
21
+ image processing logic of llava-next including operations such as
22
+ anyres and anyres_max
23
+
24
+ Currently supports the anyres and anyres_max operation for CLIP and
25
+ SigLip. For more information, you may refer to the paper or the blog
26
+
27
+ LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
28
+ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
29
+
30
+ """
17
31
  import ast
18
32
  import base64
19
33
  import math
34
+ import re
20
35
  from io import BytesIO
21
36
 
22
37
  import numpy as np
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
40
55
  min_wasted_resolution = float("inf")
41
56
 
42
57
  for width, height in possible_resolutions:
58
+ # Calculate the downscaled size to keep the aspect ratio
43
59
  scale = min(width / original_width, height / original_height)
44
60
  downscaled_width, downscaled_height = int(original_width * scale), int(
45
61
  original_height * scale
46
62
  )
63
+
64
+ # Calculate effective and wasted resolutions
47
65
  effective_resolution = min(
48
66
  downscaled_width * downscaled_height, original_width * original_height
49
67
  )
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
129
147
  Returns:
130
148
  tuple: The shape of the image patch grid in the format (width, height).
131
149
  """
150
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
151
+ assert patch_size in [
152
+ 224,
153
+ 336,
154
+ 384,
155
+ 448,
156
+ 512,
157
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
158
+ # Use regex to extract the range from the input string
159
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
160
+ range_start = tuple(map(int, matches[0]))
161
+ range_end = tuple(map(int, matches[-1]))
162
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
163
+ grid_pinpoints = [
164
+ (i, j)
165
+ for i in range(range_start[0], range_end[0] + 1)
166
+ for j in range(range_start[1], range_end[1] + 1)
167
+ ]
168
+ # Multiply all elements by patch_size
169
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
132
170
  if type(grid_pinpoints) is list:
133
171
  possible_resolutions = grid_pinpoints
134
172
  else:
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
149
187
  Returns:
150
188
  np.array: An np array containing the processed image patches.
151
189
  """
190
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
191
+ try:
192
+ patch_size = processor.size[0]
193
+ except Exception as e:
194
+ patch_size = processor.size["shortest_edge"]
195
+ assert patch_size in [
196
+ 224,
197
+ 336,
198
+ 384,
199
+ 448,
200
+ 512,
201
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
202
+ # Use regex to extract the range from the input string
203
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
204
+ range_start = tuple(map(int, matches[0]))
205
+ range_end = tuple(map(int, matches[-1]))
206
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
207
+ grid_pinpoints = [
208
+ (i, j)
209
+ for i in range(range_start[0], range_end[0] + 1)
210
+ for j in range(range_start[1], range_end[1] + 1)
211
+ ]
212
+ # Multiply all elements by patch_size
213
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
214
+
152
215
  if type(grid_pinpoints) is list:
153
216
  possible_resolutions = grid_pinpoints
154
217
  else:
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
156
219
  best_resolution = select_best_resolution(image.size, possible_resolutions)
157
220
  image_padded = resize_and_pad_image(image, best_resolution)
158
221
 
159
- patches = divide_to_patches(image_padded, processor.crop_size["height"])
160
-
161
- image_original_resize = image.resize(
162
- (processor.size["shortest_edge"], processor.size["shortest_edge"])
222
+ # For Siglip processor, only have size but no crop size
223
+ crop_size = (
224
+ processor.crop_size["height"]
225
+ if "crop_size" in processor.__dict__
226
+ else processor.size["height"]
163
227
  )
228
+ shortest_edge = (
229
+ processor.size["shortest_edge"]
230
+ if "shortest_edge" in processor.size
231
+ else processor.size["height"]
232
+ )
233
+ patches = divide_to_patches(image_padded, crop_size)
234
+
235
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
164
236
 
165
237
  image_patches = [image_original_resize] + patches
166
238
  image_patches = [
167
- processor.preprocess(image_patch)["pixel_values"][0]
239
+ processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
168
240
  for image_patch in image_patches
169
241
  ]
170
242
  return np.stack(image_patches, axis=0)
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
255
327
  )
256
328
  image = image_processor.preprocess(image)["pixel_values"][0]
257
329
  new_images.append(image)
258
- elif image_aspect_ratio == "anyres":
330
+ elif "anyres" in image_aspect_ratio:
259
331
  for image in images:
260
332
  image = process_anyres_image(
261
333
  image, image_processor, model_cfg.image_grid_pinpoints