fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Vision encoder for Jina VLM in MLX."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Tuple
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
|
|
8
|
+
from .config import VisionConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class PatchEmbedding(nn.Module):
|
|
12
|
+
"""Patch embedding using linear projection."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, config: VisionConfig):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.patch_size = config.patch_size
|
|
17
|
+
self.num_channels = config.num_channels
|
|
18
|
+
self.hidden_size = config.hidden_size
|
|
19
|
+
|
|
20
|
+
# Linear projection for patches - named to match weights
|
|
21
|
+
patch_dim = config.num_channels * config.patch_size * config.patch_size
|
|
22
|
+
self.proj = nn.Linear(patch_dim, config.hidden_size, bias=config.use_bias)
|
|
23
|
+
|
|
24
|
+
def __call__(self, x: mx.array) -> Tuple[mx.array, Tuple[int, int]]:
|
|
25
|
+
if x.ndim == 3:
|
|
26
|
+
# Already patchified: (B, n_patches, patch_dim)
|
|
27
|
+
B, n_patches, _ = x.shape
|
|
28
|
+
nH = nW = int(n_patches**0.5)
|
|
29
|
+
x = self.proj(x)
|
|
30
|
+
else:
|
|
31
|
+
# Image format: (B, C, H, W)
|
|
32
|
+
B, C, H, W = x.shape
|
|
33
|
+
pH, pW = self.patch_size, self.patch_size
|
|
34
|
+
nH, nW = H // pH, W // pW
|
|
35
|
+
x = x.reshape(B, C, nH, pH, nW, pW)
|
|
36
|
+
x = x.transpose(0, 2, 4, 1, 3, 5)
|
|
37
|
+
x = x.reshape(B, nH * nW, C * pH * pW)
|
|
38
|
+
x = self.proj(x)
|
|
39
|
+
return x, (nH, nW)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class VisionMLP(nn.Module):
|
|
43
|
+
"""MLP for vision transformer - matches weight naming: ffn.up, ffn.down"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, config: VisionConfig):
|
|
46
|
+
super().__init__()
|
|
47
|
+
# Named to match weights: ffn.up, ffn.down
|
|
48
|
+
self.up = nn.Linear(
|
|
49
|
+
config.hidden_size, config.intermediate_size, bias=config.use_bias
|
|
50
|
+
)
|
|
51
|
+
self.down = nn.Linear(
|
|
52
|
+
config.intermediate_size, config.hidden_size, bias=config.use_bias
|
|
53
|
+
)
|
|
54
|
+
# Use built-in GELU with tanh approximation
|
|
55
|
+
if config.activation == "gelu_pytorch_tanh":
|
|
56
|
+
self.gelu = nn.GELU(approx="tanh")
|
|
57
|
+
else:
|
|
58
|
+
self.gelu = nn.GELU()
|
|
59
|
+
|
|
60
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
61
|
+
x = self.up(x)
|
|
62
|
+
x = self.gelu(x)
|
|
63
|
+
x = self.down(x)
|
|
64
|
+
return x
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class VisionAttention(nn.Module):
|
|
68
|
+
"""Multi-head self-attention - matches weight naming: attn.qkv, attn.out"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, config: VisionConfig):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.num_heads = config.num_attention_heads
|
|
73
|
+
self.head_dim = config.head_dim
|
|
74
|
+
self.scale = self.head_dim**-0.5
|
|
75
|
+
|
|
76
|
+
# Fused QKV projection - named to match weights
|
|
77
|
+
self.qkv = nn.Linear(
|
|
78
|
+
config.hidden_size,
|
|
79
|
+
3 * config.num_attention_heads * config.head_dim,
|
|
80
|
+
bias=config.use_bias,
|
|
81
|
+
)
|
|
82
|
+
self.out = nn.Linear(
|
|
83
|
+
config.num_attention_heads * config.head_dim,
|
|
84
|
+
config.hidden_size,
|
|
85
|
+
bias=config.use_bias,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
89
|
+
B, L, _ = x.shape
|
|
90
|
+
qkv = self.qkv(x)
|
|
91
|
+
qkv = qkv.reshape(B, L, 3, self.num_heads, self.head_dim)
|
|
92
|
+
qkv = qkv.transpose(2, 0, 3, 1, 4)
|
|
93
|
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
|
94
|
+
|
|
95
|
+
attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
|
96
|
+
attn = mx.softmax(attn, axis=-1)
|
|
97
|
+
x = attn @ v
|
|
98
|
+
|
|
99
|
+
x = x.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
100
|
+
x = self.out(x)
|
|
101
|
+
return x
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class VisionEncoderLayer(nn.Module):
|
|
105
|
+
"""Transformer block - matches weight naming: attn_norm, ffn_norm"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, config: VisionConfig):
|
|
108
|
+
super().__init__()
|
|
109
|
+
# Named to match weights: attn_norm, ffn_norm
|
|
110
|
+
self.attn_norm = nn.LayerNorm(
|
|
111
|
+
config.hidden_size, eps=config.layer_norm_eps, bias=config.use_bias
|
|
112
|
+
)
|
|
113
|
+
self.attn = VisionAttention(config)
|
|
114
|
+
self.ffn_norm = nn.LayerNorm(
|
|
115
|
+
config.hidden_size, eps=config.layer_norm_eps, bias=config.use_bias
|
|
116
|
+
)
|
|
117
|
+
self.ffn = VisionMLP(config)
|
|
118
|
+
|
|
119
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
120
|
+
x = x + self.attn(self.attn_norm(x))
|
|
121
|
+
x = x + self.ffn(self.ffn_norm(x))
|
|
122
|
+
return x
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class VisionModel(nn.Module):
|
|
126
|
+
"""Vision encoder (SigLIP-style ViT)."""
|
|
127
|
+
|
|
128
|
+
def __init__(self, config: VisionConfig):
|
|
129
|
+
super().__init__()
|
|
130
|
+
self.config = config
|
|
131
|
+
self.model_type = config.model_type
|
|
132
|
+
self.hidden_size = config.hidden_size
|
|
133
|
+
self.vit_layers = config.vit_layers
|
|
134
|
+
|
|
135
|
+
# Named to match weights: patch_embed.proj
|
|
136
|
+
self.patch_embed = PatchEmbedding(config)
|
|
137
|
+
|
|
138
|
+
# Named to match weights: pos_embed (saved as 2D, not 3D)
|
|
139
|
+
num_patches = (config.image_size // config.patch_size) ** 2
|
|
140
|
+
if config.use_cls_token:
|
|
141
|
+
num_patches += 1
|
|
142
|
+
self.cls_token = mx.zeros((1, 1, config.hidden_size))
|
|
143
|
+
else:
|
|
144
|
+
self.cls_token = None
|
|
145
|
+
self.pos_embed = mx.zeros((num_patches, config.hidden_size))
|
|
146
|
+
|
|
147
|
+
# Transformer blocks
|
|
148
|
+
self.layers = [
|
|
149
|
+
VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
# Named to match weights: post_norm
|
|
153
|
+
if config.post_layer_norm:
|
|
154
|
+
self.post_norm = nn.LayerNorm(
|
|
155
|
+
config.hidden_size, eps=config.layer_norm_eps, bias=config.use_bias
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
self.post_norm = None
|
|
159
|
+
|
|
160
|
+
def __call__(self, x: mx.array) -> Tuple[mx.array, List[mx.array]]:
|
|
161
|
+
x, shape = self.patch_embed(x)
|
|
162
|
+
|
|
163
|
+
if self.cls_token is not None:
|
|
164
|
+
B = x.shape[0]
|
|
165
|
+
cls = mx.broadcast_to(self.cls_token, (B, 1, self.hidden_size))
|
|
166
|
+
x = mx.concatenate([cls, x], axis=1)
|
|
167
|
+
|
|
168
|
+
# pos_embed is (num_patches, hidden_size), add batch dim for broadcast
|
|
169
|
+
x = x + self.pos_embed[None, :, :]
|
|
170
|
+
|
|
171
|
+
hidden_states = []
|
|
172
|
+
for layer in self.layers:
|
|
173
|
+
x = layer(x)
|
|
174
|
+
hidden_states.append(x)
|
|
175
|
+
|
|
176
|
+
if self.post_norm is not None:
|
|
177
|
+
x = self.post_norm(x)
|
|
178
|
+
hidden_states.append(x)
|
|
179
|
+
|
|
180
|
+
return x, hidden_states
|
|
181
|
+
|
|
182
|
+
def get_features(self, images: mx.array) -> mx.array:
|
|
183
|
+
"""Extract features from specific ViT layers.
|
|
184
|
+
|
|
185
|
+
Note: hidden_states includes all layer outputs plus the post_norm output.
|
|
186
|
+
vit_layers indices (e.g., [-4, -10]) are applied to this full list.
|
|
187
|
+
For 27 layers with post_norm, hidden_states has 28 elements:
|
|
188
|
+
- indices 0-26: layer 0-26 outputs
|
|
189
|
+
- index 27: post_norm output
|
|
190
|
+
So vit_layers=[-4, -10] extracts layers 24 and 18 (not 23 and 17).
|
|
191
|
+
"""
|
|
192
|
+
_, hidden_states = self(images)
|
|
193
|
+
# Use full hidden_states including post_norm output for correct indexing
|
|
194
|
+
|
|
195
|
+
features = []
|
|
196
|
+
for layer_idx in self.vit_layers:
|
|
197
|
+
feats = hidden_states[layer_idx]
|
|
198
|
+
if self.cls_token is not None:
|
|
199
|
+
feats = feats[:, 1:]
|
|
200
|
+
features.append(feats)
|
|
201
|
+
|
|
202
|
+
return mx.concatenate(features, axis=-1)
|
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def nearest_interpolate(x, size=None, scale_factor=None):
|
|
5
|
+
"""
|
|
6
|
+
Nearest neighbor interpolation that exactly matches PyTorch's behavior.
|
|
7
|
+
"""
|
|
8
|
+
# Get input dimensions
|
|
9
|
+
batch_size, channels, in_h, in_w = x.shape
|
|
10
|
+
|
|
11
|
+
# Calculate output dimensions
|
|
12
|
+
if size is not None:
|
|
13
|
+
out_h, out_w = size
|
|
14
|
+
elif scale_factor is not None:
|
|
15
|
+
if isinstance(scale_factor, (int, float)):
|
|
16
|
+
scale_h = scale_w = scale_factor
|
|
17
|
+
else:
|
|
18
|
+
scale_h, scale_w = scale_factor
|
|
19
|
+
out_h, out_w = int(in_h * scale_h), int(in_w * scale_w)
|
|
20
|
+
else:
|
|
21
|
+
raise ValueError("Either size or scale_factor must be specified")
|
|
22
|
+
|
|
23
|
+
# Create dimensions tensor
|
|
24
|
+
dims = mx.array([batch_size, channels, in_h, in_w, out_h, out_w], dtype=mx.int32)
|
|
25
|
+
|
|
26
|
+
# Reshape input tensor to 1D for kernel processing
|
|
27
|
+
x_flat = x.reshape(-1)
|
|
28
|
+
input_dtype = x.dtype
|
|
29
|
+
if input_dtype != mx.float32:
|
|
30
|
+
x_flat = x_flat.astype(mx.float32)
|
|
31
|
+
|
|
32
|
+
# Metal kernel source that matches PyTorch's coordinate calculation
|
|
33
|
+
source = """
|
|
34
|
+
uint x_out = thread_position_in_grid.x;
|
|
35
|
+
uint y_out = thread_position_in_grid.y;
|
|
36
|
+
uint bc_idx = thread_position_in_grid.z;
|
|
37
|
+
|
|
38
|
+
int batch_size = dims[0];
|
|
39
|
+
int channels = dims[1];
|
|
40
|
+
int in_h = dims[2];
|
|
41
|
+
int in_w = dims[3];
|
|
42
|
+
int out_h = dims[4];
|
|
43
|
+
int out_w = dims[5];
|
|
44
|
+
|
|
45
|
+
if (x_out >= (uint)out_w || y_out >= (uint)out_h || bc_idx >= (uint)(batch_size * channels))
|
|
46
|
+
return;
|
|
47
|
+
|
|
48
|
+
int c = bc_idx % channels;
|
|
49
|
+
int b = bc_idx / channels;
|
|
50
|
+
|
|
51
|
+
// PyTorch's coordinate calculation for nearest neighbor
|
|
52
|
+
// This matches: torch.nn.functional.interpolate(..., mode='nearest')
|
|
53
|
+
float scale_h = float(in_h) / float(out_h);
|
|
54
|
+
float scale_w = float(in_w) / float(out_w);
|
|
55
|
+
|
|
56
|
+
// PyTorch uses floor for nearest neighbor coordinate mapping
|
|
57
|
+
int y_in = int(floor(float(y_out) * scale_h));
|
|
58
|
+
int x_in = int(floor(float(x_out) * scale_w));
|
|
59
|
+
|
|
60
|
+
// Clamp to bounds
|
|
61
|
+
y_in = max(0, min(y_in, in_h - 1));
|
|
62
|
+
x_in = max(0, min(x_in, in_w - 1));
|
|
63
|
+
|
|
64
|
+
int input_offset = ((b * channels + c) * in_h + y_in) * in_w + x_in;
|
|
65
|
+
int output_offset = ((b * channels + c) * out_h + y_out) * out_w + x_out;
|
|
66
|
+
|
|
67
|
+
output[output_offset] = input[input_offset];
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
# Create and run kernel
|
|
71
|
+
kernel = mx.fast.metal_kernel(
|
|
72
|
+
name="nearest_interpolation",
|
|
73
|
+
input_names=["input", "dims"],
|
|
74
|
+
output_names=["output"],
|
|
75
|
+
source=source,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
threadgroup = get_optimal_threadgroup(out_w, out_h)
|
|
79
|
+
outputs = kernel(
|
|
80
|
+
inputs=[x_flat, dims],
|
|
81
|
+
grid=(out_w, out_h, batch_size * channels),
|
|
82
|
+
threadgroup=threadgroup,
|
|
83
|
+
output_shapes=[(batch_size * channels * out_h * out_w,)],
|
|
84
|
+
output_dtypes=[mx.float32],
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
result = outputs[0].reshape(batch_size, channels, out_h, out_w)
|
|
88
|
+
if input_dtype != mx.float32:
|
|
89
|
+
result = result.astype(input_dtype)
|
|
90
|
+
|
|
91
|
+
return result
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def bicubic_interpolate(
|
|
95
|
+
x, size=None, scale_factor=None, align_corners=False, antialias=False
|
|
96
|
+
):
|
|
97
|
+
"""
|
|
98
|
+
Bicubic interpolation using MLX's built-in interpolate function.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
x: MLX tensor of shape [B, C, H, W]
|
|
102
|
+
size: Tuple of (out_h, out_w) or None
|
|
103
|
+
scale_factor: Float or tuple of (scale_h, scale_w) or None
|
|
104
|
+
align_corners: Whether to align corners
|
|
105
|
+
antialias: Whether to apply antialiasing
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Interpolated MLX tensor
|
|
109
|
+
"""
|
|
110
|
+
# Get input dimensions
|
|
111
|
+
batch_size, channels, in_h, in_w = x.shape
|
|
112
|
+
|
|
113
|
+
# Calculate output dimensions
|
|
114
|
+
if size is not None:
|
|
115
|
+
out_h, out_w = size
|
|
116
|
+
scale_h, scale_w = out_h / in_h, out_w / in_w
|
|
117
|
+
elif scale_factor is not None:
|
|
118
|
+
if isinstance(scale_factor, (int, float)):
|
|
119
|
+
scale_h = scale_w = scale_factor
|
|
120
|
+
else:
|
|
121
|
+
scale_h, scale_w = scale_factor
|
|
122
|
+
out_h, out_w = int(in_h * scale_h), int(in_w * scale_w)
|
|
123
|
+
else:
|
|
124
|
+
raise ValueError("Either size or scale_factor must be specified")
|
|
125
|
+
|
|
126
|
+
# Calculate antialiasing parameters
|
|
127
|
+
# PyTorch uses support = 2.0 for bicubic when antialiasing
|
|
128
|
+
support = 2.0
|
|
129
|
+
antialias_flag = 1.0 if (antialias and (scale_h < 1.0 or scale_w < 1.0)) else 0.0
|
|
130
|
+
|
|
131
|
+
# When downsampling with antialias, PyTorch expands the filter support
|
|
132
|
+
if antialias and scale_h < 1.0:
|
|
133
|
+
filter_scale_h = 1.0 / scale_h
|
|
134
|
+
else:
|
|
135
|
+
filter_scale_h = 1.0
|
|
136
|
+
|
|
137
|
+
if antialias and scale_w < 1.0:
|
|
138
|
+
filter_scale_w = 1.0 / scale_w
|
|
139
|
+
else:
|
|
140
|
+
filter_scale_w = 1.0
|
|
141
|
+
|
|
142
|
+
# Create parameters tensor
|
|
143
|
+
params = mx.array(
|
|
144
|
+
[
|
|
145
|
+
scale_h,
|
|
146
|
+
scale_w,
|
|
147
|
+
1.0 if align_corners else 0.0,
|
|
148
|
+
antialias_flag,
|
|
149
|
+
filter_scale_h,
|
|
150
|
+
filter_scale_w,
|
|
151
|
+
support,
|
|
152
|
+
],
|
|
153
|
+
dtype=mx.float32,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Create dimensions tensor
|
|
157
|
+
dims = mx.array([batch_size, channels, in_h, in_w, out_h, out_w], dtype=mx.int32)
|
|
158
|
+
|
|
159
|
+
# Reshape input tensor to 1D for kernel processing
|
|
160
|
+
x_flat = x.reshape(-1)
|
|
161
|
+
|
|
162
|
+
# Convert to float32 for processing if needed
|
|
163
|
+
input_dtype = x.dtype
|
|
164
|
+
if input_dtype != mx.float32:
|
|
165
|
+
x_flat = x_flat.astype(mx.float32)
|
|
166
|
+
|
|
167
|
+
header = """
|
|
168
|
+
// Bicubic kernel function
|
|
169
|
+
float cubic_kernel(float x) {
|
|
170
|
+
float absx = fabs(x);
|
|
171
|
+
float absx2 = absx * absx;
|
|
172
|
+
float absx3 = absx2 * absx;
|
|
173
|
+
|
|
174
|
+
const float a = -0.5f;
|
|
175
|
+
|
|
176
|
+
if (absx <= 1.0f) {
|
|
177
|
+
return (a + 2.0f) * absx3 - (a + 3.0f) * absx2 + 1.0f;
|
|
178
|
+
} else if (absx < 2.0f) {
|
|
179
|
+
return a * absx3 - 5.0f * a * absx2 + 8.0f * a * absx - 4.0f * a;
|
|
180
|
+
}
|
|
181
|
+
return 0.0f;
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
// Antialiased bicubic kernel - scales the support region for downsampling
|
|
185
|
+
float cubic_kernel_antialias(float x, float scale) {
|
|
186
|
+
// When downsampling, we need to integrate over a wider region
|
|
187
|
+
// This matches PyTorch's antialiasing behavior
|
|
188
|
+
return cubic_kernel(x / scale);
|
|
189
|
+
}
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
# Metal kernel source code with antialiasing support
|
|
193
|
+
source = """
|
|
194
|
+
// Get thread position
|
|
195
|
+
uint x_out = thread_position_in_grid.x;
|
|
196
|
+
uint y_out = thread_position_in_grid.y;
|
|
197
|
+
uint bc_idx = thread_position_in_grid.z;
|
|
198
|
+
|
|
199
|
+
// Extract dimensions
|
|
200
|
+
int batch_size = dims[0];
|
|
201
|
+
int channels = dims[1];
|
|
202
|
+
int in_h = dims[2];
|
|
203
|
+
int in_w = dims[3];
|
|
204
|
+
int out_h = dims[4];
|
|
205
|
+
int out_w = dims[5];
|
|
206
|
+
|
|
207
|
+
// Extract parameters
|
|
208
|
+
float scale_h = params[0];
|
|
209
|
+
float scale_w = params[1];
|
|
210
|
+
bool align_corners = params[2] > 0.5f;
|
|
211
|
+
bool use_antialias = params[3] > 0.5f;
|
|
212
|
+
float filter_scale_h = params[4];
|
|
213
|
+
float filter_scale_w = params[5];
|
|
214
|
+
float support = params[6];
|
|
215
|
+
|
|
216
|
+
// Check bounds
|
|
217
|
+
if (x_out >= (uint)out_w || y_out >= (uint)out_h || bc_idx >= (uint)(batch_size * channels))
|
|
218
|
+
return;
|
|
219
|
+
|
|
220
|
+
// Calculate batch and channel indices
|
|
221
|
+
int c = bc_idx % channels;
|
|
222
|
+
int b = bc_idx / channels;
|
|
223
|
+
|
|
224
|
+
// Calculate input coordinates
|
|
225
|
+
float x_in, y_in;
|
|
226
|
+
|
|
227
|
+
if (align_corners && out_w > 1 && out_h > 1) {
|
|
228
|
+
x_in = float(x_out) * (in_w - 1) / (out_w - 1);
|
|
229
|
+
y_in = float(y_out) * (in_h - 1) / (out_h - 1);
|
|
230
|
+
} else {
|
|
231
|
+
// PyTorch's default coordinate mapping
|
|
232
|
+
x_in = ((float(x_out) + 0.5f) / float(out_w)) * float(in_w) - 0.5f;
|
|
233
|
+
y_in = ((float(y_out) + 0.5f) / float(out_h)) * float(in_h) - 0.5f;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
// Calculate the support region based on antialiasing
|
|
237
|
+
float support_h = use_antialias ? support * filter_scale_h : support;
|
|
238
|
+
float support_w = use_antialias ? support * filter_scale_w : support;
|
|
239
|
+
|
|
240
|
+
// Calculate the range of input pixels to sample
|
|
241
|
+
int y_start = int(floor(y_in - support_h)) + 1;
|
|
242
|
+
int y_end = int(floor(y_in + support_h)) + 1;
|
|
243
|
+
int x_start = int(floor(x_in - support_w)) + 1;
|
|
244
|
+
int x_end = int(floor(x_in + support_w)) + 1;
|
|
245
|
+
|
|
246
|
+
// Clamp to valid range
|
|
247
|
+
y_start = max(0, y_start);
|
|
248
|
+
y_end = min(in_h, y_end);
|
|
249
|
+
x_start = max(0, x_start);
|
|
250
|
+
x_end = min(in_w, x_end);
|
|
251
|
+
|
|
252
|
+
// Perform bicubic interpolation with antialiasing
|
|
253
|
+
float result = 0.0f;
|
|
254
|
+
float weight_sum = 0.0f;
|
|
255
|
+
|
|
256
|
+
for (int y_pos = y_start; y_pos < y_end; y_pos++) {
|
|
257
|
+
float dy = float(y_pos) - y_in;
|
|
258
|
+
float wy = use_antialias ?
|
|
259
|
+
cubic_kernel_antialias(dy, filter_scale_h) :
|
|
260
|
+
cubic_kernel(dy);
|
|
261
|
+
|
|
262
|
+
for (int x_pos = x_start; x_pos < x_end; x_pos++) {
|
|
263
|
+
float dx = float(x_pos) - x_in;
|
|
264
|
+
float wx = use_antialias ?
|
|
265
|
+
cubic_kernel_antialias(dx, filter_scale_w) :
|
|
266
|
+
cubic_kernel(dx);
|
|
267
|
+
|
|
268
|
+
float weight = wy * wx;
|
|
269
|
+
|
|
270
|
+
// Calculate input tensor offset
|
|
271
|
+
int input_offset = ((b * channels + c) * in_h + y_pos) * in_w + x_pos;
|
|
272
|
+
|
|
273
|
+
// Add weighted contribution
|
|
274
|
+
result += input[input_offset] * weight;
|
|
275
|
+
weight_sum += weight;
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
// Normalize by weight sum
|
|
280
|
+
if (weight_sum > 1e-8f) {
|
|
281
|
+
result /= weight_sum;
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
// Calculate output tensor offset
|
|
285
|
+
int output_offset = ((b * channels + c) * out_h + y_out) * out_w + x_out;
|
|
286
|
+
|
|
287
|
+
// Assign the result to output
|
|
288
|
+
output[output_offset] = result;
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
# Create the kernel
|
|
292
|
+
kernel = mx.fast.metal_kernel(
|
|
293
|
+
name="bicubic_interpolation_antialias",
|
|
294
|
+
input_names=["input", "dims", "params"],
|
|
295
|
+
output_names=["output"],
|
|
296
|
+
source=source,
|
|
297
|
+
header=header,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Run the kernel
|
|
301
|
+
threadgroup = get_optimal_threadgroup(out_w, out_h)
|
|
302
|
+
outputs = kernel(
|
|
303
|
+
inputs=[x_flat, dims, params],
|
|
304
|
+
grid=(out_w, out_h, batch_size * channels),
|
|
305
|
+
threadgroup=threadgroup,
|
|
306
|
+
output_shapes=[(batch_size * channels * out_h * out_w,)],
|
|
307
|
+
output_dtypes=[mx.float32],
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Reshape output back to 4D tensor and convert back to original dtype
|
|
311
|
+
result = outputs[0].reshape(batch_size, channels, out_h, out_w)
|
|
312
|
+
if input_dtype != mx.float32:
|
|
313
|
+
result = result.astype(input_dtype)
|
|
314
|
+
|
|
315
|
+
return result
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def grid_sample(x, grid):
|
|
319
|
+
"""
|
|
320
|
+
Grid sample using MLX's built-in interpolate function.
|
|
321
|
+
Args:
|
|
322
|
+
x: MLX tensor of shape [B, C, H, W]
|
|
323
|
+
grid: MLX tensor of shape [B, gN, gM, 2]
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Interpolated MLX tensor
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
assert x.ndim == 4, "`x` must be 4D."
|
|
330
|
+
assert grid.ndim == 4, "`grid` must be 4D."
|
|
331
|
+
|
|
332
|
+
B, _, _, C = x.shape
|
|
333
|
+
_, gN, gM, D = grid.shape
|
|
334
|
+
out_shape = (B, gN, gM, C)
|
|
335
|
+
|
|
336
|
+
assert D == 2, "Last dim of `grid` must be size 2."
|
|
337
|
+
|
|
338
|
+
source = """
|
|
339
|
+
uint elem = thread_position_in_grid.x;
|
|
340
|
+
int H = x_shape[1];
|
|
341
|
+
int W = x_shape[2];
|
|
342
|
+
int C = x_shape[3];
|
|
343
|
+
int gH = grid_shape[1];
|
|
344
|
+
int gW = grid_shape[2];
|
|
345
|
+
|
|
346
|
+
int w_stride = C;
|
|
347
|
+
int h_stride = W * w_stride;
|
|
348
|
+
int b_stride = H * h_stride;
|
|
349
|
+
|
|
350
|
+
uint grid_idx = elem / C * 2;
|
|
351
|
+
float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
|
|
352
|
+
float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
|
|
353
|
+
|
|
354
|
+
int ix_nw = floor(ix);
|
|
355
|
+
int iy_nw = floor(iy);
|
|
356
|
+
|
|
357
|
+
int ix_ne = ix_nw + 1;
|
|
358
|
+
int iy_ne = iy_nw;
|
|
359
|
+
|
|
360
|
+
int ix_sw = ix_nw;
|
|
361
|
+
int iy_sw = iy_nw + 1;
|
|
362
|
+
|
|
363
|
+
int ix_se = ix_nw + 1;
|
|
364
|
+
int iy_se = iy_nw + 1;
|
|
365
|
+
|
|
366
|
+
T nw = (ix_se - ix) * (iy_se - iy);
|
|
367
|
+
T ne = (ix - ix_sw) * (iy_sw - iy);
|
|
368
|
+
T sw = (ix_ne - ix) * (iy - iy_ne);
|
|
369
|
+
T se = (ix - ix_nw) * (iy - iy_nw);
|
|
370
|
+
|
|
371
|
+
int batch_idx = elem / C / gH / gW * b_stride;
|
|
372
|
+
int channel_idx = elem % C;
|
|
373
|
+
int base_idx = batch_idx + channel_idx;
|
|
374
|
+
|
|
375
|
+
T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
|
|
376
|
+
T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
|
|
377
|
+
T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
|
|
378
|
+
T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
|
|
379
|
+
|
|
380
|
+
I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
|
|
381
|
+
I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
|
|
382
|
+
I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
|
|
383
|
+
I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
|
|
384
|
+
|
|
385
|
+
out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
kernel = mx.fast.metal_kernel(
|
|
389
|
+
name="grid_sample",
|
|
390
|
+
input_names=["x", "grid"],
|
|
391
|
+
output_names=["out"],
|
|
392
|
+
source=source,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
outputs = kernel(
|
|
396
|
+
inputs=[x, grid],
|
|
397
|
+
template=[("T", x.dtype)],
|
|
398
|
+
output_shapes=[out_shape],
|
|
399
|
+
output_dtypes=[x.dtype],
|
|
400
|
+
grid=(mx.prod(mx.array(out_shape)), 1, 1),
|
|
401
|
+
threadgroup=(256, 1, 1),
|
|
402
|
+
)
|
|
403
|
+
return outputs[0]
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def get_optimal_threadgroup(out_w, out_h):
|
|
407
|
+
# Calculate optimal threadgroup dimensions based on output dimensions
|
|
408
|
+
|
|
409
|
+
# Maximum threadgroup size for most Metal GPUs
|
|
410
|
+
# This could be made more dynamic with Metal API queries if needed
|
|
411
|
+
MAX_THREADS_PER_GROUP = 1024
|
|
412
|
+
MAX_THREADS_PER_DIM = 1024
|
|
413
|
+
|
|
414
|
+
# Start with a reasonable default size for 2D workloads
|
|
415
|
+
default_threadgroup = (32, 32, 1)
|
|
416
|
+
|
|
417
|
+
try:
|
|
418
|
+
# Don't create threadgroups larger than the work dimensions
|
|
419
|
+
max_width = min(MAX_THREADS_PER_DIM, out_w)
|
|
420
|
+
max_height = min(MAX_THREADS_PER_DIM, out_h)
|
|
421
|
+
|
|
422
|
+
# Find largest power of 2 that fits within our dimensions
|
|
423
|
+
width = 2 ** (max_width.bit_length() - 1)
|
|
424
|
+
if width > max_width:
|
|
425
|
+
width = width // 2
|
|
426
|
+
|
|
427
|
+
height = 2 ** (max_height.bit_length() - 1)
|
|
428
|
+
if height > max_height:
|
|
429
|
+
height = height // 2
|
|
430
|
+
|
|
431
|
+
# Ensure we don't exceed maximum threads per threadgroup
|
|
432
|
+
while width * height > MAX_THREADS_PER_GROUP:
|
|
433
|
+
# Reduce the larger dimension first
|
|
434
|
+
if width >= height:
|
|
435
|
+
width = width // 2
|
|
436
|
+
else:
|
|
437
|
+
height = height // 2
|
|
438
|
+
|
|
439
|
+
# Ensure minimum size for efficiency
|
|
440
|
+
width = max(8, width)
|
|
441
|
+
height = max(8, height)
|
|
442
|
+
|
|
443
|
+
return (width, height, 1)
|
|
444
|
+
|
|
445
|
+
except Exception:
|
|
446
|
+
# Return safe defaults if calculation fails
|
|
447
|
+
return default_threadgroup
|