fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,994 @@
1
+ from collections.abc import Sequence
2
+ from math import sqrt
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+ from mlx_vlm.models.gemma3n.config import (
9
+ EdgeResidualConfig,
10
+ MultiQueryAttentionBlockConfig,
11
+ UniversalInvertedResidualConfig,
12
+ VisionConfig,
13
+ )
14
+
15
+ from ..kernels import bicubic_interpolate, nearest_interpolate
16
+
17
+
18
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L24
19
+ class MobileNetV5MultiScaleFusionAdapter(nn.Module):
20
+ """Multi-layer fusion token adapter.
21
+ Attributes:
22
+ out_filters: The number of output filters.
23
+ output_resolution: The output resolution.
24
+ activation: The activation function.
25
+ expansion_ratio: The expansion ratio.
26
+ upsampling_interpolation: The upsampling interpolation.
27
+ use_layer_scale: Whether to use layer scale.
28
+ layer_scale_init_value: The initial value of the layer scale.
29
+ skip_projection: Whether to skip the projection.
30
+ name: The name of the module.
31
+ upsize: The upsampling fn.
32
+ downsize: The downsampling fn.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ in_chs: List[int],
38
+ out_chs: int,
39
+ output_resolution: int,
40
+ expansion_ratio: float = 2.0,
41
+ interpolation_mode: str = "nearest",
42
+ use_layer_scale: bool = False,
43
+ layer_scale_init_value: float = 1e-5,
44
+ noskip: bool = True,
45
+ ):
46
+ super().__init__()
47
+ self.in_channels = sum(in_chs) if isinstance(in_chs, Sequence) else in_chs
48
+ self.out_channels = out_chs
49
+ self.output_resolution = to_2tuple(output_resolution)
50
+ self.expansion_ratio = expansion_ratio
51
+ self.interpolation_mode = interpolation_mode
52
+ self.use_layer_scale = use_layer_scale
53
+ self.layer_scale_init_value = layer_scale_init_value
54
+ self.noskip = noskip
55
+
56
+ norm_layer = RMSNormAct2d
57
+ self.ffn = UniversalInvertedResidual(
58
+ in_chs=self.in_channels,
59
+ out_chs=self.out_channels,
60
+ dw_kernel_size_mid=0,
61
+ exp_ratio=self.expansion_ratio,
62
+ norm_layer=norm_layer,
63
+ noskip=self.noskip,
64
+ layer_scale_init_value=(
65
+ self.layer_scale_init_value if self.use_layer_scale else None
66
+ ),
67
+ )
68
+
69
+ self.norm = norm_layer(self.out_channels, eps=1e-6, apply_act=False)
70
+
71
+ def __call__(self, inputs: list[mx.array]) -> mx.array:
72
+ inputs = [i.transpose(0, 3, 1, 2) for i in inputs]
73
+ high_resolution = inputs[0].shape[
74
+ -2:
75
+ ] # Assuming the first input is the highest resolution.
76
+ resized_inputs = []
77
+
78
+ for _, img in enumerate(inputs):
79
+ if any([r < hr for r, hr in zip(img.shape[-2:], high_resolution)]):
80
+ img = nearest_interpolate(img, size=high_resolution)
81
+
82
+ resized_inputs.append(img)
83
+
84
+ channel_cat_imgs = mx.concatenate(
85
+ resized_inputs, axis=1
86
+ ) # Cat on channel dim, must equal self.in_channels
87
+ img = self.ffn(channel_cat_imgs.swapaxes(1, 3)).swapaxes(1, 3)
88
+
89
+ if any([ro != rh for ro, rh in zip(high_resolution, self.output_resolution)]):
90
+ if (
91
+ high_resolution[0] % self.output_resolution[0] != 0
92
+ or high_resolution[1] % self.output_resolution[1] != 0
93
+ ):
94
+ img = bicubic_interpolate(img, self.output_resolution)
95
+ else:
96
+ h_strides = high_resolution[0] // self.output_resolution[0]
97
+ w_strides = high_resolution[1] // self.output_resolution[1]
98
+
99
+ img = nn.AvgPool2d(
100
+ kernel_size=(h_strides, w_strides),
101
+ stride=(h_strides, w_strides),
102
+ )(img.swapaxes(1, 3))
103
+
104
+ img = self.norm(img) if self.noskip else img
105
+
106
+ return img
107
+
108
+
109
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/layer_scale.py#L22
110
+ class LayerScale2d(nn.Module):
111
+ def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
112
+ super().__init__()
113
+ self.inplace = inplace
114
+ self.gamma = init_values * mx.ones((dim,))
115
+
116
+ def __call__(self, x: mx.array) -> mx.array:
117
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
118
+
119
+
120
+ def rms_norm2d(
121
+ x: mx.array,
122
+ normalized_shape: List[int],
123
+ weight: Optional[mx.array] = None,
124
+ eps: float = 1e-5,
125
+ ):
126
+ assert len(normalized_shape) == 1
127
+ dtype = x.dtype
128
+ v = mx.power(x, 2)
129
+ v = mx.mean(v, axis=1, keepdims=True)
130
+ x = x * mx.rsqrt(v + eps)
131
+ if weight is not None:
132
+ x = x.astype(dtype) * weight.reshape(1, -1, 1, 1)
133
+ return x
134
+
135
+
136
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/norm_act.py#L504
137
+ class RMSNormAct2d(nn.RMSNorm):
138
+ def __init__(
139
+ self,
140
+ num_channels,
141
+ eps=1e-6,
142
+ apply_act: bool = True,
143
+ ):
144
+ super().__init__(dims=num_channels, eps=eps)
145
+ self.normalized_shape = [num_channels]
146
+ self.drop = nn.Identity()
147
+ self.act = nn.GELU() if apply_act else nn.Identity()
148
+
149
+ def __call__(self, x: mx.array) -> mx.array:
150
+
151
+ x = x.transpose(0, 3, 1, 2) # Convert from NHWC to NCHW
152
+ x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
153
+ x = self.drop(x)
154
+ x = self.act(x)
155
+ x = x.transpose(0, 2, 3, 1) # Convert back to NHWC
156
+ return x
157
+
158
+
159
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L310
160
+ class UniversalInvertedResidual(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_chs: int,
164
+ out_chs: int,
165
+ dw_kernel_size_start: int = 0,
166
+ dw_kernel_size_mid: int = 3,
167
+ dw_kernel_size_end: int = 0,
168
+ stride: int = 1,
169
+ dilation: int = 1,
170
+ group_size: int = 1,
171
+ pad_type: str = "",
172
+ noskip: bool = False,
173
+ exp_ratio: float = 1.0,
174
+ norm_layer=RMSNormAct2d,
175
+ conv_kwargs: Optional[Dict] = None,
176
+ drop_path_rate: float = 0.0,
177
+ layer_scale_init_value: Optional[float] = 1e-5,
178
+ ):
179
+ super().__init__()
180
+ self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
181
+ if stride > 1:
182
+ assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
183
+
184
+ if dw_kernel_size_start:
185
+ dw_start_stride = stride if not dw_kernel_size_mid else 1
186
+ dw_start_groups = num_groups(group_size, in_chs)
187
+ self.dw_start = ConvNormAct(
188
+ nn.Conv2d,
189
+ in_chs,
190
+ in_chs,
191
+ kernel_size=dw_kernel_size_start,
192
+ stride=dw_start_stride,
193
+ padding=(dw_kernel_size_start - 1) // 2,
194
+ dilation=dilation,
195
+ groups=dw_start_groups,
196
+ bias=False,
197
+ apply_act=False,
198
+ eps=1e-05,
199
+ )
200
+ else:
201
+ self.dw_start = nn.Identity()
202
+
203
+ mid_chs = make_divisible(in_chs * exp_ratio)
204
+ self.pw_exp = ConvNormAct(
205
+ nn.Conv2d,
206
+ in_chs,
207
+ mid_chs,
208
+ kernel_size=1,
209
+ stride=1,
210
+ padding=0,
211
+ groups=1,
212
+ bias=False,
213
+ eps=1e-05,
214
+ )
215
+
216
+ if dw_kernel_size_mid:
217
+ dw_mid_groups = num_groups(group_size, mid_chs)
218
+ self.dw_mid = ConvNormAct(
219
+ Conv2dSame,
220
+ mid_chs,
221
+ mid_chs,
222
+ kernel_size=dw_kernel_size_mid,
223
+ stride=stride,
224
+ padding=0,
225
+ dilation=dilation,
226
+ groups=dw_mid_groups,
227
+ bias=False,
228
+ eps=1e-05,
229
+ )
230
+ else:
231
+ self.dw_mid = nn.Identity()
232
+
233
+ self.pw_proj = ConvNormAct(
234
+ nn.Conv2d,
235
+ mid_chs,
236
+ out_chs,
237
+ kernel_size=1,
238
+ stride=1,
239
+ padding=0,
240
+ groups=1,
241
+ bias=False,
242
+ apply_act=False,
243
+ eps=1e-05,
244
+ )
245
+ if layer_scale_init_value is not None:
246
+ self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
247
+ else:
248
+ self.layer_scale = nn.Identity()
249
+
250
+ def __call__(self, x: mx.array) -> mx.array:
251
+ shortcut = x
252
+ x = self.dw_start(x)
253
+ x = self.pw_exp(x)
254
+ x = self.dw_mid(x)
255
+ x = self.pw_proj(x)
256
+ x = self.layer_scale(x)
257
+ if self.has_skip:
258
+ x = x + shortcut
259
+ return x
260
+
261
+
262
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/conv_bn_act.py#L15
263
+ class ConvNormAct(nn.Module):
264
+ def __init__(
265
+ self,
266
+ conv_cls,
267
+ in_chs: int,
268
+ out_chs: int,
269
+ kernel_size: int = 3,
270
+ stride: int = 1,
271
+ padding: int = 0,
272
+ dilation: int = 1,
273
+ groups: int = 1,
274
+ bias: bool = False,
275
+ apply_act: bool = True,
276
+ eps: float = 1e-6,
277
+ ):
278
+ super().__init__()
279
+ self.out_chs = out_chs
280
+ self.conv = conv_cls(
281
+ in_chs,
282
+ out_chs,
283
+ kernel_size,
284
+ stride,
285
+ padding,
286
+ (dilation, dilation),
287
+ groups,
288
+ bias,
289
+ )
290
+ self.bn = RMSNormAct2d(out_chs, eps=eps, apply_act=apply_act)
291
+
292
+ def __call__(self, x: mx.array) -> mx.array:
293
+ c = self.conv(x)
294
+ r = self.bn(c)
295
+ return r
296
+
297
+
298
+ def pad_same(
299
+ x,
300
+ kernel_size: List[int],
301
+ stride: List[int],
302
+ dilation: List[int] = (1, 1),
303
+ value: float = 0,
304
+ ):
305
+ """
306
+ Input should be in MLX format
307
+ """
308
+ ih, iw = x.shape[1:3]
309
+ pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
310
+ pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
311
+
312
+ # MLX pad format: [(low, high), (low, high), ...] for each axis
313
+ # Padding order is reversed compared to PyTorch F.pad
314
+ pad_widths = [
315
+ (0, 0), # No padding for batch dimension
316
+ (pad_h // 2, pad_h - pad_h // 2), # Height padding
317
+ (pad_w // 2, pad_w - pad_w // 2), # Width padding
318
+ (0, 0), # No padding for channel dimension
319
+ ]
320
+
321
+ x = mx.pad(x, pad_widths, constant_values=value)
322
+ return x
323
+
324
+
325
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
326
+ dynamic = False
327
+ if isinstance(padding, str):
328
+ # for any string padding, the padding will be calculated for you, one of three ways
329
+ padding = padding.lower()
330
+ if padding == "same":
331
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
332
+ if is_static_pad(kernel_size, **kwargs):
333
+ # static case, no extra overhead
334
+ padding = get_padding(kernel_size, **kwargs)
335
+ else:
336
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
337
+ padding = 0
338
+ dynamic = True
339
+ elif padding == "valid":
340
+ # 'VALID' padding, same as padding=0
341
+ padding = 0
342
+ else:
343
+ # Default to PyTorch style 'same'-ish symmetric padding
344
+ padding = get_padding(kernel_size, **kwargs)
345
+ return padding, dynamic
346
+
347
+
348
+ def get_same_padding(
349
+ input_size: int, kernel_size: int, stride: int, dilation: int = 1
350
+ ) -> int:
351
+ """Calculate padding needed for 'same' output size."""
352
+ effective_kernel_size = dilation * (kernel_size - 1) + 1
353
+ output_size = (input_size + stride - 1) // stride
354
+ total_padding = max(
355
+ 0, (output_size - 1) * stride + effective_kernel_size - input_size
356
+ )
357
+ return total_padding
358
+
359
+
360
+ def get_padding(kernel_size, stride=1, dilation=1, **_):
361
+ """Get symmetric padding for given kernel size."""
362
+ if isinstance(kernel_size, int):
363
+ kernel_size = [kernel_size, kernel_size]
364
+ if isinstance(stride, int):
365
+ stride = [stride, stride]
366
+ if isinstance(dilation, int):
367
+ dilation = [dilation, dilation]
368
+
369
+ padding = []
370
+ for k, d in zip(kernel_size, dilation):
371
+ effective_k = d * (k - 1) + 1
372
+ pad_total = effective_k - 1
373
+ padding.append(pad_total // 2)
374
+ return tuple(padding)
375
+
376
+
377
+ def is_static_pad(kernel_size, stride=1, dilation=1, **_):
378
+ """Check if padding can be calculated statically."""
379
+ if isinstance(kernel_size, int):
380
+ kernel_size = [kernel_size, kernel_size]
381
+ if isinstance(stride, int):
382
+ stride = [stride, stride]
383
+ if isinstance(dilation, int):
384
+ dilation = [dilation, dilation]
385
+
386
+ # Static padding is possible when stride is 1 for all dimensions
387
+ return all(s == 1 for s in stride)
388
+
389
+
390
+ class Conv2dSame(nn.Conv2d):
391
+ def __init__(self, *args, **kwargs):
392
+ super().__init__(*args, **kwargs)
393
+ self.kernel_size = self.weight.shape[1:3]
394
+
395
+ def __call__(self, x: mx.array) -> mx.array:
396
+ x = pad_same(x, self.kernel_size, self.stride, self.dilation)
397
+ y = mx.conv2d(
398
+ x, self.weight, self.stride, self.padding, self.dilation, self.groups
399
+ )
400
+ if "bias" in self:
401
+ y = y + self.bias
402
+ return y
403
+
404
+
405
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L629
406
+ class EdgeResidual(nn.Module):
407
+ def __init__(
408
+ self,
409
+ in_chs: int,
410
+ out_chs: int,
411
+ exp_kernel_size: int = 3,
412
+ stride: int = 1,
413
+ dilation: int = 1,
414
+ group_size: int = 0,
415
+ pad_type: str = "",
416
+ force_in_chs: int = 0,
417
+ noskip: bool = False,
418
+ expand_ratio: float = 1.0,
419
+ pw_kernel_size: int = 1,
420
+ norm_layer=RMSNormAct2d,
421
+ ):
422
+ super().__init__()
423
+
424
+ if force_in_chs > 0:
425
+ mid_chs = make_divisible(force_in_chs * expand_ratio)
426
+ else:
427
+ mid_chs = make_divisible(in_chs * expand_ratio)
428
+
429
+ groups = num_groups(group_size, mid_chs)
430
+
431
+ self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
432
+
433
+ self.conv_exp = Conv2dSame(
434
+ in_chs,
435
+ mid_chs,
436
+ kernel_size=exp_kernel_size,
437
+ stride=stride,
438
+ padding=0,
439
+ dilation=(dilation, dilation),
440
+ groups=groups,
441
+ bias=False,
442
+ )
443
+
444
+ self.bn1 = norm_layer(mid_chs, eps=1e-05) if norm_layer else nn.Identity()
445
+
446
+ # Point-wise linear projection
447
+ padding_pwl = (pw_kernel_size - 1) // 2
448
+ self.conv_pwl = nn.Conv2d(
449
+ mid_chs,
450
+ out_chs,
451
+ kernel_size=pw_kernel_size,
452
+ padding=padding_pwl,
453
+ bias=False,
454
+ )
455
+
456
+ self.bn2 = (
457
+ norm_layer(out_chs, eps=1e-05, apply_act=False)
458
+ if norm_layer
459
+ else nn.Identity()
460
+ )
461
+
462
+ def __call__(self, x: mx.array) -> mx.array:
463
+ shortcut = x
464
+ x = self.conv_exp(x)
465
+ x = self.bn1(x)
466
+ x = self.conv_pwl(x)
467
+ x = self.bn2(x)
468
+ if self.has_skip:
469
+ x = x + shortcut
470
+ return x
471
+
472
+
473
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L449
474
+ class MobileAttention(nn.Module):
475
+ def __init__(
476
+ self,
477
+ in_chs: int,
478
+ out_chs: int,
479
+ stride: int = 1,
480
+ dw_kernel_size: int = 3,
481
+ dilation: int = 1,
482
+ group_size: int = 1,
483
+ pad_type: str = "",
484
+ num_heads: int = 8,
485
+ key_dim: int = 64,
486
+ value_dim: int = 64,
487
+ use_multi_query: bool = True,
488
+ query_strides: Tuple[int, int] = (1, 1),
489
+ kv_stride: int = 1,
490
+ cpe_dw_kernel_size: int = 3,
491
+ noskip: bool = False,
492
+ act_layer=nn.GELU,
493
+ aa_layer=None,
494
+ drop_path_rate: float = 0.0,
495
+ attn_drop: float = 0.0,
496
+ proj_drop: float = 0.0,
497
+ layer_scale_init_value: Optional[float] = 1e-5,
498
+ use_bias: bool = False,
499
+ ):
500
+ super().__init__()
501
+ self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
502
+ self.query_strides = to_2tuple(query_strides)
503
+ self.kv_stride = kv_stride
504
+ self.has_query_stride = any([s > 1 for s in self.query_strides])
505
+
506
+ # Normalization layer
507
+ self.norm = RMSNormAct2d(
508
+ in_chs,
509
+ eps=1e-05,
510
+ apply_act=False,
511
+ )
512
+ # Determine number of heads if not provided
513
+ if num_heads is None:
514
+ assert in_chs % key_dim == 0
515
+ num_heads = in_chs // key_dim
516
+
517
+ # Attention layer
518
+ if use_multi_query:
519
+ self.attn = MultiQueryAttention2d(
520
+ in_chs,
521
+ dim_out=out_chs,
522
+ num_heads=num_heads,
523
+ key_dim=key_dim,
524
+ value_dim=value_dim,
525
+ query_strides=query_strides,
526
+ kv_stride=kv_stride,
527
+ dilation=dilation,
528
+ padding=pad_type,
529
+ dw_kernel_size=dw_kernel_size,
530
+ attn_drop=attn_drop,
531
+ proj_drop=proj_drop,
532
+ )
533
+ else:
534
+ raise NotImplementedError("attention not implemented")
535
+
536
+ # Layer scaling
537
+ if layer_scale_init_value is not None:
538
+ self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
539
+ else:
540
+ self.layer_scale = nn.Identity()
541
+
542
+ # Drop path for residual connection
543
+ self.drop_path = nn.Identity()
544
+
545
+ def __call__(self, x: mx.array) -> mx.array:
546
+ shortcut = x
547
+ x = self.norm(x)
548
+ x = self.attn(x)
549
+ x = self.layer_scale(x)
550
+
551
+ # Apply skip connection if available
552
+ if self.has_skip:
553
+ x = self.drop_path(x) + shortcut
554
+ return x
555
+
556
+
557
+ def create_conv2d(
558
+ in_channels,
559
+ out_channels,
560
+ kernel_size,
561
+ stride=1,
562
+ dilation=1,
563
+ depthwise=False,
564
+ bias=False,
565
+ **kwargs,
566
+ ):
567
+ """Helper function to create a 2D convolution with common parameters"""
568
+ if depthwise:
569
+ # Depthwise convolution
570
+ return nn.Conv2d(
571
+ in_channels,
572
+ out_channels,
573
+ kernel_size=kernel_size,
574
+ stride=stride,
575
+ padding=(kernel_size - 1) // 2 * dilation,
576
+ dilation=dilation,
577
+ groups=in_channels,
578
+ bias=bias,
579
+ )
580
+ else:
581
+ # Regular convolution
582
+ return nn.Conv2d(
583
+ in_channels,
584
+ out_channels,
585
+ kernel_size=kernel_size,
586
+ stride=stride,
587
+ padding=(kernel_size - 1) // 2 * dilation,
588
+ dilation=dilation,
589
+ bias=bias,
590
+ )
591
+
592
+
593
+ def to_2tuple(x):
594
+ """Convert input to 2-tuple"""
595
+ if isinstance(x, tuple):
596
+ return x
597
+ return (x, x)
598
+
599
+
600
+ class NamedSequential(nn.Module):
601
+ def __init__(self):
602
+ super().__init__()
603
+ self._order = []
604
+
605
+ def add_module(self, name, module):
606
+ setattr(self, name, module)
607
+ self._order.append(name)
608
+
609
+ def __call__(self, x):
610
+ for name in self._order:
611
+ x = getattr(self, name)(x)
612
+ return x
613
+
614
+
615
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/attention2d.py#L82
616
+ class MultiQueryAttention2d(nn.Module):
617
+ def __init__(
618
+ self,
619
+ dim: int,
620
+ dim_out: Optional[int] = None,
621
+ num_heads: int = 8,
622
+ key_dim: int = 64,
623
+ value_dim: int = 64,
624
+ query_strides: Tuple[int, int] = (1, 1),
625
+ kv_stride: int = 1,
626
+ dilation: int = 1,
627
+ padding: str = "",
628
+ dw_kernel_size: int = 3,
629
+ attn_drop: float = 0.0,
630
+ proj_drop: float = 0.0,
631
+ ):
632
+ super().__init__()
633
+ dim_out = dim_out or dim
634
+ self.num_heads = num_heads
635
+ self.query_strides = to_2tuple(query_strides)
636
+ self.kv_stride = kv_stride
637
+ self.fused_attn = True
638
+ self.key_dim = key_dim
639
+ self.value_dim = value_dim
640
+ head_dim = key_dim
641
+ self.scale = head_dim**-0.5
642
+
643
+ self.query = NamedSequential()
644
+ self.query.add_module(
645
+ "proj",
646
+ create_conv2d(
647
+ dim,
648
+ self.num_heads * self.key_dim,
649
+ kernel_size=1,
650
+ ),
651
+ )
652
+ self.key = NamedSequential()
653
+ if kv_stride > 1:
654
+ self.key.add_module(
655
+ "down_conv",
656
+ create_conv2d(
657
+ dim,
658
+ dim,
659
+ kernel_size=dw_kernel_size,
660
+ stride=kv_stride,
661
+ dilation=dilation,
662
+ padding=padding,
663
+ depthwise=True,
664
+ ),
665
+ )
666
+ self.key.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
667
+ self.key.add_module(
668
+ "proj", create_conv2d(dim, key_dim, kernel_size=1, bias=False)
669
+ )
670
+
671
+ self.value = NamedSequential()
672
+ if kv_stride > 1:
673
+ self.value.add_module(
674
+ "down_conv",
675
+ create_conv2d(
676
+ dim,
677
+ dim,
678
+ kernel_size=dw_kernel_size,
679
+ stride=kv_stride,
680
+ dilation=dilation,
681
+ padding=padding,
682
+ depthwise=True,
683
+ ),
684
+ )
685
+ self.value.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
686
+ self.value.add_module(
687
+ "proj", create_conv2d(dim, value_dim, kernel_size=1, bias=False)
688
+ )
689
+
690
+ # Attention dropout
691
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
692
+
693
+ # Output projection
694
+ self.output = NamedSequential()
695
+ self.output.add_module(
696
+ "proj",
697
+ create_conv2d(
698
+ value_dim * num_heads,
699
+ dim_out,
700
+ kernel_size=1,
701
+ stride=1,
702
+ bias=False,
703
+ ),
704
+ )
705
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
706
+
707
+ def _reshape_input(self, t: mx.array):
708
+ """
709
+ Input shape MLX: [B, H, W, C]
710
+ Input shape PyTorch: [B, C, H, W]
711
+
712
+ PyTorch Reshape: [B, C, H, W] -> [B, C, -1] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
713
+ MLX Reshape: [B, H, W, C] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
714
+ """
715
+ s = t.shape
716
+ t = t.reshape(s[0], -1, s[3])[:, None, :, :]
717
+
718
+ return t
719
+
720
+ def _reshape_projected_query(self, t: mx.array, num_heads: int, key_dim: int):
721
+ """
722
+ Input shape MLX: [B, H, W, C] where C = num_heads * key_dim
723
+ """
724
+ B, H, W, C = t.shape
725
+ # t = t.reshape(B, H, W, num_heads, key_dim)
726
+ t = t.reshape(B, H * W, num_heads, key_dim)
727
+ return t.transpose(0, 2, 1, 3)
728
+
729
+ def _reshape_output(self, t: mx.array, num_heads: int, h_px: int, w_px: int):
730
+ """
731
+ Input shape: [B, NH, L, D] where L = h_px * w_px
732
+ Output shape MLX: [B, H, W, C] where C = NH * D
733
+ """
734
+ B, NH, L, D = t.shape
735
+ # First transpose to [B, L, NH, D]
736
+ t = t.transpose(0, 2, 1, 3)
737
+ # Then reshape to [B, H, W, NH*D]
738
+ t = t.reshape(B, h_px, w_px, NH * D)
739
+ return t
740
+
741
+ def __call__(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
742
+ B, H, W, C = x.shape
743
+ q = self.query(x)
744
+ q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
745
+
746
+ k = self.key(x)
747
+ k = self._reshape_input(k)
748
+
749
+ v = self.value(x)
750
+ v = self._reshape_input(v)
751
+
752
+ if self.fused_attn:
753
+ o = mx.fast.scaled_dot_product_attention(
754
+ q,
755
+ k,
756
+ v,
757
+ scale=1.0 / sqrt(q.shape[-1]),
758
+ )
759
+ else:
760
+ raise NotImplementedError("unfused attention not implemented")
761
+
762
+ o = self._reshape_output(
763
+ o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]
764
+ )
765
+ x = self.output(o)
766
+ return x
767
+
768
+
769
+ def num_groups(group_size: Optional[int], channels: int) -> int:
770
+ if not group_size: # 0 or None
771
+ return 1 # normal conv with 1 group
772
+ else:
773
+ # NOTE group_size == 1 -> depthwise conv
774
+ assert channels % group_size == 0
775
+ return channels // group_size
776
+
777
+
778
+ def make_divisible(v, divisor: int = 8, min_value=None, round_limit: float = 0.9):
779
+ min_value = min_value or divisor
780
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
781
+ # Make sure that round down does not go down by more than 10%.
782
+ if new_v < round_limit * v:
783
+ new_v += divisor
784
+ return new_v
785
+
786
+
787
+ def _er(kernel_size, filters, strides=1, expand_ratio=4.0, is_multiscale=False):
788
+ return EdgeResidualConfig(
789
+ kernel_size=kernel_size,
790
+ filters=filters,
791
+ strides=strides,
792
+ expand_ratio=expand_ratio,
793
+ is_multiscale=is_multiscale,
794
+ )
795
+
796
+
797
+ def _uir(
798
+ start_dw_kernel_size,
799
+ mid_dw_kernel_size,
800
+ filters,
801
+ strides=1,
802
+ expand_ratio=4.0,
803
+ is_multiscale=False,
804
+ ):
805
+ return UniversalInvertedResidualConfig(
806
+ start_dw_kernel_size=start_dw_kernel_size,
807
+ mid_dw_kernel_size=mid_dw_kernel_size,
808
+ filters=filters,
809
+ strides=strides,
810
+ expand_ratio=expand_ratio,
811
+ is_multiscale=is_multiscale,
812
+ )
813
+
814
+
815
+ def _mmqa(
816
+ num_heads,
817
+ kv_dim,
818
+ kv_strides,
819
+ mmqa_avg_pool_kv=False,
820
+ is_multiscale=False,
821
+ ):
822
+ conf = MultiQueryAttentionBlockConfig(
823
+ num_heads=num_heads,
824
+ kv_dim=kv_dim,
825
+ kv_strides=kv_strides,
826
+ mmqa_avg_pool_kv=mmqa_avg_pool_kv,
827
+ is_multiscale=is_multiscale,
828
+ )
829
+ return conf
830
+
831
+
832
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L596
833
+ def gemma3n_mobilenet_def():
834
+ return [
835
+ # Stage 1: Edge Residuals
836
+ [_er(3, 128, 2)] + [_er(3, 128, 1)] * 2,
837
+ # Stage 2: Universal Inverted Residuals
838
+ [_uir(3, 5, 256, 2, 6.0)] + [_uir(k, 0, 256) for k in [5, 3, 5, 3]],
839
+ # Stage 3: Universal Inverted Residuals with Multi-Query Attention
840
+ [_uir(5, 5, 640, 2, 6.0)]
841
+ + [_uir(5, 0, 640)] * 7
842
+ + [_uir(0, 0, 640, 1, 1.0)]
843
+ + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0)] * 13
844
+ + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0, is_multiscale=True)],
845
+ # Stage 4: Universal Inverted Residuals with Multi-Query Attention
846
+ [_uir(5, 5, 1280, 2, 6.0)]
847
+ + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0)] * 18
848
+ + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0, is_multiscale=True)],
849
+ ]
850
+
851
+
852
+ class VisionTower(nn.Module):
853
+ def __init__(self, config: VisionConfig):
854
+ super().__init__()
855
+ self.conv_stem = ConvNormAct(
856
+ Conv2dSame,
857
+ in_chs=3,
858
+ out_chs=64,
859
+ kernel_size=3,
860
+ stride=2,
861
+ padding=0,
862
+ eps=1e-05,
863
+ bias=True,
864
+ )
865
+ msfa_indices = (3, 4)
866
+ msfa_output_resolution = (16, 16)
867
+
868
+ (num_features, self.blocks) = self.build()
869
+ self.num_features = self.head_hidden_size = (
870
+ num_features # output of msfa is output of forward_features()
871
+ )
872
+ self.msfa_indices = msfa_indices
873
+ self.msfa_output_resolution = msfa_output_resolution
874
+
875
+ self.msfa = MobileNetV5MultiScaleFusionAdapter(
876
+ in_chs=[1920],
877
+ out_chs=2048,
878
+ output_resolution=self.msfa_output_resolution,
879
+ )
880
+
881
+ def build(self):
882
+ blocks = []
883
+ in_chs = self.conv_stem.out_chs
884
+ for stage, block_config in enumerate(gemma3n_mobilenet_def()):
885
+ block_group = []
886
+ for config in block_config:
887
+ match config:
888
+ case EdgeResidualConfig(
889
+ kernel_size, filters, strides, expand_ratio, is_multiscale
890
+ ):
891
+ x = EdgeResidual(
892
+ exp_kernel_size=kernel_size,
893
+ in_chs=in_chs,
894
+ out_chs=filters,
895
+ stride=strides,
896
+ expand_ratio=expand_ratio,
897
+ )
898
+ in_chs = filters # in_chs of next is out_chs of prev
899
+ block_group.append(x)
900
+ case UniversalInvertedResidualConfig(
901
+ start_dw_kernel_size,
902
+ mid_dw_kernel_size,
903
+ filters,
904
+ strides,
905
+ expand_ratio,
906
+ is_multiscale,
907
+ ):
908
+ x = UniversalInvertedResidual(
909
+ in_chs=in_chs,
910
+ out_chs=filters,
911
+ dw_kernel_size_start=start_dw_kernel_size,
912
+ dw_kernel_size_mid=mid_dw_kernel_size,
913
+ stride=strides,
914
+ exp_ratio=expand_ratio,
915
+ )
916
+ in_chs = filters
917
+ block_group.append(x)
918
+ case MultiQueryAttentionBlockConfig(
919
+ num_heads,
920
+ kv_dim,
921
+ kv_strides,
922
+ mmqa_avg_pool_kv,
923
+ is_multiscale,
924
+ ):
925
+ x = MobileAttention(
926
+ in_chs=in_chs,
927
+ out_chs=in_chs,
928
+ stride=1,
929
+ num_heads=num_heads,
930
+ key_dim=kv_dim,
931
+ value_dim=kv_dim,
932
+ kv_stride=kv_strides,
933
+ act_layer=None,
934
+ )
935
+ block_group.append(x)
936
+ case _:
937
+ continue
938
+ blocks.append(block_group)
939
+ return (in_chs, blocks)
940
+
941
+ def __call__(
942
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
943
+ ) -> mx.array:
944
+ feat_idx = 0
945
+ x = x.transpose(0, 2, 3, 1) # Convert from NCHW to NHWC
946
+ x = self.conv_stem(x)
947
+ intermediates = []
948
+
949
+ if feat_idx in self.msfa_indices:
950
+ intermediates.append(x)
951
+
952
+ # MBV5 is constructed of 4 stages, each stage is a group of blocks.
953
+ for block_group in self.blocks:
954
+ feat_idx += 1
955
+ for block in block_group:
956
+ x = block(x)
957
+
958
+ if feat_idx in self.msfa_indices:
959
+ intermediates.append(x)
960
+
961
+ x = self.msfa(intermediates)
962
+ return x
963
+
964
+
965
+ class VisionModel(nn.Module):
966
+ def __init__(self, config: VisionConfig):
967
+ super().__init__()
968
+ self.model_type = config.model_type
969
+ if self.model_type not in ["gemma3", "gemma3_vision", "gemma3n_vision"]:
970
+ raise ValueError(f"Unsupported model type: {self.model_type}")
971
+
972
+ self.timm_model = VisionTower(config)
973
+
974
+ def __call__(
975
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
976
+ ) -> mx.array:
977
+ return self.timm_model(x, output_hidden_states)
978
+
979
+ def sanitize(self, weights):
980
+ sanitized_weights = {}
981
+ skip_transpose = False
982
+ _, H, _, C = weights["vision_tower.timm_model.blocks.0.0.conv_exp.weight"].shape
983
+ if C > H:
984
+ skip_transpose = True
985
+
986
+ for k, v in weights.items():
987
+ # PyTorch conv2d weight: [out_channels, in_channels, kH, kW]
988
+ # MLX conv2d weight: [out_channels, kH, KW, in_channels]
989
+ if ("conv" in k and "weight" in k) or ("attn" and "proj.weight") in k:
990
+ if len(v.shape) == 4 and not skip_transpose:
991
+ v = v.transpose(0, 2, 3, 1)
992
+ sanitized_weights[k] = v
993
+
994
+ return sanitized_weights