fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
from typing import Optional, Tuple, Type
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_abs_pos_sam(abs_pos, tgt_size):
|
|
8
|
+
"""Interpolate absolute positional embeddings to target size."""
|
|
9
|
+
dtype = abs_pos.dtype
|
|
10
|
+
src_size = abs_pos.shape[1]
|
|
11
|
+
|
|
12
|
+
if src_size != tgt_size:
|
|
13
|
+
# Transpose to (B, C, H, W) for interpolation
|
|
14
|
+
old_pos_embed = abs_pos.transpose(0, 3, 1, 2)
|
|
15
|
+
old_pos_embed = old_pos_embed.astype(mx.float32)
|
|
16
|
+
|
|
17
|
+
# Bicubic interpolation
|
|
18
|
+
from ..kernels import bicubic_interpolate
|
|
19
|
+
|
|
20
|
+
new_pos_embed = bicubic_interpolate(
|
|
21
|
+
old_pos_embed, size=(tgt_size, tgt_size), antialias=True
|
|
22
|
+
).astype(dtype)
|
|
23
|
+
|
|
24
|
+
# Transpose back to (B, H, W, C)
|
|
25
|
+
new_pos_embed = new_pos_embed.transpose(0, 2, 3, 1)
|
|
26
|
+
return new_pos_embed
|
|
27
|
+
else:
|
|
28
|
+
return abs_pos
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MLPBlock(nn.Module):
|
|
32
|
+
"""MLP block with GELU activation."""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
embedding_dim: int,
|
|
37
|
+
mlp_dim: int,
|
|
38
|
+
act: Type[nn.Module] = nn.GELU,
|
|
39
|
+
) -> None:
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
|
42
|
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
|
43
|
+
self.act = act()
|
|
44
|
+
|
|
45
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
46
|
+
return self.lin2(self.act(self.lin1(x)))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Attention(nn.Module):
|
|
50
|
+
"""Multi-head Attention block with relative position embeddings."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
dim: int,
|
|
55
|
+
num_heads: int = 8,
|
|
56
|
+
qkv_bias: bool = True,
|
|
57
|
+
use_rel_pos: bool = False,
|
|
58
|
+
input_size: Optional[Tuple[int, int]] = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""
|
|
61
|
+
Args:
|
|
62
|
+
dim (int): Number of input channels.
|
|
63
|
+
num_heads (int): Number of attention heads.
|
|
64
|
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
65
|
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
66
|
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
|
67
|
+
positional parameter size.
|
|
68
|
+
"""
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.num_heads = num_heads
|
|
71
|
+
head_dim = dim // num_heads
|
|
72
|
+
self.scale = head_dim**-0.5
|
|
73
|
+
|
|
74
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
75
|
+
self.proj = nn.Linear(dim, dim)
|
|
76
|
+
|
|
77
|
+
self.use_rel_pos = use_rel_pos
|
|
78
|
+
if self.use_rel_pos:
|
|
79
|
+
assert (
|
|
80
|
+
input_size is not None
|
|
81
|
+
), "Input size must be provided if using relative positional encoding."
|
|
82
|
+
# Initialize relative positional embeddings
|
|
83
|
+
self.rel_pos_h = mx.zeros((2 * input_size[0] - 1, head_dim))
|
|
84
|
+
self.rel_pos_w = mx.zeros((2 * input_size[1] - 1, head_dim))
|
|
85
|
+
|
|
86
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
87
|
+
B, H, W, _ = x.shape
|
|
88
|
+
|
|
89
|
+
# QKV projection and reshape
|
|
90
|
+
qkv = (
|
|
91
|
+
self.qkv(x)
|
|
92
|
+
.reshape(B, H * W, 3, self.num_heads, -1)
|
|
93
|
+
.transpose(2, 0, 3, 1, 4)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Separate q, k, v
|
|
97
|
+
qkv_reshaped = qkv.reshape(3, B * self.num_heads, H * W, -1)
|
|
98
|
+
q, k, v = qkv_reshaped[0], qkv_reshaped[1], qkv_reshaped[2]
|
|
99
|
+
|
|
100
|
+
# Compute relative positional embeddings if needed
|
|
101
|
+
rel_h, rel_w = None, None
|
|
102
|
+
if self.use_rel_pos:
|
|
103
|
+
rel_h, rel_w = add_decomposed_rel_pos(
|
|
104
|
+
q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Reshape for attention
|
|
108
|
+
q = q.reshape(B, self.num_heads, H * W, -1)
|
|
109
|
+
k = k.reshape(B, self.num_heads, H * W, -1)
|
|
110
|
+
v = v.reshape(B, self.num_heads, H * W, -1)
|
|
111
|
+
|
|
112
|
+
# Apply scaled dot product attention
|
|
113
|
+
if self.use_rel_pos:
|
|
114
|
+
rel_h = rel_h.reshape(
|
|
115
|
+
B, self.num_heads, rel_h.shape[1], rel_h.shape[2], rel_h.shape[3]
|
|
116
|
+
)
|
|
117
|
+
rel_w = rel_w.reshape(
|
|
118
|
+
B, self.num_heads, rel_w.shape[1], rel_w.shape[2], rel_w.shape[3]
|
|
119
|
+
)
|
|
120
|
+
attn_bias = (rel_h + rel_w).reshape(
|
|
121
|
+
B, self.num_heads, rel_h.shape[2], rel_h.shape[3] * rel_w.shape[4]
|
|
122
|
+
)
|
|
123
|
+
x = mx.fast.scaled_dot_product_attention(
|
|
124
|
+
q, k, v, scale=self.scale, mask=attn_bias
|
|
125
|
+
)
|
|
126
|
+
else:
|
|
127
|
+
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
|
128
|
+
|
|
129
|
+
# Reshape output
|
|
130
|
+
x = (
|
|
131
|
+
x.reshape(B, self.num_heads, H, W, -1)
|
|
132
|
+
.transpose(0, 2, 3, 1, 4)
|
|
133
|
+
.reshape(B, H, W, -1)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
x = self.proj(x)
|
|
137
|
+
return x
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class Block(nn.Module):
|
|
141
|
+
"""Transformer blocks with support of window attention and residual propagation."""
|
|
142
|
+
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
dim: int,
|
|
146
|
+
num_heads: int,
|
|
147
|
+
mlp_ratio: float = 4.0,
|
|
148
|
+
qkv_bias: bool = True,
|
|
149
|
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
150
|
+
act_layer: Type[nn.Module] = nn.GELU,
|
|
151
|
+
use_rel_pos: bool = False,
|
|
152
|
+
window_size: int = 0,
|
|
153
|
+
input_size: Optional[Tuple[int, int]] = None,
|
|
154
|
+
) -> None:
|
|
155
|
+
"""
|
|
156
|
+
Args:
|
|
157
|
+
dim (int): Number of input channels.
|
|
158
|
+
num_heads (int): Number of attention heads in each ViT block.
|
|
159
|
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
160
|
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
161
|
+
norm_layer (nn.Module): Normalization layer.
|
|
162
|
+
act_layer (nn.Module): Activation layer.
|
|
163
|
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
164
|
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
|
165
|
+
use global attention.
|
|
166
|
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
|
167
|
+
positional parameter size.
|
|
168
|
+
"""
|
|
169
|
+
super().__init__()
|
|
170
|
+
self.norm1 = norm_layer(dim, eps=1e-6)
|
|
171
|
+
self.attn = Attention(
|
|
172
|
+
dim,
|
|
173
|
+
num_heads=num_heads,
|
|
174
|
+
qkv_bias=qkv_bias,
|
|
175
|
+
use_rel_pos=use_rel_pos,
|
|
176
|
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
self.norm2 = norm_layer(dim, eps=1e-6)
|
|
180
|
+
self.mlp = MLPBlock(
|
|
181
|
+
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.window_size = window_size
|
|
185
|
+
|
|
186
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
187
|
+
shortcut = x
|
|
188
|
+
x = self.norm1(x)
|
|
189
|
+
|
|
190
|
+
# Window partition
|
|
191
|
+
if self.window_size > 0:
|
|
192
|
+
H, W = x.shape[1], x.shape[2]
|
|
193
|
+
x, pad_hw = window_partition(x, self.window_size)
|
|
194
|
+
|
|
195
|
+
x = self.attn(x)
|
|
196
|
+
|
|
197
|
+
# Reverse window partition
|
|
198
|
+
if self.window_size > 0:
|
|
199
|
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
|
200
|
+
|
|
201
|
+
x = shortcut + x
|
|
202
|
+
x = x + self.mlp(self.norm2(x))
|
|
203
|
+
|
|
204
|
+
return x
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class PatchEmbed(nn.Module):
|
|
208
|
+
"""Image to Patch Embedding."""
|
|
209
|
+
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
kernel_size: Tuple[int, int] = (16, 16),
|
|
213
|
+
stride: Tuple[int, int] = (16, 16),
|
|
214
|
+
in_chans: int = 3,
|
|
215
|
+
embed_dim: int = 768,
|
|
216
|
+
) -> None:
|
|
217
|
+
"""
|
|
218
|
+
Args:
|
|
219
|
+
kernel_size (Tuple): kernel size of the projection layer.
|
|
220
|
+
stride (Tuple): stride of the projection layer.
|
|
221
|
+
in_chans (int): Number of input image channels.
|
|
222
|
+
embed_dim (int): Patch embedding dimension.
|
|
223
|
+
"""
|
|
224
|
+
super().__init__()
|
|
225
|
+
self.proj = nn.Conv2d(
|
|
226
|
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
230
|
+
x = self.proj(x)
|
|
231
|
+
return x
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class SAMEncoder(nn.Module):
|
|
235
|
+
"""Vision Transformer encoder based on SAM architecture."""
|
|
236
|
+
|
|
237
|
+
def __init__(
|
|
238
|
+
self,
|
|
239
|
+
img_size: int = 1024,
|
|
240
|
+
patch_size: int = 16,
|
|
241
|
+
in_chans: int = 3,
|
|
242
|
+
embed_dim: int = 768,
|
|
243
|
+
depth: int = 12,
|
|
244
|
+
num_heads: int = 12,
|
|
245
|
+
mlp_ratio: float = 4.0,
|
|
246
|
+
out_chans: int = 256,
|
|
247
|
+
qkv_bias: bool = True,
|
|
248
|
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
|
249
|
+
act_layer: Type[nn.Module] = nn.GELU,
|
|
250
|
+
use_abs_pos: bool = True,
|
|
251
|
+
use_rel_pos: bool = True,
|
|
252
|
+
window_size: int = 14,
|
|
253
|
+
global_attn_indexes: Tuple[int, ...] = (2, 5, 8, 11),
|
|
254
|
+
final_out_chans: int = 1024,
|
|
255
|
+
) -> None:
|
|
256
|
+
"""
|
|
257
|
+
Args:
|
|
258
|
+
img_size (int): Input image size.
|
|
259
|
+
patch_size (int): Patch size.
|
|
260
|
+
in_chans (int): Number of input image channels.
|
|
261
|
+
embed_dim (int): Patch embedding dimension.
|
|
262
|
+
depth (int): Depth of ViT.
|
|
263
|
+
num_heads (int): Number of attention heads in each ViT block.
|
|
264
|
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
|
265
|
+
out_chans (int): Output channels for neck.
|
|
266
|
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
267
|
+
norm_layer (nn.Module): Normalization layer.
|
|
268
|
+
act_layer (nn.Module): Activation layer.
|
|
269
|
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
|
270
|
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
|
271
|
+
window_size (int): Window size for window attention blocks.
|
|
272
|
+
global_attn_indexes (tuple): Indexes for blocks using global attention.
|
|
273
|
+
final_out_chans (int): Final output channels after net_3 (1024 for OCR, 896 for OCR-2).
|
|
274
|
+
"""
|
|
275
|
+
super().__init__()
|
|
276
|
+
self.img_size = img_size
|
|
277
|
+
|
|
278
|
+
self.patch_embed = PatchEmbed(
|
|
279
|
+
kernel_size=(patch_size, patch_size),
|
|
280
|
+
stride=(patch_size, patch_size),
|
|
281
|
+
in_chans=in_chans,
|
|
282
|
+
embed_dim=embed_dim,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
self.use_abs_pos = use_abs_pos
|
|
286
|
+
if use_abs_pos:
|
|
287
|
+
# Initialize absolute positional embedding with pretrain image size
|
|
288
|
+
self.pos_embed = mx.zeros(
|
|
289
|
+
(1, img_size // patch_size, img_size // patch_size, embed_dim)
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
self.blocks = []
|
|
293
|
+
for i in range(depth):
|
|
294
|
+
block = Block(
|
|
295
|
+
dim=embed_dim,
|
|
296
|
+
num_heads=num_heads,
|
|
297
|
+
mlp_ratio=mlp_ratio,
|
|
298
|
+
qkv_bias=qkv_bias,
|
|
299
|
+
norm_layer=norm_layer,
|
|
300
|
+
act_layer=act_layer,
|
|
301
|
+
use_rel_pos=use_rel_pos,
|
|
302
|
+
window_size=window_size if i not in global_attn_indexes else 0,
|
|
303
|
+
input_size=(img_size // patch_size, img_size // patch_size),
|
|
304
|
+
)
|
|
305
|
+
self.blocks.append(block)
|
|
306
|
+
|
|
307
|
+
# Neck layers for output processing
|
|
308
|
+
self.neck = [
|
|
309
|
+
nn.Conv2d(embed_dim, out_chans, kernel_size=1, bias=False),
|
|
310
|
+
nn.LayerNorm(out_chans, eps=1e-6),
|
|
311
|
+
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
|
|
312
|
+
nn.LayerNorm(out_chans, eps=1e-6),
|
|
313
|
+
]
|
|
314
|
+
|
|
315
|
+
# Additional downsampling layers
|
|
316
|
+
self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
|
|
317
|
+
self.net_3 = nn.Conv2d(
|
|
318
|
+
512, final_out_chans, kernel_size=3, stride=2, padding=1, bias=False
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
322
|
+
# Patch embedding
|
|
323
|
+
x = self.patch_embed(x)
|
|
324
|
+
|
|
325
|
+
# Add positional embeddings
|
|
326
|
+
if self.use_abs_pos:
|
|
327
|
+
x = x + get_abs_pos_sam(self.pos_embed, x.shape[1])
|
|
328
|
+
|
|
329
|
+
# Apply transformer blocks
|
|
330
|
+
for blk in self.blocks:
|
|
331
|
+
x = blk(x)
|
|
332
|
+
|
|
333
|
+
# Apply neck layers
|
|
334
|
+
for n in self.neck:
|
|
335
|
+
x = n(x)
|
|
336
|
+
|
|
337
|
+
# Additional downsampling
|
|
338
|
+
x = self.net_2(x)
|
|
339
|
+
x = self.net_3(x)
|
|
340
|
+
|
|
341
|
+
return x
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
# Utility functions
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def window_partition(x: mx.array, window_size: int) -> Tuple[mx.array, Tuple[int, int]]:
|
|
348
|
+
"""
|
|
349
|
+
Partition into non-overlapping windows with padding if needed.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
x (mx.array): input tokens with [B, H, W, C].
|
|
353
|
+
window_size (int): window size.
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
|
357
|
+
(Hp, Wp): padded height and width before partition
|
|
358
|
+
"""
|
|
359
|
+
B, H, W, C = x.shape
|
|
360
|
+
|
|
361
|
+
pad_h = (window_size - H % window_size) % window_size
|
|
362
|
+
pad_w = (window_size - W % window_size) % window_size
|
|
363
|
+
|
|
364
|
+
if pad_h > 0 or pad_w > 0:
|
|
365
|
+
x = mx.pad(x, [(0, 0), (0, pad_h), (0, pad_w), (0, 0)])
|
|
366
|
+
|
|
367
|
+
Hp, Wp = H + pad_h, W + pad_w
|
|
368
|
+
|
|
369
|
+
x = x.reshape(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
|
370
|
+
windows = x.transpose(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
|
|
371
|
+
|
|
372
|
+
return windows, (Hp, Wp)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def window_unpartition(
|
|
376
|
+
windows: mx.array, # FIXED: Changed from np.ndarray to mx.array
|
|
377
|
+
window_size: int,
|
|
378
|
+
pad_hw: Tuple[int, int],
|
|
379
|
+
hw: Tuple[int, int],
|
|
380
|
+
) -> mx.array: # FIXED: Changed return type from implicit to mx.array
|
|
381
|
+
"""
|
|
382
|
+
Window unpartition into original sequences and removing padding.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
windows (mx.array): input tokens with [B * num_windows, window_size, window_size, C].
|
|
386
|
+
window_size (int): window size.
|
|
387
|
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
|
388
|
+
hw (Tuple): original height and width (H, W) before padding.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
x: unpartitioned sequences with [B, H, W, C].
|
|
392
|
+
"""
|
|
393
|
+
Hp, Wp = pad_hw
|
|
394
|
+
H, W = hw
|
|
395
|
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
|
396
|
+
|
|
397
|
+
x = windows.reshape(
|
|
398
|
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
|
399
|
+
)
|
|
400
|
+
x = x.transpose(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
|
|
401
|
+
|
|
402
|
+
if Hp > H or Wp > W:
|
|
403
|
+
x = x[:, :H, :W, :]
|
|
404
|
+
|
|
405
|
+
return x
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: mx.array) -> mx.array:
|
|
409
|
+
"""
|
|
410
|
+
Get relative positional embeddings according to the relative positions of
|
|
411
|
+
query and key sizes.
|
|
412
|
+
|
|
413
|
+
Args:
|
|
414
|
+
q_size (int): size of query q.
|
|
415
|
+
k_size (int): size of key k.
|
|
416
|
+
rel_pos (mx.array): relative position embeddings (L, C).
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
Extracted positional embeddings according to relative positions.
|
|
420
|
+
"""
|
|
421
|
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
|
422
|
+
|
|
423
|
+
# Interpolate rel pos if needed
|
|
424
|
+
if rel_pos.shape[0] != max_rel_dist:
|
|
425
|
+
dtype = rel_pos.dtype
|
|
426
|
+
rel_pos = rel_pos.astype(mx.float32)
|
|
427
|
+
rel_pos_resized = rel_pos.reshape(1, rel_pos.shape[0], -1).transpose(0, 2, 1)
|
|
428
|
+
|
|
429
|
+
# Linear interpolation
|
|
430
|
+
scale = rel_pos_resized.shape[2] / max_rel_dist
|
|
431
|
+
indices = mx.arange(max_rel_dist, dtype=mx.float32) * scale
|
|
432
|
+
idx_floor = mx.floor(indices).astype(mx.int32)
|
|
433
|
+
idx_ceil = mx.minimum(idx_floor + 1, rel_pos_resized.shape[2] - 1)
|
|
434
|
+
weight = indices - idx_floor.astype(mx.float32)
|
|
435
|
+
|
|
436
|
+
rel_pos_resized = (
|
|
437
|
+
mx.take(rel_pos_resized, idx_floor, axis=2) * (1 - weight)
|
|
438
|
+
+ mx.take(rel_pos_resized, idx_ceil, axis=2) * weight
|
|
439
|
+
).astype(dtype)
|
|
440
|
+
|
|
441
|
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).transpose(1, 0)
|
|
442
|
+
else:
|
|
443
|
+
rel_pos_resized = rel_pos
|
|
444
|
+
|
|
445
|
+
# Scale the coords with short length if shapes for q and k are different
|
|
446
|
+
q_coords = mx.arange(q_size, dtype=mx.float32)[:, None] * max(k_size / q_size, 1.0)
|
|
447
|
+
k_coords = mx.arange(k_size, dtype=mx.float32)[None, :] * max(q_size / k_size, 1.0)
|
|
448
|
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
|
449
|
+
|
|
450
|
+
return rel_pos_resized[relative_coords.astype(mx.int32)]
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def add_decomposed_rel_pos(
|
|
454
|
+
q: mx.array, # FIXED: Changed from np.ndarray to mx.array
|
|
455
|
+
rel_pos_h: mx.array, # FIXED: Changed from np.ndarray to mx.array
|
|
456
|
+
rel_pos_w: mx.array, # FIXED: Changed from np.ndarray to mx.array
|
|
457
|
+
q_size: Tuple[int, int],
|
|
458
|
+
k_size: Tuple[int, int],
|
|
459
|
+
) -> Tuple[mx.array, mx.array]: # FIXED: Added explicit return type
|
|
460
|
+
"""
|
|
461
|
+
Calculate decomposed Relative Positional Embeddings.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
q (mx.array): query q in the attention layer with shape (B, q_h * q_w, C).
|
|
465
|
+
rel_pos_h (mx.array): relative position embeddings (Lh, C) for height axis.
|
|
466
|
+
rel_pos_w (mx.array): relative position embeddings (Lw, C) for width axis.
|
|
467
|
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
|
468
|
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
Tuple of (rel_h, rel_w): relative position biases for height and width.
|
|
472
|
+
"""
|
|
473
|
+
q_h, q_w = q_size
|
|
474
|
+
k_h, k_w = k_size
|
|
475
|
+
|
|
476
|
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
|
477
|
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
|
478
|
+
|
|
479
|
+
B, _, dim = q.shape
|
|
480
|
+
r_q = q.reshape(B, q_h, q_w, dim)
|
|
481
|
+
|
|
482
|
+
rel_h = mx.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
|
483
|
+
rel_w = mx.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
|
484
|
+
rel_h = rel_h[..., None]
|
|
485
|
+
rel_w = rel_w[..., None, :]
|
|
486
|
+
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
|
|
487
|
+
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
|
|
488
|
+
|
|
489
|
+
return rel_h, rel_w
|