sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +12 -1
  3. sglang/srt/conversation.py +35 -1
  4. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  5. sglang/srt/entrypoints/http_server_engine.py +1 -1
  6. sglang/srt/layers/communicator.py +3 -1
  7. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  8. sglang/srt/layers/layernorm.py +2 -2
  9. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  10. sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
  11. sglang/srt/layers/moe/ep_moe/layer.py +140 -2
  12. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  13. sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
  14. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  15. sglang/srt/layers/quantization/__init__.py +2 -0
  16. sglang/srt/layers/quantization/fp8.py +28 -7
  17. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  18. sglang/srt/layers/quantization/w4afp8.py +264 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  20. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  21. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  22. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  23. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  24. sglang/srt/managers/cache_controller.py +41 -195
  25. sglang/srt/managers/io_struct.py +8 -1
  26. sglang/srt/managers/mm_utils.py +4 -2
  27. sglang/srt/managers/schedule_batch.py +1 -1
  28. sglang/srt/managers/scheduler.py +17 -5
  29. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  30. sglang/srt/mem_cache/memory_pool.py +113 -63
  31. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  32. sglang/srt/mem_cache/radix_cache.py +8 -4
  33. sglang/srt/models/deepseek_v2.py +16 -2
  34. sglang/srt/models/mllama4.py +360 -79
  35. sglang/srt/multimodal/mm_utils.py +2 -2
  36. sglang/srt/multimodal/processors/mllama4.py +62 -60
  37. sglang/srt/server_args.py +15 -0
  38. sglang/srt/two_batch_overlap.py +3 -0
  39. sglang/srt/utils.py +37 -17
  40. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  41. sglang/utils.py +5 -5
  42. sglang/version.py +1 -1
  43. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
  44. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
  45. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  46. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  47. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,6 @@
1
+ import json as json_lib
2
+ import logging
3
+ import os
1
4
  from collections.abc import Iterable
2
5
  from typing import List, Optional, Set, Tuple
3
6
 
@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
19
22
  from sglang.srt.utils import add_prefix, is_cpu
20
23
 
21
24
  _is_cpu = is_cpu()
25
+ from sglang.srt.model_loader.weight_utils import (
26
+ default_weight_loader,
27
+ maybe_remap_kv_scale_name,
28
+ )
29
+ from sglang.srt.utils import add_prefix
30
+
31
+ logger = logging.getLogger(__name__)
22
32
 
23
33
 
24
34
  class Llama4ForConditionalGeneration(nn.Module):
@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module):
37
47
  self.config = config
38
48
  self.quant_config = quant_config
39
49
 
40
- self.vision_model = Llama4VisionModel(config.vision_config)
41
- self.multi_modal_projector = Llama4MultiModalProjector(config)
50
+ # Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
51
+ self.has_vision = self._has_vision_weights(config)
52
+ if not self.has_vision:
53
+ logger.warning(
54
+ "No vision weights found in checkpoint. Model will run in text-only mode. "
55
+ "Multimodal capabilities (image processing) will be unavailable."
56
+ )
57
+
58
+ if self.has_vision:
59
+ self.vision_model = Llama4VisionModel(config.vision_config)
60
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
61
+ else:
62
+ self.vision_model = None
63
+ self.multi_modal_projector = None
42
64
 
43
65
  # Initialize the language model
44
66
  from sglang.srt.models.llama4 import Llama4ForCausalLM
45
67
 
46
68
  self.language_model = Llama4ForCausalLM(
47
- config.text_config,
69
+ config.text_config if hasattr(config, "text_config") else config,
48
70
  quant_config=quant_config,
49
71
  prefix=add_prefix("language_model", prefix),
50
72
  )
51
73
 
52
- self.logits_processor = LogitsProcessor(config.text_config)
74
+ self.logits_processor = LogitsProcessor(
75
+ config.text_config if hasattr(config, "text_config") else config
76
+ )
77
+
78
+ def _has_vision_weights(self, config) -> bool:
79
+ """Check if the model has vision components by examining the checkpoint."""
80
+ model_path = getattr(config, "_name_or_path", None)
81
+ if not model_path:
82
+ return False
83
+
84
+ # Check if this is a local path first
85
+ if os.path.isdir(model_path):
86
+ index_file = os.path.join(model_path, "model.safetensors.index.json")
87
+ if os.path.exists(index_file):
88
+ return self._check_vision_weights_in_index(index_file)
89
+
90
+ # For HuggingFace models, we need to check the actual checkpoint
91
+ # The config might say it's multimodal, but the checkpoint might be text-only
92
+ try:
93
+ # Try to access the HuggingFace cache directory
94
+ from huggingface_hub import try_to_load_from_cache
95
+
96
+ # Check if index file exists in cache
97
+ index_file_path = try_to_load_from_cache(
98
+ repo_id=model_path,
99
+ filename="model.safetensors.index.json",
100
+ cache_dir=None,
101
+ )
102
+
103
+ if index_file_path and os.path.exists(index_file_path):
104
+ return self._check_vision_weights_in_index(index_file_path)
105
+
106
+ except Exception:
107
+ # If we can't access the cache, fall back to config-based detection
108
+ pass
109
+
110
+ # Fallback, assume text-only
111
+ return False
112
+
113
+ def _check_vision_weights_in_index(self, index_file: str) -> bool:
114
+ """Check if the model.safetensors.index.json contains vision weights."""
115
+ try:
116
+ with open(index_file, "r") as f:
117
+ index_data = json_lib.load(f)
118
+
119
+ vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"]
120
+ weight_names = index_data.get("weight_map", {}).keys()
121
+
122
+ return any(
123
+ pattern in weight_name
124
+ for weight_name in weight_names
125
+ for pattern in vision_patterns
126
+ )
127
+ except (OSError, json_lib.JSONDecodeError, KeyError):
128
+ return False
53
129
 
54
130
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
55
131
  pattern = MultiModalityDataPaddingPatternMultimodalTokens()
@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module):
59
135
  self,
60
136
  items: List[MultimodalDataItem],
61
137
  ) -> torch.Tensor:
138
+ # For text-only models, return None or raise an error
139
+ if not self.has_vision or self.vision_model is None:
140
+ raise ValueError("Vision model not available for text-only checkpoint")
141
+
62
142
  pixel_values = (
63
143
  torch.concat([item.pixel_values for item in items])
64
144
  .to(next(self.vision_model.parameters()).device)
@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
79
159
  **kwargs: object,
80
160
  ) -> torch.Tensor:
81
161
 
162
+ # For text-only models, pass None for image_data_embedding_func
163
+ image_embedding_func = self.get_image_feature if self.has_vision else None
164
+
82
165
  hs = general_mm_embed_routine(
83
166
  input_ids=input_ids,
84
167
  forward_batch=forward_batch,
85
168
  language_model=self.language_model,
86
- image_data_embedding_func=self.get_image_feature,
169
+ image_data_embedding_func=image_embedding_func,
87
170
  positions=positions,
88
171
  )
89
172
 
@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module):
124
207
  return name, loaded_weight
125
208
 
126
209
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
127
-
128
210
  stacked_params_mapping = [
129
211
  # (param_name, shard_name, shard_id)
130
212
  (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
137
219
  ]
138
220
 
139
221
  params_dict = dict(self.named_parameters())
222
+ num_experts = (
223
+ self.config.text_config.num_local_experts
224
+ if hasattr(self.config, "text_config")
225
+ else self.config.num_local_experts
226
+ )
140
227
 
141
- num_experts = self.config.text_config.num_local_experts
142
-
143
- # Params for weights, fp8 weight scales, fp8 activation scales
144
- # (param_name, weight_name, expert_id, shard_id)
145
228
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
146
229
  ckpt_gate_proj_name="gate_proj",
147
230
  ckpt_down_proj_name="down_proj",
@@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module):
150
233
  )
151
234
 
152
235
  for name, loaded_weight in weights:
153
- if not "vision" in name:
236
+ if self._should_skip_weight(name):
237
+ continue
238
+
239
+ name = self._transform_weight_name(name)
240
+
241
+ if "vision" not in name:
154
242
  name, loaded_weight = self.permute_qk_weight_for_rotary(
155
243
  name, loaded_weight
156
244
  )
157
245
 
158
- for param_name, weight_name, shard_id in stacked_params_mapping:
159
- if weight_name not in name:
160
- continue
161
-
162
- if "vision" in name:
163
- continue
164
- name = name.replace(weight_name, param_name)
165
- param = params_dict[name]
166
- weight_loader = param.weight_loader
167
- weight_loader(param, loaded_weight, shard_id)
168
- break
246
+ if self._handle_scale_remapping(name, params_dict):
247
+ continue
248
+
249
+ if self._handle_stacked_params(
250
+ name, loaded_weight, stacked_params_mapping, params_dict
251
+ ):
252
+ continue
253
+
254
+ if self._handle_expert_weights(
255
+ name, loaded_weight, expert_params_mapping, params_dict, num_experts
256
+ ):
257
+ continue
258
+
259
+ self._handle_default_weight(name, loaded_weight, params_dict)
260
+
261
+ def _should_skip_weight(self, name: str) -> bool:
262
+ """Check if we should skip loading this weight."""
263
+ return "vision" in name and not self.has_vision
264
+
265
+ def _transform_weight_name(self, name: str) -> str:
266
+ """Transform weight name by adding language_model prefix if needed."""
267
+ if (
268
+ not name.startswith("language_model.")
269
+ and "vision" not in name
270
+ and "multi_modal_projector" not in name
271
+ ):
272
+ return f"language_model.{name}"
273
+ return name
274
+
275
+ def _handle_scale_remapping(self, name: str, params_dict: dict) -> bool:
276
+ """Handle scale parameter remapping. Returns True if handled."""
277
+ if "scale" in name and "expert" not in name:
278
+ remapped_name = maybe_remap_kv_scale_name(name, params_dict)
279
+ return remapped_name is None
280
+ return False
281
+
282
+ def _handle_stacked_params(
283
+ self,
284
+ name: str,
285
+ loaded_weight: torch.Tensor,
286
+ stacked_params_mapping: list,
287
+ params_dict: dict,
288
+ ) -> bool:
289
+ """Handle stacked parameter loading. Returns True if handled."""
290
+ for param_name, weight_name, shard_id in stacked_params_mapping:
291
+ if weight_name in name and "vision" not in name:
292
+ transformed_name = name.replace(weight_name, param_name)
293
+ param = params_dict[transformed_name]
294
+ param.weight_loader(param, loaded_weight, shard_id)
295
+ return True
296
+ return False
297
+
298
+ def _handle_expert_weights(
299
+ self,
300
+ name: str,
301
+ loaded_weight: torch.Tensor,
302
+ expert_params_mapping: list,
303
+ params_dict: dict,
304
+ num_experts: int,
305
+ ) -> bool:
306
+ """Handle expert weight loading for MoE (Mixture of Experts) layers.
307
+
308
+ Args:
309
+ name: Parameter name from the checkpoint
310
+ loaded_weight: The weight tensor to be loaded
311
+ expert_params_mapping: Mapping of parameter names to expert configurations
312
+ params_dict: Dictionary of model parameters
313
+ num_experts: Total number of experts in the MoE layer
314
+
315
+ Returns:
316
+ bool: True if the parameter was handled (is an expert parameter), False otherwise
317
+ """
318
+ if ".experts" not in name:
319
+ return False
320
+
321
+ if "experts.gate_up_proj" not in name and "experts.down_proj" not in name:
322
+ return self._handle_other_expert_params(
323
+ name, loaded_weight, expert_params_mapping, params_dict
324
+ )
325
+
326
+ if "scale" in name:
327
+ return self._handle_expert_scale_params(
328
+ name, loaded_weight, params_dict, num_experts
329
+ )
330
+ else:
331
+ return self._handle_expert_weight_params(
332
+ name, loaded_weight, params_dict, num_experts
333
+ )
334
+
335
+ def _handle_other_expert_params(
336
+ self,
337
+ name: str,
338
+ loaded_weight: torch.Tensor,
339
+ expert_params_mapping: list,
340
+ params_dict: dict,
341
+ ) -> bool:
342
+ """Handle expert parameters that are not gate_up_proj or down_proj weights.
343
+
344
+ Args:
345
+ name: Parameter name from the checkpoint
346
+ loaded_weight: The weight tensor to be loaded
347
+ expert_params_mapping: List of tuples mapping checkpoint names to model parameters
348
+ params_dict: Dictionary of model parameters
349
+
350
+ Returns:
351
+ bool: True if parameter was found and handled, False otherwise
352
+ """
353
+ for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
354
+ if weight_name in name:
355
+ transformed_name = name.replace(weight_name, param_name)
356
+ param = params_dict[transformed_name]
357
+ param.weight_loader(
358
+ param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id
359
+ )
360
+ return True
361
+ return False
362
+
363
+ def _transform_expert_name(
364
+ self, name: str, is_weight: bool = False
365
+ ) -> Tuple[str, str, List[str]]:
366
+ """Transform expert parameter name and get shard information.
367
+
368
+ Args:
369
+ name: The original parameter name
370
+ is_weight: Whether this is a weight parameter (adds _weight suffix)
371
+
372
+ Returns:
373
+ Tuple of (transformed_name, shard_id, shard_id_list)
374
+ """
375
+ suffix = "_weight" if is_weight else ""
376
+
377
+ if ".gate_up_proj" in name:
378
+ transformed_name = name.replace(
379
+ ".experts.gate_up_proj", f".experts.w13{suffix}"
380
+ )
381
+ shard_id = "w13"
382
+ shard_id_list = ["w1", "w3"]
383
+ else: # down_proj
384
+ transformed_name = name.replace(
385
+ ".experts.down_proj", f".experts.w2{suffix}"
386
+ )
387
+ shard_id = "w2"
388
+ shard_id_list = ["w2"]
389
+
390
+ return transformed_name, shard_id, shard_id_list
391
+
392
+ def _handle_expert_scale_params(
393
+ self,
394
+ name: str,
395
+ loaded_weight: torch.Tensor,
396
+ params_dict: dict,
397
+ num_experts: int,
398
+ ) -> bool:
399
+ """Handle quantization scale parameters for expert weights.
400
+
401
+ Args:
402
+ name: Parameter name containing scale information
403
+ loaded_weight: Scale tensor to be loaded
404
+ params_dict: Dictionary of model parameters
405
+ num_experts: Total number of experts for broadcast operations
406
+
407
+ Returns:
408
+ bool: True (always handles scale parameters)
409
+ """
410
+ import re
411
+
412
+ # Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
413
+ expert_match = re.search(r"experts\.(\d+)\.", name)
414
+
415
+ # Transform name
416
+ transformed_name, _, _ = self._transform_expert_name(name)
417
+
418
+ if transformed_name not in params_dict:
419
+ return True
420
+
421
+ param = params_dict[transformed_name]
422
+
423
+ # Handle scale parameters
424
+ if expert_match:
425
+ # If we have a specific expert ID, only load for that expert
426
+ expert_id = int(expert_match.group(1))
427
+ # For scale parameters, we can directly set the value
428
+ param.data[expert_id] = loaded_weight
429
+ else:
430
+ # No expert ID found - this is a single scale for all experts
431
+ # Load the same scale for all experts
432
+ for expert_id in range(num_experts):
433
+ param.data[expert_id] = loaded_weight
434
+
435
+ return True
436
+
437
+ def _handle_expert_weight_params(
438
+ self,
439
+ name: str,
440
+ loaded_weight: torch.Tensor,
441
+ params_dict: dict,
442
+ num_experts: int,
443
+ ) -> bool:
444
+ """Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
445
+
446
+ Args:
447
+ name: Parameter name (should contain gate_up_proj or down_proj)
448
+ loaded_weight: Weight tensor(s) to be loaded
449
+ params_dict: Dictionary of model parameters
450
+ num_experts: Total number of experts for tensor distribution
451
+
452
+ Returns:
453
+ bool: True (always handles weight parameters)
454
+ """
455
+ # Transform name and get shard info
456
+ transformed_name, _, shard_id_list = self._transform_expert_name(
457
+ name, is_weight=True
458
+ )
459
+
460
+ if ".gate_up_proj" in name:
461
+ loaded_weight_list = loaded_weight.chunk(2, dim=-1)
462
+ else: # down_proj
463
+ loaded_weight_list = [loaded_weight]
464
+
465
+ for param_name, weight_chunk, shard_id in zip(
466
+ [transformed_name] * len(shard_id_list), loaded_weight_list, shard_id_list
467
+ ):
468
+ if param_name not in params_dict:
469
+ continue
470
+
471
+ param = params_dict[param_name]
472
+ weight_loader = param.weight_loader
473
+
474
+ # Handle the case where loaded_weight might be a single tensor for all experts
475
+ if weight_chunk.dim() == 2:
476
+ # Single tensor case - load for all experts
477
+ for expert_id in range(num_experts):
478
+ weight_loader(
479
+ param,
480
+ weight_chunk.T,
481
+ param_name,
482
+ shard_id=shard_id,
483
+ expert_id=expert_id,
484
+ )
169
485
  else:
170
- if ".experts" in name:
171
- # NOTE: llama4 fp8 has different weight format for experts
172
- if (
173
- "experts.gate_up_proj" not in name
174
- and "experts.down_proj" not in name
175
- ):
176
- for mapping in expert_params_mapping:
177
- param_name, weight_name, expert_id, shard_id = mapping
178
- if weight_name not in name:
179
- continue
180
- name = name.replace(weight_name, param_name)
181
- param = params_dict[name]
182
- weight_loader = param.weight_loader
183
- weight_loader(
184
- param,
185
- loaded_weight,
186
- name,
187
- shard_id=shard_id,
188
- expert_id=expert_id,
189
- )
190
- break
191
- else:
192
- if ".gate_up_proj" in name:
193
- name_list = [
194
- name.replace(
195
- ".experts.gate_up_proj", ".experts.w13_weight"
196
- )
197
- ] * 2
198
- loaded_weight_list = loaded_weight.chunk(2, dim=-1)
199
- shard_id_list = ["w1", "w3"]
200
- else:
201
- name_list = [
202
- name.replace(".experts.down_proj", ".experts.w2_weight")
203
- ]
204
- shard_id_list = ["w2"]
205
- loaded_weight_list = [loaded_weight]
206
- for name, loaded_weight, shard_id in zip(
207
- name_list, loaded_weight_list, shard_id_list
208
- ):
209
- param = params_dict[name]
210
- weight_loader = param.weight_loader
211
- for expert_id in range(num_experts):
212
- weight_loader(
213
- param,
214
- loaded_weight[expert_id].T,
215
- name,
216
- shard_id=shard_id,
217
- expert_id=expert_id,
218
- )
219
- else:
220
- # Skip loading extra bias for GPTQ models.
221
- if name.endswith(".bias") and name not in params_dict:
222
- continue
223
- param = params_dict[name]
224
- weight_loader = getattr(
225
- param, "weight_loader", default_weight_loader
486
+ # Multiple experts case - load each expert's weights
487
+ for expert_id in range(num_experts):
488
+ weight_loader(
489
+ param,
490
+ weight_chunk[expert_id].T,
491
+ param_name,
492
+ shard_id=shard_id,
493
+ expert_id=expert_id,
226
494
  )
227
- weight_loader(param, loaded_weight)
495
+
496
+ return True
497
+
498
+ def _handle_default_weight(
499
+ self, name: str, loaded_weight: torch.Tensor, params_dict: dict
500
+ ):
501
+ """Handle default weight loading."""
502
+ # Skip loading extra bias for GPTQ models
503
+ if name.endswith(".bias") and name not in params_dict:
504
+ return
505
+
506
+ param = params_dict[name]
507
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
508
+ weight_loader(param, loaded_weight)
228
509
 
229
510
  def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
230
511
  if hasattr(self.language_model, "set_eagle3_layers_to_capture"):
@@ -28,12 +28,12 @@ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
28
28
 
29
29
  """
30
30
  import ast
31
- import base64
32
31
  import math
33
32
  import re
34
33
  from io import BytesIO
35
34
 
36
35
  import numpy as np
36
+ import pybase64
37
37
  from PIL import Image
38
38
 
39
39
  from sglang.srt.utils import flatten_nested_list
@@ -252,7 +252,7 @@ def process_anyres_image(image, processor, grid_pinpoints):
252
252
 
253
253
 
254
254
  def load_image_from_base64(image):
255
- return Image.open(BytesIO(base64.b64decode(image)))
255
+ return Image.open(BytesIO(pybase64.b64decode(image, validate=True)))
256
256
 
257
257
 
258
258
  def expand2square(pil_img, background_color):
@@ -60,70 +60,72 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
60
60
  )
61
61
 
62
62
  # Handle image resolutions and aspect ratios
63
- if "pixel_values" in processor_output:
64
- image_processor = processor.image_processor
65
- tokenizer = self._processor.tokenizer
63
+ if "pixel_values" not in processor_output: # no image processed
64
+ return None
66
65
 
67
- # Calculate tile size and find supported resolutions
68
- tile_size = self.vision_config.image_size
69
- max_num_tiles = getattr(self.vision_config, "max_patches", 1)
66
+ image_processor = processor.image_processor
67
+ tokenizer = self._processor.tokenizer
70
68
 
71
- possible_resolutions = find_supported_resolutions(
72
- max_num_chunks=max_num_tiles,
73
- patch_size=SizeDict(height=tile_size, width=tile_size),
69
+ # Calculate tile size and find supported resolutions
70
+ tile_size = self.vision_config.image_size
71
+ max_num_tiles = getattr(self.vision_config, "max_patches", 1)
72
+
73
+ possible_resolutions = find_supported_resolutions(
74
+ max_num_chunks=max_num_tiles,
75
+ patch_size=SizeDict(height=tile_size, width=tile_size),
76
+ )
77
+
78
+ # Find best fit for each image
79
+ best_fit_sizes = [
80
+ get_best_fit(
81
+ (image.size[1], image.size[0]), # (height, width)
82
+ torch.tensor(possible_resolutions),
83
+ resize_to_max_canvas=image_processor.resize_to_max_canvas,
74
84
  )
85
+ for image in processed_data.images
86
+ ]
87
+
88
+ # Calculate aspect ratios and patches per image
89
+ aspect_ratios = [
90
+ (image_size[0] // tile_size, image_size[1] // tile_size)
91
+ for image_size in best_fit_sizes
92
+ ]
93
+
94
+ patches_per_image = [
95
+ 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
96
+ ]
97
+
98
+ # Add to image_inputs
99
+ processor_output["aspect_ratios"] = aspect_ratios
100
+ processor_output["patches_per_image"] = torch.tensor(patches_per_image)
101
+
102
+ # Process embed_is_patch
103
+ vocab = tokenizer.get_vocab()
104
+ patch_id = vocab.get(processor.img_patch_token, -1)
105
+ image_end_id = vocab.get(processor.end_of_img_token, -1)
106
+
107
+ if patch_id != -1 and image_end_id != -1:
108
+ input_ids = processor_output["input_ids"].view(-1)
109
+
110
+ # Remove BOS token if present
111
+ if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
112
+ input_ids = input_ids[1:]
113
+
114
+ # Find image end indices and split input_ids
115
+ image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
116
+
117
+ if image_end_indices.size(0) > 0:
118
+ # Split at image boundaries
119
+ split_indices = (image_end_indices + 1)[:-1]
120
+ split_input_ids = torch.tensor_split(input_ids, split_indices)
121
+ split_input_ids = [x for x in split_input_ids if x.numel() > 0]
122
+
123
+ # Create embed_is_patch for each image
124
+ embed_is_patch = []
125
+ for per_image_input_ids in split_input_ids:
126
+ embed_is_patch.append(per_image_input_ids == patch_id)
75
127
 
76
- # Find best fit for each image
77
- best_fit_sizes = [
78
- get_best_fit(
79
- (image.size[1], image.size[0]), # (height, width)
80
- torch.tensor(possible_resolutions),
81
- resize_to_max_canvas=image_processor.resize_to_max_canvas,
82
- )
83
- for image in processed_data.images
84
- ]
85
-
86
- # Calculate aspect ratios and patches per image
87
- aspect_ratios = [
88
- (image_size[0] // tile_size, image_size[1] // tile_size)
89
- for image_size in best_fit_sizes
90
- ]
91
-
92
- patches_per_image = [
93
- 1 if r_h * r_w == 1 else 1 + r_h * r_w for (r_h, r_w) in aspect_ratios
94
- ]
95
-
96
- # Add to image_inputs
97
- processor_output["aspect_ratios"] = aspect_ratios
98
- processor_output["patches_per_image"] = torch.tensor(patches_per_image)
99
-
100
- # Process embed_is_patch
101
- vocab = tokenizer.get_vocab()
102
- patch_id = vocab.get(processor.img_patch_token, -1)
103
- image_end_id = vocab.get(processor.end_of_img_token, -1)
104
-
105
- if patch_id != -1 and image_end_id != -1:
106
- input_ids = processor_output["input_ids"].view(-1)
107
-
108
- # Remove BOS token if present
109
- if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
110
- input_ids = input_ids[1:]
111
-
112
- # Find image end indices and split input_ids
113
- image_end_indices = (input_ids == image_end_id).nonzero().view(-1)
114
-
115
- if image_end_indices.size(0) > 0:
116
- # Split at image boundaries
117
- split_indices = (image_end_indices + 1)[:-1]
118
- split_input_ids = torch.tensor_split(input_ids, split_indices)
119
- split_input_ids = [x for x in split_input_ids if x.numel() > 0]
120
-
121
- # Create embed_is_patch for each image
122
- embed_is_patch = []
123
- for per_image_input_ids in split_input_ids:
124
- embed_is_patch.append(per_image_input_ids == patch_id)
125
-
126
- processor_output["embed_is_patch"] = embed_is_patch
128
+ processor_output["embed_is_patch"] = embed_is_patch
127
129
 
128
130
  # Convert to the format expected by SGLang
129
131
  processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]