nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (196) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
  5. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
  6. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  7. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  8. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  9. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  10. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  11. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  12. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  13. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  14. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  15. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  16. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  17. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  18. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  19. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  21. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  22. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  23. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  24. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  33. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  34. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  35. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  36. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  37. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  39. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  40. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  41. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  43. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  44. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  45. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  46. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  47. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  48. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  49. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  54. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  55. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  56. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  57. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  58. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  59. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  61. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  62. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  64. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  195. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
  196. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
@@ -1,1022 +0,0 @@
1
- import inspect
2
- from collections.abc import Sequence
3
- from dataclasses import dataclass
4
- from math import sqrt
5
- from typing import Dict, List, Optional, Tuple, Type
6
-
7
- import mlx.core as mx
8
- import mlx.nn as nn
9
-
10
- from .config import VisionConfig
11
-
12
- from ..base import check_array_shape
13
- from ..kernels import bicubic_interpolate, nearest_interpolate
14
-
15
-
16
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L24
17
- class MobileNetV5MultiScaleFusionAdapter(nn.Module):
18
- """Multi-layer fusion token adapter.
19
- Attributes:
20
- out_filters: The number of output filters.
21
- output_resolution: The output resolution.
22
- activation: The activation function.
23
- expansion_ratio: The expansion ratio.
24
- upsampling_interpolation: The upsampling interpolation.
25
- use_layer_scale: Whether to use layer scale.
26
- layer_scale_init_value: The initial value of the layer scale.
27
- skip_projection: Whether to skip the projection.
28
- name: The name of the module.
29
- upsize: The upsampling fn.
30
- downsize: The downsampling fn.
31
- """
32
-
33
- def __init__(
34
- self,
35
- in_chs: List[int],
36
- out_chs: int,
37
- output_resolution: int,
38
- expansion_ratio: float = 2.0,
39
- interpolation_mode: str = "nearest",
40
- use_layer_scale: bool = False,
41
- layer_scale_init_value: float = 1e-5,
42
- noskip: bool = True,
43
- ):
44
- super().__init__()
45
- self.in_channels = sum(in_chs) if isinstance(in_chs, Sequence) else in_chs
46
- self.out_channels = out_chs
47
- self.output_resolution = to_2tuple(output_resolution)
48
- self.expansion_ratio = expansion_ratio
49
- self.interpolation_mode = interpolation_mode
50
- self.use_layer_scale = use_layer_scale
51
- self.layer_scale_init_value = layer_scale_init_value
52
- self.noskip = noskip
53
-
54
- norm_layer = RMSNormAct2d
55
- self.ffn = UniversalInvertedResidual(
56
- in_chs=self.in_channels,
57
- out_chs=self.out_channels,
58
- dw_kernel_size_mid=0,
59
- exp_ratio=self.expansion_ratio,
60
- norm_layer=norm_layer,
61
- noskip=self.noskip,
62
- layer_scale_init_value=(
63
- self.layer_scale_init_value if self.use_layer_scale else None
64
- ),
65
- )
66
-
67
- self.norm = norm_layer(self.out_channels, eps=1e-6, apply_act=False)
68
-
69
- def __call__(self, inputs: list[mx.array]) -> mx.array:
70
- inputs = [i.transpose(0, 3, 1, 2) for i in inputs]
71
- high_resolution = inputs[0].shape[
72
- -2:
73
- ] # Assuming the first input is the highest resolution.
74
- resized_inputs = []
75
-
76
- for _, img in enumerate(inputs):
77
- if any([r < hr for r, hr in zip(img.shape[-2:], high_resolution)]):
78
- img = nearest_interpolate(img, size=high_resolution)
79
-
80
- resized_inputs.append(img)
81
-
82
- channel_cat_imgs = mx.concatenate(
83
- resized_inputs, axis=1
84
- ) # Cat on channel dim, must equal self.in_channels
85
- img = self.ffn(channel_cat_imgs.swapaxes(1, 3)).swapaxes(1, 3)
86
-
87
- if any([ro != rh for ro, rh in zip(high_resolution, self.output_resolution)]):
88
- if (
89
- high_resolution[0] % self.output_resolution[0] != 0
90
- or high_resolution[1] % self.output_resolution[1] != 0
91
- ):
92
- img = bicubic_interpolate(img, self.output_resolution)
93
- else:
94
- h_strides = high_resolution[0] // self.output_resolution[0]
95
- w_strides = high_resolution[1] // self.output_resolution[1]
96
-
97
- img = nn.AvgPool2d(
98
- kernel_size=(h_strides, w_strides),
99
- stride=(h_strides, w_strides),
100
- )(img.swapaxes(1, 3))
101
-
102
- img = self.norm(img) if self.noskip else img
103
-
104
- return img
105
-
106
-
107
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/layer_scale.py#L22
108
- class LayerScale2d(nn.Module):
109
- def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
110
- super().__init__()
111
- self.inplace = inplace
112
- self.gamma = init_values * mx.ones((dim,))
113
-
114
- def __call__(self, x: mx.array) -> mx.array:
115
- return x.mul_(self.gamma) if self.inplace else x * self.gamma
116
-
117
-
118
- def rms_norm2d(
119
- x: mx.array,
120
- normalized_shape: List[int],
121
- weight: Optional[mx.array] = None,
122
- eps: float = 1e-5,
123
- ):
124
- assert len(normalized_shape) == 1
125
- dtype = x.dtype
126
- v = mx.power(x, 2)
127
- v = mx.mean(v, axis=1, keepdims=True)
128
- x = x * mx.rsqrt(v + eps)
129
- if weight is not None:
130
- x = x.astype(dtype) * weight.reshape(1, -1, 1, 1)
131
- return x
132
-
133
-
134
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/norm_act.py#L504
135
- class RMSNormAct2d(nn.RMSNorm):
136
- def __init__(
137
- self,
138
- num_channels,
139
- eps=1e-6,
140
- apply_act: bool = True,
141
- ):
142
- super().__init__(dims=num_channels, eps=eps)
143
- self.normalized_shape = [num_channels]
144
- self.drop = nn.Identity()
145
- self.act = nn.GELU() if apply_act else nn.Identity()
146
-
147
- def __call__(self, x: mx.array) -> mx.array:
148
-
149
- x = x.transpose(0, 3, 1, 2) # Convert from NHWC to NCHW
150
- x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
151
- x = self.drop(x)
152
- x = self.act(x)
153
- x = x.transpose(0, 2, 3, 1) # Convert back to NHWC
154
- return x
155
-
156
-
157
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L310
158
- class UniversalInvertedResidual(nn.Module):
159
- def __init__(
160
- self,
161
- in_chs: int,
162
- out_chs: int,
163
- dw_kernel_size_start: int = 0,
164
- dw_kernel_size_mid: int = 3,
165
- dw_kernel_size_end: int = 0,
166
- stride: int = 1,
167
- dilation: int = 1,
168
- group_size: int = 1,
169
- pad_type: str = "",
170
- noskip: bool = False,
171
- exp_ratio: float = 1.0,
172
- norm_layer=RMSNormAct2d,
173
- conv_kwargs: Optional[Dict] = None,
174
- drop_path_rate: float = 0.0,
175
- layer_scale_init_value: Optional[float] = 1e-5,
176
- ):
177
- super().__init__()
178
- self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
179
- if stride > 1:
180
- assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
181
-
182
- if dw_kernel_size_start:
183
- dw_start_stride = stride if not dw_kernel_size_mid else 1
184
- dw_start_groups = num_groups(group_size, in_chs)
185
- self.dw_start = ConvNormAct(
186
- nn.Conv2d,
187
- in_chs,
188
- in_chs,
189
- kernel_size=dw_kernel_size_start,
190
- stride=dw_start_stride,
191
- padding=(dw_kernel_size_start - 1) // 2,
192
- dilation=dilation,
193
- groups=dw_start_groups,
194
- bias=False,
195
- apply_act=False,
196
- eps=1e-05,
197
- )
198
- else:
199
- self.dw_start = nn.Identity()
200
-
201
- mid_chs = make_divisible(in_chs * exp_ratio)
202
- self.pw_exp = ConvNormAct(
203
- nn.Conv2d,
204
- in_chs,
205
- mid_chs,
206
- kernel_size=1,
207
- stride=1,
208
- padding=0,
209
- groups=1,
210
- bias=False,
211
- eps=1e-05,
212
- )
213
-
214
- if dw_kernel_size_mid:
215
- dw_mid_groups = num_groups(group_size, mid_chs)
216
- self.dw_mid = ConvNormAct(
217
- Conv2dSame,
218
- mid_chs,
219
- mid_chs,
220
- kernel_size=dw_kernel_size_mid,
221
- stride=stride,
222
- padding=0,
223
- dilation=dilation,
224
- groups=dw_mid_groups,
225
- bias=False,
226
- eps=1e-05,
227
- )
228
- else:
229
- self.dw_mid = nn.Identity()
230
-
231
- self.pw_proj = ConvNormAct(
232
- nn.Conv2d,
233
- mid_chs,
234
- out_chs,
235
- kernel_size=1,
236
- stride=1,
237
- padding=0,
238
- groups=1,
239
- bias=False,
240
- apply_act=False,
241
- eps=1e-05,
242
- )
243
- if layer_scale_init_value is not None:
244
- self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
245
- else:
246
- self.layer_scale = nn.Identity()
247
-
248
- def __call__(self, x: mx.array) -> mx.array:
249
- shortcut = x
250
- x = self.dw_start(x)
251
- x = self.pw_exp(x)
252
- x = self.dw_mid(x)
253
- x = self.pw_proj(x)
254
- x = self.layer_scale(x)
255
- if self.has_skip:
256
- x = x + shortcut
257
- return x
258
-
259
-
260
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/conv_bn_act.py#L15
261
- class ConvNormAct(nn.Module):
262
- def __init__(
263
- self,
264
- conv_cls,
265
- in_chs: int,
266
- out_chs: int,
267
- kernel_size: int = 3,
268
- stride: int = 1,
269
- padding: int = 0,
270
- dilation: int = 1,
271
- groups: int = 1,
272
- bias: bool = False,
273
- apply_act: bool = True,
274
- eps: float = 1e-6,
275
- ):
276
- super().__init__()
277
- self.out_chs = out_chs
278
- self.conv = conv_cls(
279
- in_chs,
280
- out_chs,
281
- kernel_size,
282
- stride,
283
- padding,
284
- (dilation, dilation),
285
- groups,
286
- bias,
287
- )
288
- self.bn = RMSNormAct2d(out_chs, eps=eps, apply_act=apply_act)
289
-
290
- def __call__(self, x: mx.array) -> mx.array:
291
- c = self.conv(x)
292
- r = self.bn(c)
293
- return r
294
-
295
-
296
- def pad_same(
297
- x,
298
- kernel_size: List[int],
299
- stride: List[int],
300
- dilation: List[int] = (1, 1),
301
- value: float = 0,
302
- ):
303
- """
304
- Input should be in MLX format
305
- """
306
- ih, iw = x.shape[1:3]
307
- pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
308
- pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
309
-
310
- # MLX pad format: [(low, high), (low, high), ...] for each axis
311
- # Padding order is reversed compared to PyTorch F.pad
312
- pad_widths = [
313
- (0, 0), # No padding for batch dimension
314
- (pad_h // 2, pad_h - pad_h // 2), # Height padding
315
- (pad_w // 2, pad_w - pad_w // 2), # Width padding
316
- (0, 0), # No padding for channel dimension
317
- ]
318
-
319
- x = mx.pad(x, pad_widths, constant_values=value)
320
- return x
321
-
322
-
323
- def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
324
- dynamic = False
325
- if isinstance(padding, str):
326
- # for any string padding, the padding will be calculated for you, one of three ways
327
- padding = padding.lower()
328
- if padding == "same":
329
- # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
330
- if is_static_pad(kernel_size, **kwargs):
331
- # static case, no extra overhead
332
- padding = get_padding(kernel_size, **kwargs)
333
- else:
334
- # dynamic 'SAME' padding, has runtime/GPU memory overhead
335
- padding = 0
336
- dynamic = True
337
- elif padding == "valid":
338
- # 'VALID' padding, same as padding=0
339
- padding = 0
340
- else:
341
- # Default to PyTorch style 'same'-ish symmetric padding
342
- padding = get_padding(kernel_size, **kwargs)
343
- return padding, dynamic
344
-
345
-
346
- def get_same_padding(
347
- input_size: int, kernel_size: int, stride: int, dilation: int = 1
348
- ) -> int:
349
- """Calculate padding needed for 'same' output size."""
350
- effective_kernel_size = dilation * (kernel_size - 1) + 1
351
- output_size = (input_size + stride - 1) // stride
352
- total_padding = max(
353
- 0, (output_size - 1) * stride + effective_kernel_size - input_size
354
- )
355
- return total_padding
356
-
357
-
358
- def get_padding(kernel_size, stride=1, dilation=1, **_):
359
- """Get symmetric padding for given kernel size."""
360
- if isinstance(kernel_size, int):
361
- kernel_size = [kernel_size, kernel_size]
362
- if isinstance(stride, int):
363
- stride = [stride, stride]
364
- if isinstance(dilation, int):
365
- dilation = [dilation, dilation]
366
-
367
- padding = []
368
- for k, d in zip(kernel_size, dilation):
369
- effective_k = d * (k - 1) + 1
370
- pad_total = effective_k - 1
371
- padding.append(pad_total // 2)
372
- return tuple(padding)
373
-
374
-
375
- def is_static_pad(kernel_size, stride=1, dilation=1, **_):
376
- """Check if padding can be calculated statically."""
377
- if isinstance(kernel_size, int):
378
- kernel_size = [kernel_size, kernel_size]
379
- if isinstance(stride, int):
380
- stride = [stride, stride]
381
- if isinstance(dilation, int):
382
- dilation = [dilation, dilation]
383
-
384
- # Static padding is possible when stride is 1 for all dimensions
385
- return all(s == 1 for s in stride)
386
-
387
-
388
- class Conv2dSame(nn.Conv2d):
389
- def __init__(self, *args, **kwargs):
390
- super().__init__(*args, **kwargs)
391
- self.kernel_size = self.weight.shape[1:3]
392
-
393
- def __call__(self, x: mx.array) -> mx.array:
394
- x = pad_same(x, self.kernel_size, self.stride, self.dilation)
395
- y = mx.conv2d(
396
- x, self.weight, self.stride, self.padding, self.dilation, self.groups
397
- )
398
- if "bias" in self:
399
- y = y + self.bias
400
- return y
401
-
402
-
403
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L629
404
- class EdgeResidual(nn.Module):
405
- def __init__(
406
- self,
407
- in_chs: int,
408
- out_chs: int,
409
- exp_kernel_size: int = 3,
410
- stride: int = 1,
411
- dilation: int = 1,
412
- group_size: int = 0,
413
- pad_type: str = "",
414
- force_in_chs: int = 0,
415
- noskip: bool = False,
416
- expand_ratio: float = 1.0,
417
- pw_kernel_size: int = 1,
418
- norm_layer=RMSNormAct2d,
419
- ):
420
- super().__init__()
421
-
422
- if force_in_chs > 0:
423
- mid_chs = make_divisible(force_in_chs * expand_ratio)
424
- else:
425
- mid_chs = make_divisible(in_chs * expand_ratio)
426
-
427
- groups = num_groups(group_size, mid_chs)
428
-
429
- self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
430
-
431
- self.conv_exp = Conv2dSame(
432
- in_chs,
433
- mid_chs,
434
- kernel_size=exp_kernel_size,
435
- stride=stride,
436
- padding=0,
437
- dilation=(dilation, dilation),
438
- groups=groups,
439
- bias=False,
440
- )
441
-
442
- self.bn1 = norm_layer(mid_chs, eps=1e-05) if norm_layer else nn.Identity()
443
-
444
- # Point-wise linear projection
445
- padding_pwl = (pw_kernel_size - 1) // 2
446
- self.conv_pwl = nn.Conv2d(
447
- mid_chs,
448
- out_chs,
449
- kernel_size=pw_kernel_size,
450
- padding=padding_pwl,
451
- bias=False,
452
- )
453
-
454
- self.bn2 = (
455
- norm_layer(out_chs, eps=1e-05, apply_act=False)
456
- if norm_layer
457
- else nn.Identity()
458
- )
459
-
460
- def __call__(self, x: mx.array) -> mx.array:
461
- shortcut = x
462
- x = self.conv_exp(x)
463
- x = self.bn1(x)
464
- x = self.conv_pwl(x)
465
- x = self.bn2(x)
466
- if self.has_skip:
467
- x = x + shortcut
468
- return x
469
-
470
-
471
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L449
472
- class MobileAttention(nn.Module):
473
- def __init__(
474
- self,
475
- in_chs: int,
476
- out_chs: int,
477
- stride: int = 1,
478
- dw_kernel_size: int = 3,
479
- dilation: int = 1,
480
- group_size: int = 1,
481
- pad_type: str = "",
482
- num_heads: int = 8,
483
- key_dim: int = 64,
484
- value_dim: int = 64,
485
- use_multi_query: bool = True,
486
- query_strides: Tuple[int, int] = (1, 1),
487
- kv_stride: int = 1,
488
- cpe_dw_kernel_size: int = 3,
489
- noskip: bool = False,
490
- act_layer=nn.GELU,
491
- aa_layer=None,
492
- drop_path_rate: float = 0.0,
493
- attn_drop: float = 0.0,
494
- proj_drop: float = 0.0,
495
- layer_scale_init_value: Optional[float] = 1e-5,
496
- use_bias: bool = False,
497
- ):
498
- super().__init__()
499
- self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
500
- self.query_strides = to_2tuple(query_strides)
501
- self.kv_stride = kv_stride
502
- self.has_query_stride = any([s > 1 for s in self.query_strides])
503
-
504
- # Normalization layer
505
- self.norm = RMSNormAct2d(
506
- in_chs,
507
- eps=1e-05,
508
- apply_act=False,
509
- )
510
- # Determine number of heads if not provided
511
- if num_heads is None:
512
- assert in_chs % key_dim == 0
513
- num_heads = in_chs // key_dim
514
-
515
- # Attention layer
516
- if use_multi_query:
517
- self.attn = MultiQueryAttention2d(
518
- in_chs,
519
- dim_out=out_chs,
520
- num_heads=num_heads,
521
- key_dim=key_dim,
522
- value_dim=value_dim,
523
- query_strides=query_strides,
524
- kv_stride=kv_stride,
525
- dilation=dilation,
526
- padding=pad_type,
527
- dw_kernel_size=dw_kernel_size,
528
- attn_drop=attn_drop,
529
- proj_drop=proj_drop,
530
- )
531
- else:
532
- raise NotImplementedError("attention not implemented")
533
-
534
- # Layer scaling
535
- if layer_scale_init_value is not None:
536
- self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
537
- else:
538
- self.layer_scale = nn.Identity()
539
-
540
- # Drop path for residual connection
541
- self.drop_path = nn.Identity()
542
-
543
- def __call__(self, x: mx.array) -> mx.array:
544
- shortcut = x
545
- x = self.norm(x)
546
- x = self.attn(x)
547
- x = self.layer_scale(x)
548
-
549
- # Apply skip connection if available
550
- if self.has_skip:
551
- x = self.drop_path(x) + shortcut
552
- return x
553
-
554
-
555
- def create_conv2d(
556
- in_channels,
557
- out_channels,
558
- kernel_size,
559
- stride=1,
560
- dilation=1,
561
- depthwise=False,
562
- bias=False,
563
- **kwargs,
564
- ):
565
- """Helper function to create a 2D convolution with common parameters"""
566
- if depthwise:
567
- # Depthwise convolution
568
- return nn.Conv2d(
569
- in_channels,
570
- out_channels,
571
- kernel_size=kernel_size,
572
- stride=stride,
573
- padding=(kernel_size - 1) // 2 * dilation,
574
- dilation=dilation,
575
- groups=in_channels,
576
- bias=bias,
577
- )
578
- else:
579
- # Regular convolution
580
- return nn.Conv2d(
581
- in_channels,
582
- out_channels,
583
- kernel_size=kernel_size,
584
- stride=stride,
585
- padding=(kernel_size - 1) // 2 * dilation,
586
- dilation=dilation,
587
- bias=bias,
588
- )
589
-
590
-
591
- def to_2tuple(x):
592
- """Convert input to 2-tuple"""
593
- if isinstance(x, tuple):
594
- return x
595
- return (x, x)
596
-
597
-
598
- class NamedSequential(nn.Module):
599
- def __init__(self):
600
- super().__init__()
601
- self._order = []
602
-
603
- def add_module(self, name, module):
604
- setattr(self, name, module)
605
- self._order.append(name)
606
-
607
- def __call__(self, x):
608
- for name in self._order:
609
- x = getattr(self, name)(x)
610
- return x
611
-
612
-
613
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/attention2d.py#L82
614
- class MultiQueryAttention2d(nn.Module):
615
- def __init__(
616
- self,
617
- dim: int,
618
- dim_out: Optional[int] = None,
619
- num_heads: int = 8,
620
- key_dim: int = 64,
621
- value_dim: int = 64,
622
- query_strides: Tuple[int, int] = (1, 1),
623
- kv_stride: int = 1,
624
- dilation: int = 1,
625
- padding: str = "",
626
- dw_kernel_size: int = 3,
627
- attn_drop: float = 0.0,
628
- proj_drop: float = 0.0,
629
- ):
630
- super().__init__()
631
- dim_out = dim_out or dim
632
- self.num_heads = num_heads
633
- self.query_strides = to_2tuple(query_strides)
634
- self.kv_stride = kv_stride
635
- self.fused_attn = True
636
- self.key_dim = key_dim
637
- self.value_dim = value_dim
638
- head_dim = key_dim
639
- self.scale = head_dim**-0.5
640
-
641
- self.query = NamedSequential()
642
- self.query.add_module(
643
- "proj",
644
- create_conv2d(
645
- dim,
646
- self.num_heads * self.key_dim,
647
- kernel_size=1,
648
- ),
649
- )
650
- self.key = NamedSequential()
651
- if kv_stride > 1:
652
- self.key.add_module(
653
- "down_conv",
654
- create_conv2d(
655
- dim,
656
- dim,
657
- kernel_size=dw_kernel_size,
658
- stride=kv_stride,
659
- dilation=dilation,
660
- padding=padding,
661
- depthwise=True,
662
- ),
663
- )
664
- self.key.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
665
- self.key.add_module(
666
- "proj", create_conv2d(dim, key_dim, kernel_size=1, bias=False)
667
- )
668
-
669
- self.value = NamedSequential()
670
- if kv_stride > 1:
671
- self.value.add_module(
672
- "down_conv",
673
- create_conv2d(
674
- dim,
675
- dim,
676
- kernel_size=dw_kernel_size,
677
- stride=kv_stride,
678
- dilation=dilation,
679
- padding=padding,
680
- depthwise=True,
681
- ),
682
- )
683
- self.value.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
684
- self.value.add_module(
685
- "proj", create_conv2d(dim, value_dim, kernel_size=1, bias=False)
686
- )
687
-
688
- # Attention dropout
689
- self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
690
-
691
- # Output projection
692
- self.output = NamedSequential()
693
- self.output.add_module(
694
- "proj",
695
- create_conv2d(
696
- value_dim * num_heads,
697
- dim_out,
698
- kernel_size=1,
699
- stride=1,
700
- bias=False,
701
- ),
702
- )
703
- self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
704
-
705
- def _reshape_input(self, t: mx.array):
706
- """
707
- Input shape MLX: [B, H, W, C]
708
- Input shape PyTorch: [B, C, H, W]
709
-
710
- PyTorch Reshape: [B, C, H, W] -> [B, C, -1] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
711
- MLX Reshape: [B, H, W, C] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
712
- """
713
- s = t.shape
714
- t = t.reshape(s[0], -1, s[3])[:, None, :, :]
715
-
716
- return t
717
-
718
- def _reshape_projected_query(self, t: mx.array, num_heads: int, key_dim: int):
719
- """
720
- Input shape MLX: [B, H, W, C] where C = num_heads * key_dim
721
- """
722
- B, H, W, C = t.shape
723
- # t = t.reshape(B, H, W, num_heads, key_dim)
724
- t = t.reshape(B, H * W, num_heads, key_dim)
725
- return t.transpose(0, 2, 1, 3)
726
-
727
- def _reshape_output(self, t: mx.array, num_heads: int, h_px: int, w_px: int):
728
- """
729
- Input shape: [B, NH, L, D] where L = h_px * w_px
730
- Output shape MLX: [B, H, W, C] where C = NH * D
731
- """
732
- B, NH, L, D = t.shape
733
- # First transpose to [B, L, NH, D]
734
- t = t.transpose(0, 2, 1, 3)
735
- # Then reshape to [B, H, W, NH*D]
736
- t = t.reshape(B, h_px, w_px, NH * D)
737
- return t
738
-
739
- def __call__(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
740
- B, H, W, C = x.shape
741
- q = self.query(x)
742
- q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
743
-
744
- k = self.key(x)
745
- k = self._reshape_input(k)
746
-
747
- v = self.value(x)
748
- v = self._reshape_input(v)
749
-
750
- if self.fused_attn:
751
- o = mx.fast.scaled_dot_product_attention(
752
- q,
753
- k,
754
- v,
755
- scale=1.0 / sqrt(q.shape[-1]),
756
- )
757
- else:
758
- raise NotImplementedError("unfused attention not implemented")
759
-
760
- o = self._reshape_output(
761
- o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]
762
- )
763
- x = self.output(o)
764
- return x
765
-
766
-
767
- def num_groups(group_size: Optional[int], channels: int) -> int:
768
- if not group_size: # 0 or None
769
- return 1 # normal conv with 1 group
770
- else:
771
- # NOTE group_size == 1 -> depthwise conv
772
- assert channels % group_size == 0
773
- return channels // group_size
774
-
775
-
776
- def make_divisible(v, divisor: int = 8, min_value=None, round_limit: float = 0.9):
777
- min_value = min_value or divisor
778
- new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
779
- # Make sure that round down does not go down by more than 10%.
780
- if new_v < round_limit * v:
781
- new_v += divisor
782
- return new_v
783
-
784
-
785
- @dataclass(frozen=True)
786
- class EdgeResidualConfig:
787
- kernel_size: int = 3
788
- filters: int = 32
789
- strides: int = 1
790
- expand_ratio: float = 4.0
791
- is_multiscale: bool = False
792
-
793
-
794
- def _er(kernel_size, filters, strides=1, expand_ratio=4.0, is_multiscale=False):
795
- return EdgeResidualConfig(
796
- kernel_size=kernel_size,
797
- filters=filters,
798
- strides=strides,
799
- expand_ratio=expand_ratio,
800
- is_multiscale=is_multiscale,
801
- )
802
-
803
-
804
- @dataclass(frozen=True)
805
- class UniversalInvertedResidualConfig:
806
- start_dw_kernel_size: int = 0 # Zero size means no conv
807
- mid_dw_kernel_size: int = 0 # Zero size means no conv
808
- filters: int = 32
809
- strides: int = 1
810
- expand_ratio: float = 4.0
811
- is_multiscale: bool = False
812
-
813
-
814
- def _uir(
815
- start_dw_kernel_size,
816
- mid_dw_kernel_size,
817
- filters,
818
- strides=1,
819
- expand_ratio=4.0,
820
- is_multiscale=False,
821
- ):
822
- return UniversalInvertedResidualConfig(
823
- start_dw_kernel_size=start_dw_kernel_size,
824
- mid_dw_kernel_size=mid_dw_kernel_size,
825
- filters=filters,
826
- strides=strides,
827
- expand_ratio=expand_ratio,
828
- is_multiscale=is_multiscale,
829
- )
830
-
831
-
832
- @dataclass(frozen=True)
833
- class MultiQueryAttentionBlockConfig:
834
- num_heads: int = 8
835
- kv_dim: int = 16
836
- kv_strides: int = 1
837
- mmqa_avg_pool_kv: bool = False
838
- mmqa_dropout: float = 0.0
839
- mmqa_dw_kernel_size: int = 3
840
- is_multiscale: bool = False
841
-
842
-
843
- def _mmqa(
844
- num_heads,
845
- kv_dim,
846
- kv_strides,
847
- mmqa_avg_pool_kv=False,
848
- is_multiscale=False,
849
- ):
850
- conf = MultiQueryAttentionBlockConfig(
851
- num_heads=num_heads,
852
- kv_dim=kv_dim,
853
- kv_strides=kv_strides,
854
- mmqa_avg_pool_kv=mmqa_avg_pool_kv,
855
- is_multiscale=is_multiscale,
856
- )
857
- return conf
858
-
859
-
860
- # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L596
861
- def gemma3n_mobilenet_def():
862
- return [
863
- # Stage 1: Edge Residuals
864
- [_er(3, 128, 2)] + [_er(3, 128, 1)] * 2,
865
- # Stage 2: Universal Inverted Residuals
866
- [_uir(3, 5, 256, 2, 6.0)] + [_uir(k, 0, 256) for k in [5, 3, 5, 3]],
867
- # Stage 3: Universal Inverted Residuals with Multi-Query Attention
868
- [_uir(5, 5, 640, 2, 6.0)]
869
- + [_uir(5, 0, 640)] * 7
870
- + [_uir(0, 0, 640, 1, 1.0)]
871
- + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0)] * 13
872
- + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0, is_multiscale=True)],
873
- # Stage 4: Universal Inverted Residuals with Multi-Query Attention
874
- [_uir(5, 5, 1280, 2, 6.0)]
875
- + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0)] * 18
876
- + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0, is_multiscale=True)],
877
- ]
878
-
879
-
880
- class VisionTower(nn.Module):
881
- def __init__(self, config: VisionConfig):
882
- super().__init__()
883
- self.conv_stem = ConvNormAct(
884
- Conv2dSame,
885
- in_chs=3,
886
- out_chs=64,
887
- kernel_size=3,
888
- stride=2,
889
- padding=0,
890
- eps=1e-05,
891
- bias=True,
892
- )
893
- msfa_indices = (3, 4)
894
- msfa_output_resolution = (16, 16)
895
-
896
- (num_features, self.blocks) = self.build()
897
- self.num_features = self.head_hidden_size = (
898
- num_features # output of msfa is output of forward_features()
899
- )
900
- self.msfa_indices = msfa_indices
901
- self.msfa_output_resolution = msfa_output_resolution
902
-
903
- self.msfa = MobileNetV5MultiScaleFusionAdapter(
904
- in_chs=[1920],
905
- out_chs=2048,
906
- output_resolution=self.msfa_output_resolution,
907
- )
908
-
909
- def build(self):
910
- blocks = []
911
- in_chs = self.conv_stem.out_chs
912
- for stage, block_config in enumerate(gemma3n_mobilenet_def()):
913
- block_group = []
914
- for config in block_config:
915
- match config:
916
- case EdgeResidualConfig(
917
- kernel_size, filters, strides, expand_ratio, is_multiscale
918
- ):
919
- x = EdgeResidual(
920
- exp_kernel_size=kernel_size,
921
- in_chs=in_chs,
922
- out_chs=filters,
923
- stride=strides,
924
- expand_ratio=expand_ratio,
925
- )
926
- in_chs = filters # in_chs of next is out_chs of prev
927
- block_group.append(x)
928
- case UniversalInvertedResidualConfig(
929
- start_dw_kernel_size,
930
- mid_dw_kernel_size,
931
- filters,
932
- strides,
933
- expand_ratio,
934
- is_multiscale,
935
- ):
936
- x = UniversalInvertedResidual(
937
- in_chs=in_chs,
938
- out_chs=filters,
939
- dw_kernel_size_start=start_dw_kernel_size,
940
- dw_kernel_size_mid=mid_dw_kernel_size,
941
- stride=strides,
942
- exp_ratio=expand_ratio,
943
- )
944
- in_chs = filters
945
- block_group.append(x)
946
- case MultiQueryAttentionBlockConfig(
947
- num_heads,
948
- kv_dim,
949
- kv_strides,
950
- mmqa_avg_pool_kv,
951
- is_multiscale,
952
- ):
953
- x = MobileAttention(
954
- in_chs=in_chs,
955
- out_chs=in_chs,
956
- stride=1,
957
- num_heads=num_heads,
958
- key_dim=kv_dim,
959
- value_dim=kv_dim,
960
- kv_stride=kv_strides,
961
- act_layer=None,
962
- )
963
- block_group.append(x)
964
- case _:
965
- continue
966
- blocks.append(block_group)
967
- return (in_chs, blocks)
968
-
969
- def __call__(
970
- self, x: mx.array, output_hidden_states: Optional[bool] = None
971
- ) -> mx.array:
972
- feat_idx = 0
973
- x = x.transpose(0, 2, 3, 1) # Convert from NCHW to NHWC
974
- x = self.conv_stem(x)
975
- intermediates = []
976
-
977
- if feat_idx in self.msfa_indices:
978
- intermediates.append(x)
979
-
980
- # MBV5 is constructed of 4 stages, each stage is a group of blocks.
981
- for block_group in self.blocks:
982
- feat_idx += 1
983
- for block in block_group:
984
- x = block(x)
985
-
986
- if feat_idx in self.msfa_indices:
987
- intermediates.append(x)
988
-
989
- x = self.msfa(intermediates)
990
- return x
991
-
992
-
993
- class VisionModel(nn.Module):
994
- def __init__(self, config: VisionConfig):
995
- super().__init__()
996
- self.model_type = config.model_type
997
- if self.model_type not in ["gemma3", "gemma3_vision", "gemma3n_vision"]:
998
- raise ValueError(f"Unsupported model type: {self.model_type}")
999
-
1000
- self.timm_model = VisionTower(config)
1001
-
1002
- def __call__(
1003
- self, x: mx.array, output_hidden_states: Optional[bool] = None
1004
- ) -> mx.array:
1005
- return self.timm_model(x, output_hidden_states)
1006
-
1007
- def sanitize(self, weights):
1008
- sanitized_weights = {}
1009
- skip_transpose = False
1010
- _, H, _, C = weights["vision_tower.timm_model.blocks.0.0.conv_exp.weight"].shape
1011
- if C > H:
1012
- skip_transpose = True
1013
-
1014
- for k, v in weights.items():
1015
- # PyTorch conv2d weight: [out_channels, in_channels, kH, kW]
1016
- # MLX conv2d weight: [out_channels, kH, KW, in_channels]
1017
- if ("conv" in k and "weight" in k) or ("attn" and "proj.weight") in k:
1018
- if len(v.shape) == 4 and not skip_transpose:
1019
- v = v.transpose(0, 2, 3, 1)
1020
- sanitized_weights[k] = v
1021
-
1022
- return sanitized_weights