sglang 0.2.13__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/conversation.py +50 -1
  11. sglang/srt/hf_transformers_utils.py +22 -23
  12. sglang/srt/layers/activation.py +24 -1
  13. sglang/srt/layers/decode_attention.py +338 -50
  14. sglang/srt/layers/fused_moe/layer.py +2 -2
  15. sglang/srt/layers/layernorm.py +3 -0
  16. sglang/srt/layers/logits_processor.py +60 -23
  17. sglang/srt/layers/radix_attention.py +3 -4
  18. sglang/srt/layers/sampler.py +154 -0
  19. sglang/srt/managers/controller_multi.py +2 -8
  20. sglang/srt/managers/controller_single.py +7 -10
  21. sglang/srt/managers/detokenizer_manager.py +20 -9
  22. sglang/srt/managers/io_struct.py +44 -11
  23. sglang/srt/managers/policy_scheduler.py +5 -2
  24. sglang/srt/managers/schedule_batch.py +52 -167
  25. sglang/srt/managers/tokenizer_manager.py +192 -83
  26. sglang/srt/managers/tp_worker.py +130 -43
  27. sglang/srt/mem_cache/memory_pool.py +82 -8
  28. sglang/srt/mm_utils.py +79 -7
  29. sglang/srt/model_executor/cuda_graph_runner.py +49 -11
  30. sglang/srt/model_executor/forward_batch_info.py +59 -27
  31. sglang/srt/model_executor/model_runner.py +210 -61
  32. sglang/srt/models/chatglm.py +4 -12
  33. sglang/srt/models/commandr.py +5 -1
  34. sglang/srt/models/dbrx.py +5 -1
  35. sglang/srt/models/deepseek.py +5 -1
  36. sglang/srt/models/deepseek_v2.py +5 -1
  37. sglang/srt/models/gemma.py +5 -1
  38. sglang/srt/models/gemma2.py +15 -7
  39. sglang/srt/models/gpt_bigcode.py +5 -1
  40. sglang/srt/models/grok.py +16 -2
  41. sglang/srt/models/internlm2.py +5 -1
  42. sglang/srt/models/llama2.py +7 -3
  43. sglang/srt/models/llama_classification.py +2 -2
  44. sglang/srt/models/llama_embedding.py +4 -0
  45. sglang/srt/models/llava.py +176 -59
  46. sglang/srt/models/minicpm.py +5 -1
  47. sglang/srt/models/mixtral.py +5 -1
  48. sglang/srt/models/mixtral_quant.py +5 -1
  49. sglang/srt/models/qwen.py +5 -2
  50. sglang/srt/models/qwen2.py +13 -3
  51. sglang/srt/models/qwen2_moe.py +5 -14
  52. sglang/srt/models/stablelm.py +5 -1
  53. sglang/srt/openai_api/adapter.py +117 -37
  54. sglang/srt/sampling/sampling_batch_info.py +209 -0
  55. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -0
  56. sglang/srt/server.py +84 -56
  57. sglang/srt/server_args.py +43 -15
  58. sglang/srt/utils.py +26 -16
  59. sglang/test/runners.py +23 -31
  60. sglang/test/simple_eval_common.py +9 -10
  61. sglang/test/simple_eval_gpqa.py +2 -1
  62. sglang/test/simple_eval_humaneval.py +2 -2
  63. sglang/test/simple_eval_math.py +2 -1
  64. sglang/test/simple_eval_mmlu.py +2 -1
  65. sglang/test/test_activation.py +55 -0
  66. sglang/test/test_utils.py +36 -53
  67. sglang/version.py +1 -1
  68. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/METADATA +92 -25
  69. sglang-0.2.14.dist-info/RECORD +114 -0
  70. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  71. sglang/launch_server_llavavid.py +0 -29
  72. sglang-0.2.13.dist-info/RECORD +0 -112
  73. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  74. {sglang-0.2.13.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
25
25
 
26
26
  # FIXME: temporary solution, remove after next vllm release
27
27
  from vllm.model_executor.custom_op import CustomOp
28
- from vllm.model_executor.layers.activation import GeluAndMul
29
28
 
30
29
  # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
31
30
  from vllm.model_executor.layers.linear import (
@@ -39,14 +38,16 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
39
38
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
39
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
40
 
41
+ from sglang.srt.layers.activation import GeluAndMul
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
47
48
  # Aligned with HF's implementation, using sliding window inclusive with the last token
48
49
  # SGLang assumes exclusive
49
- def get_window_size(config):
50
+ def get_attention_sliding_window_size(config):
50
51
  return config.sliding_window - 1
51
52
 
52
53
 
@@ -135,7 +136,7 @@ class Gemma2MLP(nn.Module):
135
136
  "function. Please set `hidden_act` and `hidden_activation` to "
136
137
  "`gelu_pytorch_tanh`."
137
138
  )
138
- self.act_fn = GeluAndMul(approximate="tanh")
139
+ self.act_fn = GeluAndMul()
139
140
 
140
141
  def forward(self, x: torch.Tensor) -> torch.Tensor:
141
142
  gate_up, _ = self.gate_up_proj(x)
@@ -213,7 +214,11 @@ class Gemma2Attention(nn.Module):
213
214
  self.scaling,
214
215
  num_kv_heads=self.num_kv_heads,
215
216
  layer_id=layer_idx,
216
- sliding_window_size=get_window_size(config) if use_sliding_window else None,
217
+ sliding_window_size=(
218
+ get_attention_sliding_window_size(config)
219
+ if use_sliding_window
220
+ else None
221
+ ),
217
222
  logit_cap=self.config.attn_logit_softcapping,
218
223
  )
219
224
 
@@ -392,6 +397,7 @@ class Gemma2ForCausalLM(nn.Module):
392
397
  self.quant_config = quant_config
393
398
  self.model = Gemma2Model(config, cache_config, quant_config)
394
399
  self.logits_processor = LogitsProcessor(config)
400
+ self.sampler = Sampler()
395
401
 
396
402
  @torch.no_grad()
397
403
  def forward(
@@ -402,12 +408,14 @@ class Gemma2ForCausalLM(nn.Module):
402
408
  input_embeds: torch.Tensor = None,
403
409
  ) -> torch.Tensor:
404
410
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
405
- return self.logits_processor(
411
+ logits_output = self.logits_processor(
406
412
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
407
413
  )
414
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
415
+ return sample_output, logits_output
408
416
 
409
- def get_window_size(self):
410
- return get_window_size(self.config)
417
+ def get_attention_sliding_window_size(self):
418
+ return get_attention_sliding_window_size(self.config)
411
419
 
412
420
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
413
421
  stacked_params_mapping = [
@@ -35,6 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
35
 
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.sampler import Sampler
38
39
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
39
40
 
40
41
 
@@ -261,6 +262,7 @@ class GPTBigCodeForCausalLM(nn.Module):
261
262
  if lora_config:
262
263
  self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
264
  self.logits_processor = LogitsProcessor(config)
265
+ self.sampler = Sampler()
264
266
 
265
267
  @torch.no_grad()
266
268
  def forward(
@@ -270,9 +272,11 @@ class GPTBigCodeForCausalLM(nn.Module):
270
272
  input_metadata: InputMetadata,
271
273
  ) -> torch.Tensor:
272
274
  hidden_states = self.transformer(input_ids, positions, input_metadata)
273
- return self.logits_processor(
275
+ logits_output = self.logits_processor(
274
276
  input_ids, hidden_states, self.lm_head.weight, input_metadata
275
277
  )
278
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
279
+ return sample_output, logits_output
276
280
 
277
281
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
278
282
  params_dict = dict(self.named_parameters(remove_duplicate=False))
sglang/srt/models/grok.py CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.fused_moe import FusedMoE
46
46
  from sglang.srt.layers.layernorm import RMSNorm
47
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
48
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
49
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
51
 
51
52
 
@@ -297,9 +298,13 @@ class Grok1ModelForCausalLM(nn.Module):
297
298
  self.model = Grok1Model(config, quant_config=quant_config)
298
299
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  # Monkey patch _prepare_weights to load pre-sharded weights
302
304
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
305
+
306
+ self.use_presharded_weights = True
307
+
303
308
  warnings.filterwarnings("ignore", category=FutureWarning)
304
309
 
305
310
  def forward(
@@ -310,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
310
315
  input_embeds: torch.Tensor = None,
311
316
  ) -> torch.Tensor:
312
317
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
313
- return self.logits_processor(
318
+ logits_output = self.logits_processor(
314
319
  input_ids, hidden_states, self.lm_head.weight, input_metadata
315
320
  )
321
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
+ return sample_output, logits_output
316
323
 
317
324
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
318
325
  stacked_params_mapping = [
@@ -355,6 +362,13 @@ class Grok1ModelForCausalLM(nn.Module):
355
362
  continue
356
363
  name = name.replace(weight_name, param_name)
357
364
 
365
+ if self.use_presharded_weights:
366
+ extra_kwargs = {
367
+ "use_presharded_weights": self.use_presharded_weights
368
+ }
369
+ else:
370
+ extra_kwargs = {}
371
+
358
372
  param = params_dict[name]
359
373
  weight_loader = param.weight_loader
360
374
  weight_loader(
@@ -363,7 +377,7 @@ class Grok1ModelForCausalLM(nn.Module):
363
377
  weight_name,
364
378
  shard_id=shard_id,
365
379
  expert_id=expert_id,
366
- pre_sharded=get_tensor_model_parallel_world_size() > 1,
380
+ **extra_kwargs,
367
381
  )
368
382
  break
369
383
  else:
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
262
263
  self.model = InternLM2Model(config, quant_config)
263
264
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
264
265
  self.logits_processor = LogitsProcessor(config)
266
+ self.sampler = Sampler()
265
267
 
266
268
  @torch.no_grad()
267
269
  def forward(
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
272
274
  input_embeds: torch.Tensor = None,
273
275
  ) -> torch.Tensor:
274
276
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
275
- return self.logits_processor(
277
+ logits_output = self.logits_processor(
276
278
  input_ids, hidden_states, self.output.weight, input_metadata
277
279
  )
280
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
+ return sample_output, logits_output
278
282
 
279
283
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
280
284
  stacked_params_mapping = [
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
- from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
302
303
  self.model = LlamaModel(config, quant_config=quant_config)
303
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
304
305
  self.logits_processor = LogitsProcessor(config)
306
+ self.sampler = Sampler()
305
307
 
306
308
  @torch.no_grad()
307
309
  def forward(
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
310
312
  positions: torch.Tensor,
311
313
  input_metadata: InputMetadata,
312
314
  input_embeds: torch.Tensor = None,
313
- ) -> LogitProcessorOutput:
315
+ ) -> LogitsProcessorOutput:
314
316
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
- return self.logits_processor(
317
+ logits_output = self.logits_processor(
316
318
  input_ids, hidden_states, self.lm_head.weight, input_metadata
317
319
  )
320
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
321
+ return sample_output, logits_output
318
322
 
319
323
  def get_module_name(self, name):
320
324
  stacked_params_mapping = [
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
24
24
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
25
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
26
 
27
- from sglang.srt.layers.logits_processor import LogitProcessorOutput
27
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
28
28
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
29
29
  from sglang.srt.models.llama2 import LlamaModel
30
30
 
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
65
65
  (input_metadata.batch_size, self.config.classification_out_size)
66
66
  ).to(input_ids.device)
67
67
 
68
- return LogitProcessorOutput(
68
+ return LogitsProcessorOutput(
69
69
  next_token_logits=scores,
70
70
  next_token_logprobs=scores,
71
71
  normalized_prompt_logprobs=scores,
@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
29
29
  positions: torch.Tensor,
30
30
  input_metadata: InputMetadata,
31
31
  input_embeds: torch.Tensor = None,
32
+ get_embedding: bool = True,
32
33
  ) -> EmbeddingPoolerOutput:
34
+ assert (
35
+ get_embedding
36
+ ), "LlamaEmbeddingModel / MistralModel is only used for embedding"
33
37
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
34
38
  return self.pooler(hidden_states, input_metadata)
35
39
 
@@ -15,6 +15,8 @@ limitations under the License.
15
15
 
16
16
  """Inference-only LLaVa model compatible with HuggingFace weights."""
17
17
 
18
+ import math
19
+ import re
18
20
  from typing import Iterable, List, Optional, Tuple
19
21
 
20
22
  import numpy as np
@@ -26,6 +28,8 @@ from transformers import (
26
28
  LlavaConfig,
27
29
  MistralConfig,
28
30
  Qwen2Config,
31
+ SiglipVisionConfig,
32
+ SiglipVisionModel,
29
33
  )
30
34
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
31
35
  from vllm.config import CacheConfig
@@ -63,34 +67,61 @@ class LlavaLlamaForCausalLM(nn.Module):
63
67
  )
64
68
 
65
69
  def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
66
- new_image_feature_len = self.image_feature_len
67
- # now only support spatial_unpad + anyres
68
- if self.mm_patch_merge_type.startswith("spatial"):
70
+
71
+ # hardcode for spatial_unpad + anyres
72
+ image_aspect_ratio = "anyres" if len(image_size) == 1 else "pad"
73
+ offset_list = []
74
+ for image_s in image_size:
75
+ if len(image_size) > 16:
76
+ # 2x2 pooling with stride 2
77
+ new_image_feature_len = (
78
+ math.ceil(self.image_size / self.patch_size / 2) ** 2
79
+ )
80
+ else:
81
+ new_image_feature_len = self.image_feature_len # multiimage
82
+
69
83
  height = width = self.num_patches_per_side
70
- if pt_shape[0] > 1:
71
- if self.image_aspect_ratio == "anyres":
72
- num_patch_width, num_patch_height = get_anyres_image_grid_shape(
73
- image_size,
74
- self.image_grid_pinpoints,
75
- self.vision_tower.config.image_size,
84
+ if "anyres" in image_aspect_ratio:
85
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(
86
+ image_s,
87
+ self.image_grid_pinpoints,
88
+ self.vision_tower.config.image_size,
89
+ )
90
+ h = num_patch_height * height
91
+ w = num_patch_width * width
92
+ new_h, new_w = unpad_image_shape(h, w, image_s)
93
+
94
+ if "anyres_max" in self.config.image_aspect_ratio:
95
+ matched_anyres_max_num_patches = re.match(
96
+ r"anyres_max_(\d+)", self.config.image_aspect_ratio
97
+ )
98
+ if matched_anyres_max_num_patches:
99
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
100
+ # times = math.sqrt(h * w / (max_num_patches * unit**2))
101
+ times = math.sqrt(
102
+ new_h * new_w / (max_num_patches * self.image_feature_len)
76
103
  )
77
- if "unpad" in self.mm_patch_merge_type:
78
- h = num_patch_height * height
79
- w = num_patch_width * width
80
- new_h, new_w = unpad_image_shape(h, w, image_size)
81
- new_image_feature_len += new_h * (new_w + 1)
82
-
83
- pad_ids = pad_value * (
84
- (new_image_feature_len + len(pad_value)) // len(pad_value)
85
- )
86
- offset = input_ids.index(self.config.image_token_index)
87
- # old_len + pad_len - 1, because we need to remove image_token_id
88
- new_input_ids = (
89
- input_ids[:offset]
90
- + pad_ids[:new_image_feature_len]
91
- + input_ids[offset + 1 :]
92
- )
93
- return new_input_ids, offset
104
+ if times > 1.1:
105
+ new_h = int(new_h // times)
106
+ new_w = int(new_w // times)
107
+ new_image_feature_len += new_h * (new_w + 1)
108
+
109
+ pad_ids = pad_value * (
110
+ (new_image_feature_len + len(pad_value)) // len(pad_value)
111
+ )
112
+ # print("calculated new_image_feature_len: ", new_image_feature_len)
113
+ try:
114
+ offset = input_ids.index(self.config.image_token_index)
115
+ except ValueError:
116
+ offset = 0
117
+ # old_len + pad_len - 1, because we need to remove image_token_id
118
+ input_ids = (
119
+ input_ids[:offset]
120
+ + pad_ids[:new_image_feature_len]
121
+ + input_ids[offset + 1 :]
122
+ )
123
+ offset_list.append(offset)
124
+ return input_ids, offset_list
94
125
 
95
126
  def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
96
127
  image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
@@ -124,7 +155,6 @@ class LlavaLlamaForCausalLM(nn.Module):
124
155
 
125
156
  # Embed text input
126
157
  input_embeds = self.language_model.model.embed_tokens(input_ids)
127
-
128
158
  # Embed vision input
129
159
  need_vision = (
130
160
  (positions[input_metadata.extend_start_loc] < self.image_feature_len)
@@ -163,27 +193,73 @@ class LlavaLlamaForCausalLM(nn.Module):
163
193
 
164
194
  if self.mm_patch_merge_type.startswith("spatial"):
165
195
  new_image_features = []
196
+ height = width = self.num_patches_per_side
166
197
  for image_idx, image_feature in enumerate(image_features):
167
- if image_feature.shape[0] > 1:
198
+ if len(image_sizes[image_idx]) == 1:
199
+ image_aspect_ratio = (
200
+ self.config.image_aspect_ratio
201
+ ) # single image
202
+ else:
203
+ image_aspect_ratio = "pad" # multi image
204
+ # image_aspect_ratio = (
205
+ # "anyres" if len(image_sizes[image_idx]) == 1 else "pad"
206
+ # )
207
+ if (
208
+ image_feature.shape[0] > 1
209
+ and "anyres" in image_aspect_ratio
210
+ ):
168
211
  base_image_feature = image_feature[0]
169
212
  image_feature = image_feature[1:]
170
- height = width = self.num_patches_per_side
171
213
  assert height * width == base_image_feature.shape[0]
172
- if self.image_aspect_ratio == "anyres":
173
- (
174
- num_patch_width,
175
- num_patch_height,
176
- ) = get_anyres_image_grid_shape(
177
- image_sizes[image_idx],
178
- self.image_grid_pinpoints,
179
- self.vision_tower.config.image_size,
214
+
215
+ if "anyres_max" in image_aspect_ratio:
216
+ matched_anyres_max_num_patches = re.match(
217
+ r"anyres_max_(\d+)", image_aspect_ratio
180
218
  )
219
+ if matched_anyres_max_num_patches:
220
+ max_num_patches = int(
221
+ matched_anyres_max_num_patches.group(1)
222
+ )
223
+
224
+ if (
225
+ image_aspect_ratio == "anyres"
226
+ or "anyres_max" in image_aspect_ratio
227
+ ):
228
+ vision_tower_image_size = self.image_size
229
+ try:
230
+ num_patch_width, num_patch_height = (
231
+ get_anyres_image_grid_shape(
232
+ image_sizes[image_idx][0],
233
+ self.config.image_grid_pinpoints,
234
+ vision_tower_image_size,
235
+ )
236
+ )
237
+ except Exception as e:
238
+ print(f"Error: {e}")
239
+ num_patch_width, num_patch_height = 2, 2
181
240
  image_feature = image_feature.view(
182
241
  num_patch_height, num_patch_width, height, width, -1
183
242
  )
184
243
  else:
185
- raise NotImplementedError()
244
+ image_feature = image_feature.view(
245
+ 2, 2, height, width, -1
246
+ )
247
+
248
+ # (
249
+ # num_patch_width,
250
+ # num_patch_height,
251
+ # ) = get_anyres_image_grid_shape(
252
+ # image_sizes[image_idx][0],
253
+ # self.image_grid_pinpoints,
254
+ # self.vision_tower.config.image_size,
255
+ # )
256
+
257
+ # image_feature = image_feature.view(
258
+ # num_patch_height, num_patch_width, height, width, -1
259
+ # )
260
+
186
261
  if "unpad" in self.mm_patch_merge_type:
262
+ unit = image_feature.shape[2]
187
263
  image_feature = image_feature.permute(
188
264
  4, 0, 2, 1, 3
189
265
  ).contiguous()
@@ -191,8 +267,23 @@ class LlavaLlamaForCausalLM(nn.Module):
191
267
  2, 3
192
268
  )
193
269
  image_feature = unpad_image(
194
- image_feature, image_sizes[image_idx]
270
+ image_feature, image_sizes[image_idx][0]
195
271
  )
272
+ if (
273
+ "anyres_max" in image_aspect_ratio
274
+ and matched_anyres_max_num_patches
275
+ ):
276
+ c, h, w = image_feature.shape
277
+ times = math.sqrt(
278
+ h * w / (max_num_patches * unit**2)
279
+ )
280
+ if times > 1.1:
281
+ image_feature = image_feature[None]
282
+ image_feature = nn.functional.interpolate(
283
+ image_feature,
284
+ [int(h // times), int(w // times)],
285
+ mode="bilinear",
286
+ )[0]
196
287
  image_feature = torch.cat(
197
288
  (
198
289
  image_feature,
@@ -213,16 +304,31 @@ class LlavaLlamaForCausalLM(nn.Module):
213
304
  image_feature = torch.cat(
214
305
  (base_image_feature, image_feature), dim=0
215
306
  )
307
+ image_feature = image_feature.unsqueeze(0)
216
308
  else:
217
- image_feature = image_feature[0]
218
- if "unpad" in self.mm_patch_merge_type:
219
- image_feature = torch.cat(
220
- (
221
- image_feature,
222
- self.language_model.model.image_newline[None],
223
- ),
224
- dim=0,
309
+ if image_feature.shape[0] > 16: # video
310
+ # 2x2 pooling
311
+ num_of_frames = image_feature.shape[0]
312
+ image_feature = image_feature.view(
313
+ num_of_frames, height, width, -1
225
314
  )
315
+ image_feature = image_feature.permute(
316
+ 0, 3, 1, 2
317
+ ).contiguous() # N, C, H, W
318
+ height, weight = image_feature.shape[2:]
319
+ scaled_shape = [
320
+ math.ceil(height / 2),
321
+ math.ceil(weight / 2),
322
+ ]
323
+ image_feature = nn.functional.interpolate(
324
+ image_feature, size=scaled_shape, mode="bilinear"
325
+ )
326
+ image_feature = (
327
+ image_feature.flatten(2)
328
+ .transpose(1, 2)
329
+ .contiguous()
330
+ ) # N, C, H*W
331
+
226
332
  new_image_features.append(image_feature)
227
333
  image_features = new_image_features
228
334
 
@@ -233,21 +339,22 @@ class LlavaLlamaForCausalLM(nn.Module):
233
339
  continue
234
340
 
235
341
  start_idx = extend_start_loc_cpu[i]
236
- pad_len, pad_dim = image_features[pt].shape # 576, 4096
342
+ pad_dim = image_features[pt].shape[-1] # 576, 4096
237
343
  dim = input_embeds.shape[1]
238
344
  assert (
239
345
  pad_dim == dim
240
346
  ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
241
347
  # Fill in the placeholder for the image
242
348
  try:
243
- input_embeds[
244
- start_idx
245
- + image_offsets[i] : start_idx
246
- + image_offsets[i]
247
- + pad_len
248
- ] = image_features[pt]
349
+ for j, image_off in enumerate(image_offsets[i]):
350
+ # print("actual image_features length: ", image_features[pt][j].shape[0])
351
+ pad_len = image_features[pt][j].shape[0]
352
+ input_embeds[
353
+ start_idx + image_off : start_idx + image_off + pad_len
354
+ ] = image_features[pt][j]
249
355
  except RuntimeError as e:
250
356
  print(f"RuntimeError in llava image encoding: {e}")
357
+ print(image_features[pt].shape)
251
358
  print(input_embeds.shape)
252
359
  print(start_idx, image_offsets[i])
253
360
  pt += 1
@@ -262,9 +369,16 @@ class LlavaLlamaForCausalLM(nn.Module):
262
369
  # load clip vision model by cfg['mm_vision_tower']:
263
370
  # huggingface_name or path_of_clip_relative_to_llava_model_dir
264
371
  vision_path = self.config.mm_vision_tower
265
- self.vision_tower = CLIPVisionModel.from_pretrained(
266
- vision_path, torch_dtype=torch.float16
267
- ).cuda()
372
+ if "clip" in vision_path:
373
+ self.vision_tower = CLIPVisionModel.from_pretrained(
374
+ vision_path, torch_dtype=torch.float16
375
+ ).cuda()
376
+ elif "siglip" in vision_path:
377
+ self.vision_tower = SiglipVisionModel.from_pretrained(
378
+ vision_path, torch_dtype=torch.float16
379
+ ).cuda()
380
+ # Siglip needs all feature tokens
381
+ self.config.mm_vision_select_feature = "full"
268
382
  self.vision_tower.eval()
269
383
 
270
384
  self.vision_feature_layer = self.config.mm_vision_select_layer
@@ -276,8 +390,11 @@ class LlavaLlamaForCausalLM(nn.Module):
276
390
  self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
277
391
  self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
278
392
 
279
- self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
280
- if self.vision_feature_select_strategy == "patch":
393
+ self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
394
+ if (
395
+ self.vision_feature_select_strategy == "patch"
396
+ or self.vision_feature_select_strategy == "full"
397
+ ):
281
398
  pass
282
399
  elif self.vision_feature_select_strategy == "cls_patch":
283
400
  self.image_feature_len += 1
@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
39
39
  from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.sampler import Sampler
42
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
44
 
44
45
 
@@ -297,6 +298,7 @@ class MiniCPMForCausalLM(nn.Module):
297
298
  self.scale_width = self.config.hidden_size / self.config.dim_model_base
298
299
 
299
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
300
302
 
301
303
  @torch.no_grad()
302
304
  def forward(
@@ -314,9 +316,11 @@ class MiniCPMForCausalLM(nn.Module):
314
316
  lm_head_weight = self.model.embed_tokens.weight
315
317
  else:
316
318
  lm_head_weight = self.lm_head.weight
317
- return self.logits_processor(
319
+ logits_output = self.logits_processor(
318
320
  input_ids, hidden_states, lm_head_weight, input_metadata
319
321
  )
322
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
323
+ return sample_output, logits_output
320
324
 
321
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
322
326
  stacked_params_mapping = [
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
299
300
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
300
301
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
301
302
  self.logits_processor = LogitsProcessor(config)
303
+ self.sampler = Sampler()
302
304
 
303
305
  def forward(
304
306
  self,
@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
308
310
  input_embeds: torch.Tensor = None,
309
311
  ) -> torch.Tensor:
310
312
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
311
- return self.logits_processor(
313
+ logits_output = self.logits_processor(
312
314
  input_ids, hidden_states, self.lm_head.weight, input_metadata
313
315
  )
316
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
317
+ return sample_output, logits_output
314
318
 
315
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
316
320
  stacked_params_mapping = [