sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,8 @@ from sglang.srt.managers.io_struct import (
39
39
  FlushCacheReq,
40
40
  TokenizedEmbeddingReqInput,
41
41
  TokenizedGenerateReqInput,
42
+ UpdateWeightReqInput,
43
+ UpdateWeightReqOutput,
42
44
  )
43
45
  from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
44
46
  from sglang.srt.managers.schedule_batch import (
@@ -54,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
54
56
  from sglang.srt.model_executor.model_runner import ModelRunner
55
57
  from sglang.srt.server_args import ServerArgs
56
58
  from sglang.srt.utils import (
59
+ configure_logger,
57
60
  is_multimodal_model,
58
61
  set_random_seed,
59
62
  suppress_other_loggers,
@@ -85,10 +88,6 @@ class ModelTpServer:
85
88
  self.schedule_policy = server_args.schedule_policy
86
89
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
87
90
 
88
- # Chunked prefill
89
- self.chunked_prefill_size = server_args.chunked_prefill_size
90
- self.current_inflight_req = None
91
-
92
91
  # Init model and tokenizer
93
92
  self.model_config = ModelConfig(
94
93
  server_args.model_path,
@@ -96,6 +95,7 @@ class ModelTpServer:
96
95
  context_length=server_args.context_length,
97
96
  model_overide_args=model_overide_args,
98
97
  )
98
+
99
99
  self.model_runner = ModelRunner(
100
100
  model_config=self.model_config,
101
101
  mem_fraction_static=server_args.mem_fraction_static,
@@ -135,11 +135,17 @@ class ModelTpServer:
135
135
  self.model_config.context_len - 1,
136
136
  self.max_total_num_tokens - 1,
137
137
  )
138
+
139
+ # Sync random seed
140
+ server_args.random_seed = broadcast_recv_input(
141
+ [server_args.random_seed],
142
+ self.tp_rank,
143
+ self.model_runner.tp_group.cpu_group,
144
+ )[0]
138
145
  set_random_seed(server_args.random_seed)
139
146
 
140
147
  # Print info
141
148
  logger.info(
142
- f"[gpu={self.gpu_id}] "
143
149
  f"max_total_num_tokens={self.max_total_num_tokens}, "
144
150
  f"max_prefill_tokens={self.max_prefill_tokens}, "
145
151
  f"max_running_requests={self.max_running_requests}, "
@@ -175,6 +181,13 @@ class ModelTpServer:
175
181
  self.num_generated_tokens = 0
176
182
  self.last_stats_tic = time.time()
177
183
 
184
+ # Chunked prefill
185
+ self.chunked_prefill_size = server_args.chunked_prefill_size
186
+ self.current_inflight_req = None
187
+ self.is_mixed_chunk = (
188
+ self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
189
+ )
190
+
178
191
  # Init the FSM cache for constrained generation
179
192
  if not server_args.skip_tokenizer_init:
180
193
  self.regex_fsm_cache = FSMCache(
@@ -184,6 +197,16 @@ class ModelTpServer:
184
197
  "trust_remote_code": server_args.trust_remote_code,
185
198
  },
186
199
  skip_tokenizer_init=server_args.skip_tokenizer_init,
200
+ json_schema_mode=False,
201
+ )
202
+ self.json_fsm_cache = FSMCache(
203
+ server_args.tokenizer_path,
204
+ {
205
+ "tokenizer_mode": server_args.tokenizer_mode,
206
+ "trust_remote_code": server_args.trust_remote_code,
207
+ },
208
+ skip_tokenizer_init=server_args.skip_tokenizer_init,
209
+ json_schema_mode=True,
187
210
  )
188
211
  self.jump_forward_cache = JumpForwardCache()
189
212
 
@@ -211,6 +234,9 @@ class ModelTpServer:
211
234
  self.flush_cache()
212
235
  elif isinstance(recv_req, AbortReq):
213
236
  self.abort_request(recv_req)
237
+ elif isinstance(recv_req, UpdateWeightReqInput):
238
+ success, message = self.update_weights(recv_req)
239
+ self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
214
240
  else:
215
241
  raise ValueError(f"Invalid request: {recv_req}")
216
242
 
@@ -268,7 +294,7 @@ class ModelTpServer:
268
294
  self.num_generated_tokens = 0
269
295
  self.last_stats_tic = time.time()
270
296
  logger.info(
271
- f"[gpu={self.gpu_id}] Decode batch. "
297
+ f"Decode batch. "
272
298
  f"#running-req: {len(self.running_batch.reqs)}, "
273
299
  f"#token: {num_used}, "
274
300
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
@@ -307,11 +333,16 @@ class ModelTpServer:
307
333
  if self.model_runner.is_generation:
308
334
  req.pixel_values = recv_req.pixel_values
309
335
  if req.pixel_values is not None:
336
+ image_hash = (
337
+ hash(tuple(recv_req.image_hash))
338
+ if isinstance(recv_req.image_hash, list)
339
+ else recv_req.image_hash
340
+ )
310
341
  req.pad_value = [
311
- (recv_req.image_hash) % self.model_config.vocab_size,
312
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
313
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
314
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
342
+ (image_hash) % self.model_config.vocab_size,
343
+ (image_hash >> 16) % self.model_config.vocab_size,
344
+ (image_hash >> 32) % self.model_config.vocab_size,
345
+ (image_hash >> 64) % self.model_config.vocab_size,
315
346
  ]
316
347
  req.image_size = recv_req.image_size
317
348
  (
@@ -328,8 +359,17 @@ class ModelTpServer:
328
359
  req.top_logprobs_num = recv_req.top_logprobs_num
329
360
  req.stream = recv_req.stream
330
361
 
362
+ # Init regex fsm fron json
363
+ if req.sampling_params.json_schema is not None:
364
+ req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
365
+ req.sampling_params.json_schema
366
+ )
367
+ if not self.disable_regex_jump_forward:
368
+ req.jump_forward_map = self.jump_forward_cache.query(
369
+ computed_regex_string
370
+ )
331
371
  # Init regex fsm
332
- if req.sampling_params.regex is not None:
372
+ elif req.sampling_params.regex is not None:
333
373
  req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
334
374
  if not self.disable_regex_jump_forward:
335
375
  req.jump_forward_map = self.jump_forward_cache.query(
@@ -366,11 +406,14 @@ class ModelTpServer:
366
406
  # Get priority queue
367
407
  prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
368
408
 
409
+ num_mixed_running = running_bs if self.is_mixed_chunk else 0
410
+
369
411
  adder = PrefillAdder(
370
412
  self.tree_cache,
371
413
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
372
414
  self.max_prefill_tokens,
373
415
  self.chunked_prefill_size,
416
+ num_mixed_running,
374
417
  )
375
418
 
376
419
  if self.running_batch is not None:
@@ -416,15 +459,27 @@ class ModelTpServer:
416
459
  )
417
460
  else:
418
461
  tree_cache_hit_rate = 0.0
419
- logger.info(
420
- f"[gpu={self.gpu_id}] Prefill batch. "
421
- f"#new-seq: {len(can_run_list)}, "
422
- f"#new-token: {adder.log_input_tokens}, "
423
- f"#cached-token: {adder.log_hit_tokens}, "
424
- f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
425
- f"#running-req: {running_bs}, "
426
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
427
- )
462
+
463
+ if num_mixed_running > 0:
464
+ logger.info(
465
+ f"Prefill batch"
466
+ f"(mixed #running-req: {num_mixed_running}). "
467
+ f"#new-seq: {len(can_run_list)}, "
468
+ f"#new-token: {adder.log_input_tokens}, "
469
+ f"#cached-token: {adder.log_hit_tokens}, "
470
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
471
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
472
+ )
473
+ else:
474
+ logger.info(
475
+ f"Prefill batch. "
476
+ f"#new-seq: {len(can_run_list)}, "
477
+ f"#new-token: {adder.log_input_tokens}, "
478
+ f"#cached-token: {adder.log_hit_tokens}, "
479
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
480
+ f"#running-req: {running_bs}, "
481
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
482
+ )
428
483
 
429
484
  # Return the new batch
430
485
  new_batch = ScheduleBatch.init_new(
@@ -440,11 +495,21 @@ class ModelTpServer:
440
495
  # Build batch tensors
441
496
  batch.prepare_for_extend(self.model_config.vocab_size)
442
497
 
498
+ decoding_reqs = []
499
+ if self.is_mixed_chunk and self.running_batch is not None:
500
+ self.running_batch.prepare_for_decode()
501
+ batch.mix_with_running(self.running_batch)
502
+ decoding_reqs = self.running_batch.reqs
503
+ self.running_batch = None
504
+
443
505
  if self.model_runner.is_generation:
444
506
  # Forward and sample the next tokens
445
507
  if batch.extend_num_tokens != 0:
446
508
  output = self.model_runner.forward(batch, ForwardMode.EXTEND)
447
509
  next_token_ids = batch.sample(output.next_token_logits)
510
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
511
+ next_token_ids
512
+ )
448
513
 
449
514
  # Move logprobs to cpu
450
515
  if output.next_token_logprobs is not None:
@@ -477,9 +542,15 @@ class ModelTpServer:
477
542
  req.output_ids.append(next_token_ids[i])
478
543
  req.check_finished()
479
544
 
545
+ if req.regex_fsm is not None:
546
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
547
+ req.regex_fsm_state, next_token_ids[i]
548
+ )
549
+
480
550
  if req.finished():
481
551
  self.tree_cache.cache_finished_req(req)
482
- else:
552
+ elif req not in decoding_reqs:
553
+ # To reduce overhead, only cache prefill reqs
483
554
  self.tree_cache.cache_unfinished_req(req)
484
555
 
485
556
  if req is self.current_inflight_req:
@@ -579,7 +650,7 @@ class ModelTpServer:
579
650
  self.new_token_ratio = new_token_ratio
580
651
 
581
652
  logger.info(
582
- "decode out of memory happened, "
653
+ "Decode out of memory happened. "
583
654
  f"#retracted_reqs: {len(retracted_reqs)}, "
584
655
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
585
656
  )
@@ -604,6 +675,9 @@ class ModelTpServer:
604
675
  # Forward and sample the next tokens
605
676
  output = self.model_runner.forward(batch, ForwardMode.DECODE)
606
677
  next_token_ids = batch.sample(output.next_token_logits)
678
+ batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
679
+ next_token_ids
680
+ )
607
681
 
608
682
  # Move logprobs to cpu
609
683
  if output.next_token_logprobs is not None:
@@ -620,6 +694,11 @@ class ModelTpServer:
620
694
  req.output_ids.append(next_token_id)
621
695
  req.check_finished()
622
696
 
697
+ if req.regex_fsm is not None:
698
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
699
+ req.regex_fsm_state, next_token_id
700
+ )
701
+
623
702
  if req.finished():
624
703
  self.tree_cache.cache_finished_req(req)
625
704
 
@@ -743,12 +822,15 @@ class ModelTpServer:
743
822
  self.token_to_kv_pool.clear()
744
823
  torch.cuda.empty_cache()
745
824
  logger.info("Cache flushed successfully!")
825
+ if_success = True
746
826
  else:
747
- warnings.warn(
827
+ logging.warning(
748
828
  f"Cache not flushed because there are pending requests. "
749
829
  f"#queue-req: {len(self.waiting_queue)}, "
750
830
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
751
831
  )
832
+ if_success = False
833
+ return if_success
752
834
 
753
835
  def abort_request(self, recv_req):
754
836
  # Delete requests in the waiting queue
@@ -768,6 +850,15 @@ class ModelTpServer:
768
850
  req.finished_reason = FINISH_ABORT()
769
851
  break
770
852
 
853
+ def update_weights(self, recv_req):
854
+ success, message = self.model_runner.update_weights(
855
+ recv_req.model_path, recv_req.load_format
856
+ )
857
+ if success:
858
+ flash_cache_success = self.flush_cache()
859
+ assert flash_cache_success, "Cache flush failed after updating weights"
860
+ return success, message
861
+
771
862
 
772
863
  def run_tp_server(
773
864
  gpu_id: int,
@@ -776,7 +867,9 @@ def run_tp_server(
776
867
  nccl_port: int,
777
868
  model_overide_args: dict,
778
869
  ):
779
- """Run a tensor parallel server."""
870
+ """Run a tensor parallel model server."""
871
+ configure_logger(server_args, prefix=f" TP{tp_rank}")
872
+
780
873
  try:
781
874
  model_server = ModelTpServer(
782
875
  gpu_id,
@@ -832,6 +925,7 @@ def broadcast_recv_input(
832
925
 
833
926
  dist.broadcast(tensor_size, src=0, group=dist_group)
834
927
  dist.broadcast(tensor_data, src=0, group=dist_group)
928
+ return data
835
929
  else:
836
930
  tensor_size = torch.tensor([0], dtype=torch.long)
837
931
  dist.broadcast(tensor_size, src=0, group=dist_group)
@@ -16,7 +16,8 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
- from typing import List, Union
19
+ from abc import ABC, abstractmethod
20
+ from typing import List, Tuple, Union
20
21
 
21
22
  import torch
22
23
 
@@ -52,14 +53,21 @@ class ReqToTokenPool:
52
53
  self.free_slots = list(range(self.size))
53
54
 
54
55
 
55
- class BaseTokenToKVPool:
56
+ class BaseTokenToKVPool(ABC):
56
57
  """A memory pool that maps a token to its kv cache locations"""
57
58
 
58
59
  def __init__(
59
60
  self,
60
61
  size: int,
62
+ dtype: torch.dtype,
61
63
  ):
62
64
  self.size = size
65
+ self.dtype = dtype
66
+ if dtype == torch.float8_e5m2:
67
+ # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
68
+ self.store_dtype = torch.uint8
69
+ else:
70
+ self.store_dtype = dtype
63
71
 
64
72
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
65
73
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
@@ -112,6 +120,28 @@ class BaseTokenToKVPool:
112
120
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
113
121
  self.mem_state[0] = False
114
122
 
123
+ @abstractmethod
124
+ def get_key_buffer(self, layer_id: int) -> torch.Tensor:
125
+ raise NotImplementedError()
126
+
127
+ @abstractmethod
128
+ def get_value_buffer(self, layer_id: int) -> torch.Tensor:
129
+ raise NotImplementedError()
130
+
131
+ @abstractmethod
132
+ def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
133
+ raise NotImplementedError()
134
+
135
+ @abstractmethod
136
+ def set_kv_buffer(
137
+ self,
138
+ layer_id: int,
139
+ loc: torch.Tensor,
140
+ cache_k: torch.Tensor,
141
+ cache_v: torch.Tensor,
142
+ ) -> None:
143
+ raise NotImplementedError()
144
+
115
145
 
116
146
  class MHATokenToKVPool(BaseTokenToKVPool):
117
147
 
@@ -123,26 +153,52 @@ class MHATokenToKVPool(BaseTokenToKVPool):
123
153
  head_dim: int,
124
154
  layer_num: int,
125
155
  ):
126
- super().__init__(size)
156
+ super().__init__(size, dtype)
127
157
 
128
158
  # [size, head_num, head_dim] for each layer
129
159
  self.k_buffer = [
130
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
160
+ torch.empty(
161
+ (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
162
+ )
131
163
  for _ in range(layer_num)
132
164
  ]
133
165
  self.v_buffer = [
134
- torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
166
+ torch.empty(
167
+ (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
168
+ )
135
169
  for _ in range(layer_num)
136
170
  ]
137
171
 
138
172
  def get_key_buffer(self, layer_id: int):
173
+ if self.store_dtype != self.dtype:
174
+ return self.k_buffer[layer_id].view(self.dtype)
139
175
  return self.k_buffer[layer_id]
140
176
 
141
177
  def get_value_buffer(self, layer_id: int):
178
+ if self.store_dtype != self.dtype:
179
+ return self.v_buffer[layer_id].view(self.dtype)
142
180
  return self.v_buffer[layer_id]
143
181
 
144
182
  def get_kv_buffer(self, layer_id: int):
145
- return self.k_buffer[layer_id], self.v_buffer[layer_id]
183
+ return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
184
+
185
+ def set_kv_buffer(
186
+ self,
187
+ layer_id: int,
188
+ loc: torch.Tensor,
189
+ cache_k: torch.Tensor,
190
+ cache_v: torch.Tensor,
191
+ ):
192
+ if cache_k.dtype != self.dtype:
193
+ cache_k = cache_k.to(self.dtype)
194
+ if cache_v.dtype != self.dtype:
195
+ cache_v = cache_v.to(self.dtype)
196
+ if self.store_dtype != self.dtype:
197
+ self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
198
+ self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
199
+ else:
200
+ self.k_buffer[layer_id][loc] = cache_k
201
+ self.v_buffer[layer_id][loc] = cache_v
146
202
 
147
203
 
148
204
  class MLATokenToKVPool(BaseTokenToKVPool):
@@ -155,23 +211,41 @@ class MLATokenToKVPool(BaseTokenToKVPool):
155
211
  qk_rope_head_dim: int,
156
212
  layer_num: int,
157
213
  ):
158
- super().__init__(size)
214
+ super().__init__(size, dtype)
159
215
 
160
216
  self.kv_lora_rank = kv_lora_rank
161
217
  self.kv_buffer = [
162
218
  torch.empty(
163
219
  (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
164
- dtype=dtype,
220
+ dtype=self.store_dtype,
165
221
  device="cuda",
166
222
  )
167
223
  for _ in range(layer_num)
168
224
  ]
169
225
 
170
226
  def get_key_buffer(self, layer_id: int):
227
+ if self.store_dtype != self.dtype:
228
+ return self.kv_buffer[layer_id].view(self.dtype)
171
229
  return self.kv_buffer[layer_id]
172
230
 
173
231
  def get_value_buffer(self, layer_id: int):
232
+ if self.store_dtype != self.dtype:
233
+ return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
174
234
  return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
175
235
 
176
236
  def get_kv_buffer(self, layer_id: int):
177
237
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
238
+
239
+ def set_kv_buffer(
240
+ self,
241
+ layer_id: int,
242
+ loc: torch.Tensor,
243
+ cache_k: torch.Tensor,
244
+ cache_v: torch.Tensor,
245
+ ):
246
+ if cache_k.dtype != self.dtype:
247
+ cache_k = cache_k.to(self.dtype)
248
+ if self.store_dtype != self.dtype:
249
+ self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
250
+ else:
251
+ self.kv_buffer[layer_id][loc] = cache_k
sglang/srt/mm_utils.py CHANGED
@@ -13,10 +13,25 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- # Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
16
+ # Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
17
+ """
18
+ Utilities for multi-modal models.
19
+
20
+ This python file mainly contains utilities that were used in the
21
+ image processing logic of llava-next including operations such as
22
+ anyres and anyres_max
23
+
24
+ Currently supports the anyres and anyres_max operation for CLIP and
25
+ SigLip. For more information, you may refer to the paper or the blog
26
+
27
+ LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/
28
+ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
29
+
30
+ """
17
31
  import ast
18
32
  import base64
19
33
  import math
34
+ import re
20
35
  from io import BytesIO
21
36
 
22
37
  import numpy as np
@@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions):
40
55
  min_wasted_resolution = float("inf")
41
56
 
42
57
  for width, height in possible_resolutions:
58
+ # Calculate the downscaled size to keep the aspect ratio
43
59
  scale = min(width / original_width, height / original_height)
44
60
  downscaled_width, downscaled_height = int(original_width * scale), int(
45
61
  original_height * scale
46
62
  )
63
+
64
+ # Calculate effective and wasted resolutions
47
65
  effective_resolution = min(
48
66
  downscaled_width * downscaled_height, original_width * original_height
49
67
  )
@@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
129
147
  Returns:
130
148
  tuple: The shape of the image patch grid in the format (width, height).
131
149
  """
150
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
151
+ assert patch_size in [
152
+ 224,
153
+ 336,
154
+ 384,
155
+ 448,
156
+ 512,
157
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
158
+ # Use regex to extract the range from the input string
159
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
160
+ range_start = tuple(map(int, matches[0]))
161
+ range_end = tuple(map(int, matches[-1]))
162
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
163
+ grid_pinpoints = [
164
+ (i, j)
165
+ for i in range(range_start[0], range_end[0] + 1)
166
+ for j in range(range_start[1], range_end[1] + 1)
167
+ ]
168
+ # Multiply all elements by patch_size
169
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
132
170
  if type(grid_pinpoints) is list:
133
171
  possible_resolutions = grid_pinpoints
134
172
  else:
@@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints):
149
187
  Returns:
150
188
  np.array: An np array containing the processed image patches.
151
189
  """
190
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
191
+ try:
192
+ patch_size = processor.size[0]
193
+ except Exception as e:
194
+ patch_size = processor.size["shortest_edge"]
195
+ assert patch_size in [
196
+ 224,
197
+ 336,
198
+ 384,
199
+ 448,
200
+ 512,
201
+ ], "patch_size should be in [224, 336, 384, 448, 512]"
202
+ # Use regex to extract the range from the input string
203
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
204
+ range_start = tuple(map(int, matches[0]))
205
+ range_end = tuple(map(int, matches[-1]))
206
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
207
+ grid_pinpoints = [
208
+ (i, j)
209
+ for i in range(range_start[0], range_end[0] + 1)
210
+ for j in range(range_start[1], range_end[1] + 1)
211
+ ]
212
+ # Multiply all elements by patch_size
213
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
214
+
152
215
  if type(grid_pinpoints) is list:
153
216
  possible_resolutions = grid_pinpoints
154
217
  else:
@@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints):
156
219
  best_resolution = select_best_resolution(image.size, possible_resolutions)
157
220
  image_padded = resize_and_pad_image(image, best_resolution)
158
221
 
159
- patches = divide_to_patches(image_padded, processor.crop_size["height"])
160
-
161
- image_original_resize = image.resize(
162
- (processor.size["shortest_edge"], processor.size["shortest_edge"])
222
+ # For Siglip processor, only have size but no crop size
223
+ crop_size = (
224
+ processor.crop_size["height"]
225
+ if "crop_size" in processor.__dict__
226
+ else processor.size["height"]
163
227
  )
228
+ shortest_edge = (
229
+ processor.size["shortest_edge"]
230
+ if "shortest_edge" in processor.size
231
+ else processor.size["height"]
232
+ )
233
+ patches = divide_to_patches(image_padded, crop_size)
234
+
235
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
164
236
 
165
237
  image_patches = [image_original_resize] + patches
166
238
  image_patches = [
167
- processor.preprocess(image_patch)["pixel_values"][0]
239
+ processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0]
168
240
  for image_patch in image_patches
169
241
  ]
170
242
  return np.stack(image_patches, axis=0)
@@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg):
255
327
  )
256
328
  image = image_processor.preprocess(image)["pixel_values"][0]
257
329
  new_images.append(image)
258
- elif image_aspect_ratio == "anyres":
330
+ elif "anyres" in image_aspect_ratio:
259
331
  for image in images:
260
332
  image = process_anyres_image(
261
333
  image, image_processor, model_cfg.image_grid_pinpoints