fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ..interpolate import resize_bilinear
|
|
8
|
+
from .config import VisionConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def check_array_shape(arr):
|
|
12
|
+
shape = arr.shape
|
|
13
|
+
|
|
14
|
+
# Check if the shape has 4 dimensions
|
|
15
|
+
if len(shape) != 4:
|
|
16
|
+
return False
|
|
17
|
+
|
|
18
|
+
out_channels, kH, KW, _ = shape
|
|
19
|
+
|
|
20
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
21
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
22
|
+
return True
|
|
23
|
+
else:
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Attention(nn.Module):
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
dims: int,
|
|
31
|
+
num_heads: int,
|
|
32
|
+
query_input_dims: Optional[int] = None,
|
|
33
|
+
key_input_dims: Optional[int] = None,
|
|
34
|
+
value_input_dims: Optional[int] = None,
|
|
35
|
+
value_dims: Optional[int] = None,
|
|
36
|
+
value_output_dims: Optional[int] = None,
|
|
37
|
+
bias: bool = True,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
if (dims % num_heads) != 0:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
"The input feature dimensions should be divisible by the "
|
|
44
|
+
f"number of heads ({dims} % {num_heads}) != 0"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
query_input_dims = query_input_dims or dims
|
|
48
|
+
key_input_dims = key_input_dims or dims
|
|
49
|
+
value_input_dims = value_input_dims or key_input_dims
|
|
50
|
+
value_dims = value_dims or dims
|
|
51
|
+
value_output_dims = value_output_dims or dims
|
|
52
|
+
|
|
53
|
+
self.num_heads = num_heads
|
|
54
|
+
head_dim = dims // num_heads
|
|
55
|
+
self.scale = head_dim**-0.5
|
|
56
|
+
|
|
57
|
+
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
|
58
|
+
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
|
59
|
+
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
|
60
|
+
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
|
61
|
+
|
|
62
|
+
def __call__(self, x, mask=None):
|
|
63
|
+
queries = self.q_proj(x)
|
|
64
|
+
keys = self.k_proj(x)
|
|
65
|
+
values = self.v_proj(x)
|
|
66
|
+
|
|
67
|
+
num_heads = self.num_heads
|
|
68
|
+
B, L, D = queries.shape
|
|
69
|
+
_, S, _ = keys.shape
|
|
70
|
+
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
71
|
+
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
72
|
+
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
73
|
+
|
|
74
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
75
|
+
queries, keys, values, scale=self.scale, mask=mask
|
|
76
|
+
)
|
|
77
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
78
|
+
return self.out_proj(output)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class MLP(nn.Module):
|
|
82
|
+
def __init__(self, config: VisionConfig):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.activation_fn = nn.GELU(approx="precise")
|
|
85
|
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
|
|
86
|
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
|
|
87
|
+
|
|
88
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
89
|
+
x = self.fc1(x)
|
|
90
|
+
x = self.activation_fn(x)
|
|
91
|
+
x = self.fc2(x)
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class EncoderLayer(nn.Module):
|
|
96
|
+
def __init__(self, config: VisionConfig):
|
|
97
|
+
super().__init__()
|
|
98
|
+
self.embed_dim = config.hidden_size
|
|
99
|
+
self.self_attn = Attention(
|
|
100
|
+
config.hidden_size, config.num_attention_heads, bias=True
|
|
101
|
+
)
|
|
102
|
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
103
|
+
self.mlp = MLP(config)
|
|
104
|
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
105
|
+
|
|
106
|
+
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
|
107
|
+
r = self.self_attn(self.layer_norm1(x), mask)
|
|
108
|
+
h = x + r
|
|
109
|
+
r = self.mlp(self.layer_norm2(h))
|
|
110
|
+
return h + r
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class Encoder(nn.Module):
|
|
114
|
+
def __init__(self, config: VisionConfig):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
117
|
+
|
|
118
|
+
def __call__(
|
|
119
|
+
self,
|
|
120
|
+
x: mx.array,
|
|
121
|
+
output_hidden_states: Optional[bool] = None,
|
|
122
|
+
mask: Optional[mx.array] = None,
|
|
123
|
+
) -> mx.array:
|
|
124
|
+
encoder_states = (x,) if output_hidden_states else None
|
|
125
|
+
h = x
|
|
126
|
+
for l in self.layers:
|
|
127
|
+
x = l(x, mask=mask)
|
|
128
|
+
if output_hidden_states:
|
|
129
|
+
encoder_states = encoder_states + (x,)
|
|
130
|
+
|
|
131
|
+
h = x
|
|
132
|
+
|
|
133
|
+
return (h, encoder_states)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def gaussian_blur_axis(image, sigma, axis):
|
|
137
|
+
"""
|
|
138
|
+
Applies a 1D Gaussian blur along the given axis.
|
|
139
|
+
This version works for arrays with any number of dimensions.
|
|
140
|
+
"""
|
|
141
|
+
radius = int(3 * sigma)
|
|
142
|
+
if radius < 1:
|
|
143
|
+
return image
|
|
144
|
+
x = mx.arange(-radius, radius + 1)
|
|
145
|
+
kernel = mx.exp(-(x**2) / (2 * sigma**2))
|
|
146
|
+
kernel = kernel / mx.sum(kernel)
|
|
147
|
+
|
|
148
|
+
# MLX doesn't have a direct apply_along_axis equivalent,
|
|
149
|
+
# so we'll implement the convolution differently based on the axis
|
|
150
|
+
|
|
151
|
+
# Helper function to apply 1D convolution along specific axis
|
|
152
|
+
def conv_1d(array, kernel, axis):
|
|
153
|
+
# Reshape kernel to broadcast along the right dimensions
|
|
154
|
+
kernel_shape = [1] * image.ndim
|
|
155
|
+
kernel_shape[axis] = len(kernel)
|
|
156
|
+
kernel_reshaped = kernel.reshape(kernel_shape)
|
|
157
|
+
|
|
158
|
+
# Pad the array
|
|
159
|
+
pad_width = [(0, 0)] * image.ndim
|
|
160
|
+
pad_width[axis] = (radius, radius)
|
|
161
|
+
padded = mx.pad(array, pad_width, mode="edge")
|
|
162
|
+
|
|
163
|
+
# Perform convolution via sliding window sum
|
|
164
|
+
result = mx.zeros_like(array)
|
|
165
|
+
slices = [slice(None)] * padded.ndim
|
|
166
|
+
|
|
167
|
+
for i in range(2 * radius + 1):
|
|
168
|
+
slices[axis] = slice(i, i + array.shape[axis])
|
|
169
|
+
result = result + padded[tuple(slices)] * kernel_reshaped
|
|
170
|
+
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
return conv_1d(image, kernel, axis)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class VisionEmbeddings(nn.Module):
|
|
177
|
+
def __init__(self, config: VisionConfig):
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.config = config
|
|
180
|
+
self.embed_dim = config.hidden_size
|
|
181
|
+
self.image_size = config.image_size
|
|
182
|
+
self.patch_size = config.patch_size
|
|
183
|
+
|
|
184
|
+
self.patch_embedding = nn.Conv2d(
|
|
185
|
+
config.num_channels,
|
|
186
|
+
config.hidden_size,
|
|
187
|
+
kernel_size=self.patch_size,
|
|
188
|
+
stride=self.patch_size,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
192
|
+
self.num_positions = self.num_patches
|
|
193
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def resize_positional_embeddings(
|
|
197
|
+
positional_embeddings: mx.array,
|
|
198
|
+
spatial_shapes: mx.array,
|
|
199
|
+
max_length: int,
|
|
200
|
+
) -> mx.array:
|
|
201
|
+
"""
|
|
202
|
+
Resize positional embeddings to image-specific size and pad to a fixed size.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
positional_embeddings (`torch.Tensor`):
|
|
206
|
+
Position embeddings of shape (height, width, embed_dim)
|
|
207
|
+
spatial_shapes (`torch.LongTensor`):
|
|
208
|
+
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
|
|
209
|
+
max_length (`int`):
|
|
210
|
+
Maximum length of the positional embeddings to pad resized positional embeddings to
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
|
|
214
|
+
"""
|
|
215
|
+
batch_size = spatial_shapes.shape[0]
|
|
216
|
+
embed_dim = positional_embeddings.shape[-1]
|
|
217
|
+
source_dtype = positional_embeddings.dtype
|
|
218
|
+
|
|
219
|
+
resulted_positional_embeddings = mx.zeros(
|
|
220
|
+
(batch_size, max_length, embed_dim)
|
|
221
|
+
).astype(source_dtype)
|
|
222
|
+
|
|
223
|
+
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
|
|
224
|
+
positional_embeddings = positional_embeddings.transpose(2, 0, 1).reshape(
|
|
225
|
+
1, embed_dim, -1
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
|
|
229
|
+
if positional_embeddings.device.type == "cpu":
|
|
230
|
+
positional_embeddings = positional_embeddings.astype(mx.float32)
|
|
231
|
+
|
|
232
|
+
for i in range(batch_size):
|
|
233
|
+
# (1, dim, height, width) -> (1, dim, target_height, target_width)
|
|
234
|
+
height, width = spatial_shapes[i]
|
|
235
|
+
# Then upsample width dimension
|
|
236
|
+
resized_embeddings = resize_bilinear(
|
|
237
|
+
positional_embeddings,
|
|
238
|
+
(height, width),
|
|
239
|
+
align_corners=False,
|
|
240
|
+
antialias=True,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
|
|
244
|
+
resized_embeddings = resized_embeddings.reshape(
|
|
245
|
+
embed_dim, height * width
|
|
246
|
+
).transpose(0, 1)
|
|
247
|
+
|
|
248
|
+
# Cast to original dtype
|
|
249
|
+
resized_embeddings = resized_embeddings.astype(source_dtype)
|
|
250
|
+
|
|
251
|
+
resulted_positional_embeddings[i, : height * width] = resized_embeddings
|
|
252
|
+
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
|
|
253
|
+
|
|
254
|
+
return resulted_positional_embeddings
|
|
255
|
+
|
|
256
|
+
def __call__(
|
|
257
|
+
self, x: mx.array, spatial_shapes: Optional[mx.array] = None
|
|
258
|
+
) -> mx.array:
|
|
259
|
+
batch_size = x.shape[0]
|
|
260
|
+
patch_embeddings = self.patch_embedding(x)
|
|
261
|
+
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
|
|
262
|
+
if spatial_shapes is None:
|
|
263
|
+
position_ids = mx.array(np.arange(self.num_positions)[None, :])
|
|
264
|
+
embeddings = patch_embeddings
|
|
265
|
+
embeddings += self.position_embedding(position_ids)
|
|
266
|
+
|
|
267
|
+
else:
|
|
268
|
+
# Get positional resized and padded positional embeddings
|
|
269
|
+
positional_embeddings = self.position_embedding.weight.reshape(
|
|
270
|
+
self.position_embedding_size, self.position_embedding_size, -1
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
resized_positional_embeddings = self.resize_positional_embeddings(
|
|
274
|
+
positional_embeddings, spatial_shapes, max_length=x.shape[1]
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Add positional embeddings to patch embeddings
|
|
278
|
+
embeddings = patch_embeds + resized_positional_embeddings
|
|
279
|
+
return embeddings
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
class SigLipVisionModel(nn.Module):
|
|
283
|
+
def __init__(self, config: VisionConfig):
|
|
284
|
+
super().__init__()
|
|
285
|
+
|
|
286
|
+
self.embeddings = VisionEmbeddings(config)
|
|
287
|
+
self.encoder = Encoder(config)
|
|
288
|
+
self.post_layernorm = nn.LayerNorm(config.hidden_size)
|
|
289
|
+
|
|
290
|
+
def __call__(
|
|
291
|
+
self,
|
|
292
|
+
x: mx.array,
|
|
293
|
+
spatial_shapes: mx.array,
|
|
294
|
+
output_hidden_states: Optional[bool] = None,
|
|
295
|
+
) -> mx.array:
|
|
296
|
+
x = self.embeddings(x, spatial_shapes)
|
|
297
|
+
x = x.astype(self.embeddings.patch_embedding.weight.dtype)
|
|
298
|
+
encoder_outputs = self.encoder(
|
|
299
|
+
x=x, output_hidden_states=output_hidden_states, mask=None
|
|
300
|
+
)
|
|
301
|
+
pooler_output = self.post_layernorm(encoder_outputs[0])
|
|
302
|
+
return pooler_output, x, encoder_outputs[-1]
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class VisionModel(nn.Module):
|
|
306
|
+
def __init__(self, config: VisionConfig):
|
|
307
|
+
super().__init__()
|
|
308
|
+
self.model_type = config.model_type
|
|
309
|
+
if self.model_type not in ["siglip_vision_model"]:
|
|
310
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
311
|
+
|
|
312
|
+
self.vision_model = SigLipVisionModel(config)
|
|
313
|
+
|
|
314
|
+
def __call__(
|
|
315
|
+
self,
|
|
316
|
+
x: mx.array,
|
|
317
|
+
spatial_shapes: Optional[mx.array] = None,
|
|
318
|
+
output_hidden_states: Optional[bool] = None,
|
|
319
|
+
) -> mx.array:
|
|
320
|
+
return self.vision_model(x, spatial_shapes, output_hidden_states)
|
|
321
|
+
|
|
322
|
+
def sanitize(self, weights):
|
|
323
|
+
sanitized_weights = {}
|
|
324
|
+
for k, v in weights.items():
|
|
325
|
+
if "position_ids" in k:
|
|
326
|
+
# Remove unused position_ids
|
|
327
|
+
continue
|
|
328
|
+
elif "patch_embedding.weight" in k:
|
|
329
|
+
# PyTorch conv2d weight tensors have shape:
|
|
330
|
+
# [out_channels, in_channels, kH, KW]
|
|
331
|
+
# MLX conv2d expects the weight be of shape:
|
|
332
|
+
# [out_channels, kH, KW, in_channels]
|
|
333
|
+
if check_array_shape(v):
|
|
334
|
+
sanitized_weights[k] = v
|
|
335
|
+
else:
|
|
336
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
337
|
+
else:
|
|
338
|
+
sanitized_weights[k] = v
|
|
339
|
+
|
|
340
|
+
return sanitized_weights
|
mlx_vlm/models/base.py
ADDED
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import math
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
import mlx.nn as nn
|
|
9
|
+
from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
|
|
10
|
+
from PIL import Image
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class LanguageModelOutput:
|
|
15
|
+
logits: mx.array
|
|
16
|
+
hidden_states: Optional[List[mx.array]] = None
|
|
17
|
+
cross_attention_states: Optional[List[mx.array]] = None
|
|
18
|
+
encoder_outputs: Optional[List[mx.array]] = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class InputEmbeddingsFeatures:
|
|
23
|
+
inputs_embeds: mx.array
|
|
24
|
+
attention_mask_4d: Optional[mx.array] = None
|
|
25
|
+
visual_pos_masks: Optional[mx.array] = None
|
|
26
|
+
deepstack_visual_embeds: Optional[mx.array] = None
|
|
27
|
+
per_layer_inputs: Optional[mx.array] = None
|
|
28
|
+
cross_attention_states: Optional[mx.array] = None
|
|
29
|
+
cross_attention_mask: Optional[mx.array] = None
|
|
30
|
+
full_text_row_masked_out_mask: Optional[mx.array] = None
|
|
31
|
+
decoder_inputs_embeds: Optional[mx.array] = None
|
|
32
|
+
attention_mask: Optional[mx.array] = None # For encoder-decoder models
|
|
33
|
+
|
|
34
|
+
def to_dict(self):
|
|
35
|
+
return {
|
|
36
|
+
"inputs_embeds": self.inputs_embeds,
|
|
37
|
+
"attention_mask_4d": self.attention_mask_4d,
|
|
38
|
+
"visual_pos_masks": self.visual_pos_masks,
|
|
39
|
+
"deepstack_visual_embeds": self.deepstack_visual_embeds,
|
|
40
|
+
"per_layer_inputs": self.per_layer_inputs,
|
|
41
|
+
"cross_attention_states": self.cross_attention_states,
|
|
42
|
+
"cross_attention_mask": self.cross_attention_mask,
|
|
43
|
+
"full_text_row_masked_out_mask": self.full_text_row_masked_out_mask,
|
|
44
|
+
"decoder_inputs_embeds": self.decoder_inputs_embeds,
|
|
45
|
+
"attention_mask": self.attention_mask,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class BaseModelConfig:
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_dict(cls, params):
|
|
53
|
+
return cls(
|
|
54
|
+
**{
|
|
55
|
+
k: v
|
|
56
|
+
for k, v in params.items()
|
|
57
|
+
if k in inspect.signature(cls).parameters
|
|
58
|
+
}
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def to_dict(self):
|
|
62
|
+
return {k: v for k, v in self.__dict__.items() if v is not None}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class BaseImageProcessor:
|
|
66
|
+
"""
|
|
67
|
+
Base image processor class. Subclasses should implement preprocess().
|
|
68
|
+
Transformers imports are deferred to __init__ for faster module loading.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
image_mean=(0.5, 0.5, 0.5),
|
|
74
|
+
image_std=(0.5, 0.5, 0.5),
|
|
75
|
+
size=(384, 384),
|
|
76
|
+
crop_size: Dict[str, int] = None,
|
|
77
|
+
resample=None,
|
|
78
|
+
rescale_factor=1 / 255,
|
|
79
|
+
data_format=None,
|
|
80
|
+
):
|
|
81
|
+
from transformers.image_processing_utils import get_size_dict
|
|
82
|
+
from transformers.image_utils import ChannelDimension, PILImageResampling
|
|
83
|
+
|
|
84
|
+
if resample is None:
|
|
85
|
+
resample = PILImageResampling.BICUBIC
|
|
86
|
+
if data_format is None:
|
|
87
|
+
data_format = ChannelDimension.FIRST
|
|
88
|
+
|
|
89
|
+
crop_size = (
|
|
90
|
+
crop_size if crop_size is not None else {"height": 384, "width": 384}
|
|
91
|
+
)
|
|
92
|
+
crop_size = get_size_dict(
|
|
93
|
+
crop_size, default_to_square=True, param_name="crop_size"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self.image_mean = image_mean
|
|
97
|
+
self.image_std = image_std
|
|
98
|
+
self.size = size
|
|
99
|
+
self.resample = resample
|
|
100
|
+
self.rescale_factor = rescale_factor
|
|
101
|
+
self.data_format = data_format
|
|
102
|
+
self.crop_size = crop_size
|
|
103
|
+
|
|
104
|
+
def rescale(
|
|
105
|
+
self,
|
|
106
|
+
image,
|
|
107
|
+
scale: float,
|
|
108
|
+
input_data_format: str = "channels_first",
|
|
109
|
+
):
|
|
110
|
+
"""Rescale an image by a scale factor."""
|
|
111
|
+
return image * scale
|
|
112
|
+
|
|
113
|
+
def normalize(
|
|
114
|
+
self,
|
|
115
|
+
image,
|
|
116
|
+
mean,
|
|
117
|
+
std,
|
|
118
|
+
input_data_format: str = "channels_first",
|
|
119
|
+
):
|
|
120
|
+
"""Normalize an image with mean and std."""
|
|
121
|
+
import numpy as np
|
|
122
|
+
|
|
123
|
+
mean = np.array(mean, dtype=image.dtype)
|
|
124
|
+
std = np.array(std, dtype=image.dtype)
|
|
125
|
+
|
|
126
|
+
if input_data_format == "channels_first":
|
|
127
|
+
# Image shape: [C, H, W]
|
|
128
|
+
mean = mean[:, None, None]
|
|
129
|
+
std = std[:, None, None]
|
|
130
|
+
else:
|
|
131
|
+
# Image shape: [H, W, C]
|
|
132
|
+
pass # mean and std are already in correct shape
|
|
133
|
+
|
|
134
|
+
return (image - mean) / std
|
|
135
|
+
|
|
136
|
+
@abstractmethod
|
|
137
|
+
def preprocess(self, images):
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def expand2square(pil_img, background_color):
|
|
142
|
+
width, height = pil_img.size
|
|
143
|
+
if width == height:
|
|
144
|
+
return pil_img
|
|
145
|
+
elif width > height:
|
|
146
|
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
147
|
+
result.paste(pil_img, (0, (width - height) // 2))
|
|
148
|
+
return result
|
|
149
|
+
else:
|
|
150
|
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
151
|
+
result.paste(pil_img, ((height - width) // 2, 0))
|
|
152
|
+
return result
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def check_array_shape(arr):
|
|
156
|
+
shape = arr.shape
|
|
157
|
+
|
|
158
|
+
# Check if the shape has 4 dimensions
|
|
159
|
+
if len(shape) == 4:
|
|
160
|
+
out_channels, kH, KW, _ = shape
|
|
161
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
162
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
163
|
+
return True
|
|
164
|
+
else:
|
|
165
|
+
return False
|
|
166
|
+
# Check if the shape has 3 dimensions
|
|
167
|
+
elif len(shape) == 3:
|
|
168
|
+
_, kW, out_channels = shape
|
|
169
|
+
# Check if out_channels is the largest
|
|
170
|
+
if kW >= out_channels:
|
|
171
|
+
return True
|
|
172
|
+
else:
|
|
173
|
+
return False
|
|
174
|
+
else:
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def check_activation_stats(name, tensor):
|
|
179
|
+
"""Helper function to check for anomalies and log stats."""
|
|
180
|
+
|
|
181
|
+
print(f"--- Activation Stats: {name} ---")
|
|
182
|
+
# Check for NaNs/Infs
|
|
183
|
+
has_nan = mx.isnan(tensor).any()
|
|
184
|
+
has_inf = mx.isinf(tensor).any()
|
|
185
|
+
if has_nan:
|
|
186
|
+
print(f"WARNING: Found NaN in {name}")
|
|
187
|
+
if has_inf:
|
|
188
|
+
print(f"WARNING: Found Inf in {name}")
|
|
189
|
+
|
|
190
|
+
# Calculate and print stats (ensure computation happens)
|
|
191
|
+
min_val = mx.min(tensor).item()
|
|
192
|
+
max_val = mx.max(tensor).item()
|
|
193
|
+
mean_val = mx.mean(tensor).item()
|
|
194
|
+
std_val = mx.std(tensor).item()
|
|
195
|
+
print(f" Shape: {tensor.shape}")
|
|
196
|
+
print(f" Min: {min_val:.4f}, Max: {max_val:.4f}")
|
|
197
|
+
print(f" Mean: {mean_val:.4f}, Std: {std_val:.4f}")
|
|
198
|
+
print("-" * (len(name) + 24))
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def pixel_shuffle(input_tensor, shuffle_ratio):
|
|
202
|
+
# input_tensor: [batch_size, num_patches, channels]
|
|
203
|
+
batch_size, num_patches, channels = input_tensor.shape
|
|
204
|
+
patch_size = int(math.sqrt(num_patches))
|
|
205
|
+
|
|
206
|
+
input_tensor = input_tensor.reshape(batch_size, patch_size, patch_size, -1)
|
|
207
|
+
batch_size, height, width, channels = input_tensor.shape
|
|
208
|
+
|
|
209
|
+
reshaped_tensor = input_tensor.reshape(
|
|
210
|
+
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
|
|
211
|
+
)
|
|
212
|
+
reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3)
|
|
213
|
+
|
|
214
|
+
reshaped_tensor = reshaped_tensor.reshape(
|
|
215
|
+
batch_size,
|
|
216
|
+
int(height * shuffle_ratio),
|
|
217
|
+
int(width * shuffle_ratio),
|
|
218
|
+
int(channels / (shuffle_ratio**2)),
|
|
219
|
+
)
|
|
220
|
+
reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3)
|
|
221
|
+
|
|
222
|
+
output_tensor = reshaped_tensor.reshape(batch_size, -1, reshaped_tensor.shape[-1])
|
|
223
|
+
return output_tensor
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def interpolate(pos_embed, size, mode="cubic", align_corners=False):
|
|
227
|
+
"""
|
|
228
|
+
MLX implementation of PyTorch's F.interpolate with bicubic mode
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
pos_embed: MLX array with shape [B, C, H_src, W_src] or [C, H_src, W_src]
|
|
232
|
+
size: Tuple (H_dst, W_dst) - target size
|
|
233
|
+
align_corners: Boolean - whether to align corners
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Interpolated array with shape [B, C, H_dst, W_dst] or [C, H_dst, W_dst]
|
|
237
|
+
"""
|
|
238
|
+
# Handle different input shapes
|
|
239
|
+
input_dim = pos_embed.ndim
|
|
240
|
+
original_shape = pos_embed.shape
|
|
241
|
+
|
|
242
|
+
if input_dim == 3:
|
|
243
|
+
# [C, H, W] -> [1, C, H, W]
|
|
244
|
+
pos_embed = pos_embed.reshape(1, *original_shape)
|
|
245
|
+
|
|
246
|
+
# Get source dimensions
|
|
247
|
+
h_src, w_src = pos_embed.shape[-2:]
|
|
248
|
+
h_dst, w_dst = size
|
|
249
|
+
|
|
250
|
+
# Calculate scale factors
|
|
251
|
+
scale_h = h_dst / h_src
|
|
252
|
+
scale_w = w_dst / w_src
|
|
253
|
+
|
|
254
|
+
# Create upsampler
|
|
255
|
+
upsampler = nn.Upsample(
|
|
256
|
+
scale_factor=(scale_h, scale_w), mode=mode, align_corners=align_corners
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Apply upsampling
|
|
260
|
+
result = upsampler(pos_embed)
|
|
261
|
+
|
|
262
|
+
# Return in the original dimension format
|
|
263
|
+
if input_dim == 3:
|
|
264
|
+
return result.reshape(original_shape[0], *size)
|
|
265
|
+
return result
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
@mx.compile
|
|
269
|
+
def chunked_attention(
|
|
270
|
+
queries: mx.array,
|
|
271
|
+
keys: mx.array,
|
|
272
|
+
values: mx.array,
|
|
273
|
+
scale: float,
|
|
274
|
+
chunk_size: int,
|
|
275
|
+
) -> mx.array:
|
|
276
|
+
|
|
277
|
+
L = queries.shape[2]
|
|
278
|
+
|
|
279
|
+
outputs = []
|
|
280
|
+
for i in range(0, L, chunk_size):
|
|
281
|
+
end_idx = min(i + chunk_size, L)
|
|
282
|
+
q_chunk = queries[:, :, i:end_idx, :] # (B, n_heads, chunk, head_dim)
|
|
283
|
+
|
|
284
|
+
chunk_output = mx.fast.scaled_dot_product_attention(
|
|
285
|
+
q_chunk, keys, values, scale=scale
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
outputs.append(chunk_output)
|
|
289
|
+
|
|
290
|
+
return mx.concatenate(outputs, axis=2) # (B, n_heads, L, head_dim)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def install_auto_processor_patch(target_model_types, processor_cls):
|
|
294
|
+
"""
|
|
295
|
+
Install a composable patch on transformers.AutoProcessor.from_pretrained
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
target_model_types (Union[str, List[str]]): Model types to intercept.
|
|
299
|
+
processor_cls (type): Processor class exposing `from_pretrained`.
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
The previous `AutoProcessor.from_pretrained` for reference.
|
|
303
|
+
"""
|
|
304
|
+
from transformers import AutoProcessor as _HF_AutoProcessor
|
|
305
|
+
|
|
306
|
+
if isinstance(target_model_types, str):
|
|
307
|
+
target_model_types = [target_model_types]
|
|
308
|
+
target_model_types = {t.lower() for t in target_model_types}
|
|
309
|
+
|
|
310
|
+
previous_from_pretrained = _HF_AutoProcessor.from_pretrained
|
|
311
|
+
|
|
312
|
+
@classmethod
|
|
313
|
+
def _patched_auto_processor_from_pretrained(
|
|
314
|
+
cls, pretrained_model_name_or_path, **kwargs
|
|
315
|
+
):
|
|
316
|
+
import json as _json
|
|
317
|
+
from pathlib import Path
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
model_path = Path(pretrained_model_name_or_path)
|
|
321
|
+
is_local = model_path.exists() and model_path.is_dir()
|
|
322
|
+
|
|
323
|
+
cfg = {}
|
|
324
|
+
if is_local:
|
|
325
|
+
config_path = model_path / "config.json"
|
|
326
|
+
if config_path.exists():
|
|
327
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
328
|
+
cfg = _json.load(f)
|
|
329
|
+
else:
|
|
330
|
+
try:
|
|
331
|
+
from huggingface_hub import hf_hub_download
|
|
332
|
+
|
|
333
|
+
cfg_path = hf_hub_download(
|
|
334
|
+
pretrained_model_name_or_path, "config.json"
|
|
335
|
+
)
|
|
336
|
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
|
337
|
+
cfg = _json.load(f)
|
|
338
|
+
except Exception:
|
|
339
|
+
cfg = {}
|
|
340
|
+
|
|
341
|
+
model_type = str(cfg.get("model_type", "")).lower()
|
|
342
|
+
if model_type in target_model_types:
|
|
343
|
+
return processor_cls.from_pretrained(
|
|
344
|
+
pretrained_model_name_or_path, **kwargs
|
|
345
|
+
)
|
|
346
|
+
except Exception:
|
|
347
|
+
# On any failure, fall back to previous behavior
|
|
348
|
+
pass
|
|
349
|
+
|
|
350
|
+
# Chain to the prior from_pretrained (which may already be patched)
|
|
351
|
+
return previous_from_pretrained.__func__(
|
|
352
|
+
cls, pretrained_model_name_or_path, **kwargs
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
_HF_AutoProcessor.from_pretrained = _patched_auto_processor_from_pretrained
|
|
356
|
+
return previous_from_pretrained
|