nexaai 1.0.19rc5__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (221) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  5. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  6. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  7. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  8. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  9. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  10. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  11. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  12. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
  13. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
  14. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  15. nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
  16. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  17. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
  18. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
  19. nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
  20. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
  21. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
  23. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
  24. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
  25. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  26. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
  33. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
  34. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
  35. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
  36. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
  37. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
  38. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
  39. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  40. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
  41. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
  42. nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
  43. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  44. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
  45. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
  46. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
  47. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  48. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
  49. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
  50. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
  51. nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
  54. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
  55. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
  56. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
  57. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
  58. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
  59. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
  60. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
  61. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
  62. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
  64. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
  66. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
  67. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
  195. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
  196. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
  197. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
  198. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
  199. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
  200. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
  201. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  202. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  203. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  204. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  205. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
  206. nexaai/mlx_backend/vlm/interface.py +21 -4
  207. nexaai/mlx_backend/vlm/main.py +6 -2
  208. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  209. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  210. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  211. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  212. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  213. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  214. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  215. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  216. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  217. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  218. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
  219. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +221 -21
  220. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
  221. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1022 @@
1
+ import inspect
2
+ from collections.abc import Sequence
3
+ from dataclasses import dataclass
4
+ from math import sqrt
5
+ from typing import Dict, List, Optional, Tuple, Type
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+
10
+ from .config import VisionConfig
11
+
12
+ from ..base import check_array_shape
13
+ from ..kernels import bicubic_interpolate, nearest_interpolate
14
+
15
+
16
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L24
17
+ class MobileNetV5MultiScaleFusionAdapter(nn.Module):
18
+ """Multi-layer fusion token adapter.
19
+ Attributes:
20
+ out_filters: The number of output filters.
21
+ output_resolution: The output resolution.
22
+ activation: The activation function.
23
+ expansion_ratio: The expansion ratio.
24
+ upsampling_interpolation: The upsampling interpolation.
25
+ use_layer_scale: Whether to use layer scale.
26
+ layer_scale_init_value: The initial value of the layer scale.
27
+ skip_projection: Whether to skip the projection.
28
+ name: The name of the module.
29
+ upsize: The upsampling fn.
30
+ downsize: The downsampling fn.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ in_chs: List[int],
36
+ out_chs: int,
37
+ output_resolution: int,
38
+ expansion_ratio: float = 2.0,
39
+ interpolation_mode: str = "nearest",
40
+ use_layer_scale: bool = False,
41
+ layer_scale_init_value: float = 1e-5,
42
+ noskip: bool = True,
43
+ ):
44
+ super().__init__()
45
+ self.in_channels = sum(in_chs) if isinstance(in_chs, Sequence) else in_chs
46
+ self.out_channels = out_chs
47
+ self.output_resolution = to_2tuple(output_resolution)
48
+ self.expansion_ratio = expansion_ratio
49
+ self.interpolation_mode = interpolation_mode
50
+ self.use_layer_scale = use_layer_scale
51
+ self.layer_scale_init_value = layer_scale_init_value
52
+ self.noskip = noskip
53
+
54
+ norm_layer = RMSNormAct2d
55
+ self.ffn = UniversalInvertedResidual(
56
+ in_chs=self.in_channels,
57
+ out_chs=self.out_channels,
58
+ dw_kernel_size_mid=0,
59
+ exp_ratio=self.expansion_ratio,
60
+ norm_layer=norm_layer,
61
+ noskip=self.noskip,
62
+ layer_scale_init_value=(
63
+ self.layer_scale_init_value if self.use_layer_scale else None
64
+ ),
65
+ )
66
+
67
+ self.norm = norm_layer(self.out_channels, eps=1e-6, apply_act=False)
68
+
69
+ def __call__(self, inputs: list[mx.array]) -> mx.array:
70
+ inputs = [i.transpose(0, 3, 1, 2) for i in inputs]
71
+ high_resolution = inputs[0].shape[
72
+ -2:
73
+ ] # Assuming the first input is the highest resolution.
74
+ resized_inputs = []
75
+
76
+ for _, img in enumerate(inputs):
77
+ if any([r < hr for r, hr in zip(img.shape[-2:], high_resolution)]):
78
+ img = nearest_interpolate(img, size=high_resolution)
79
+
80
+ resized_inputs.append(img)
81
+
82
+ channel_cat_imgs = mx.concatenate(
83
+ resized_inputs, axis=1
84
+ ) # Cat on channel dim, must equal self.in_channels
85
+ img = self.ffn(channel_cat_imgs.swapaxes(1, 3)).swapaxes(1, 3)
86
+
87
+ if any([ro != rh for ro, rh in zip(high_resolution, self.output_resolution)]):
88
+ if (
89
+ high_resolution[0] % self.output_resolution[0] != 0
90
+ or high_resolution[1] % self.output_resolution[1] != 0
91
+ ):
92
+ img = bicubic_interpolate(img, self.output_resolution)
93
+ else:
94
+ h_strides = high_resolution[0] // self.output_resolution[0]
95
+ w_strides = high_resolution[1] // self.output_resolution[1]
96
+
97
+ img = nn.AvgPool2d(
98
+ kernel_size=(h_strides, w_strides),
99
+ stride=(h_strides, w_strides),
100
+ )(img.swapaxes(1, 3))
101
+
102
+ img = self.norm(img) if self.noskip else img
103
+
104
+ return img
105
+
106
+
107
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/layer_scale.py#L22
108
+ class LayerScale2d(nn.Module):
109
+ def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
110
+ super().__init__()
111
+ self.inplace = inplace
112
+ self.gamma = init_values * mx.ones((dim,))
113
+
114
+ def __call__(self, x: mx.array) -> mx.array:
115
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
116
+
117
+
118
+ def rms_norm2d(
119
+ x: mx.array,
120
+ normalized_shape: List[int],
121
+ weight: Optional[mx.array] = None,
122
+ eps: float = 1e-5,
123
+ ):
124
+ assert len(normalized_shape) == 1
125
+ dtype = x.dtype
126
+ v = mx.power(x, 2)
127
+ v = mx.mean(v, axis=1, keepdims=True)
128
+ x = x * mx.rsqrt(v + eps)
129
+ if weight is not None:
130
+ x = x.astype(dtype) * weight.reshape(1, -1, 1, 1)
131
+ return x
132
+
133
+
134
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/norm_act.py#L504
135
+ class RMSNormAct2d(nn.RMSNorm):
136
+ def __init__(
137
+ self,
138
+ num_channels,
139
+ eps=1e-6,
140
+ apply_act: bool = True,
141
+ ):
142
+ super().__init__(dims=num_channels, eps=eps)
143
+ self.normalized_shape = [num_channels]
144
+ self.drop = nn.Identity()
145
+ self.act = nn.GELU() if apply_act else nn.Identity()
146
+
147
+ def __call__(self, x: mx.array) -> mx.array:
148
+
149
+ x = x.transpose(0, 3, 1, 2) # Convert from NHWC to NCHW
150
+ x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
151
+ x = self.drop(x)
152
+ x = self.act(x)
153
+ x = x.transpose(0, 2, 3, 1) # Convert back to NHWC
154
+ return x
155
+
156
+
157
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L310
158
+ class UniversalInvertedResidual(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_chs: int,
162
+ out_chs: int,
163
+ dw_kernel_size_start: int = 0,
164
+ dw_kernel_size_mid: int = 3,
165
+ dw_kernel_size_end: int = 0,
166
+ stride: int = 1,
167
+ dilation: int = 1,
168
+ group_size: int = 1,
169
+ pad_type: str = "",
170
+ noskip: bool = False,
171
+ exp_ratio: float = 1.0,
172
+ norm_layer=RMSNormAct2d,
173
+ conv_kwargs: Optional[Dict] = None,
174
+ drop_path_rate: float = 0.0,
175
+ layer_scale_init_value: Optional[float] = 1e-5,
176
+ ):
177
+ super().__init__()
178
+ self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
179
+ if stride > 1:
180
+ assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
181
+
182
+ if dw_kernel_size_start:
183
+ dw_start_stride = stride if not dw_kernel_size_mid else 1
184
+ dw_start_groups = num_groups(group_size, in_chs)
185
+ self.dw_start = ConvNormAct(
186
+ nn.Conv2d,
187
+ in_chs,
188
+ in_chs,
189
+ kernel_size=dw_kernel_size_start,
190
+ stride=dw_start_stride,
191
+ padding=(dw_kernel_size_start - 1) // 2,
192
+ dilation=dilation,
193
+ groups=dw_start_groups,
194
+ bias=False,
195
+ apply_act=False,
196
+ eps=1e-05,
197
+ )
198
+ else:
199
+ self.dw_start = nn.Identity()
200
+
201
+ mid_chs = make_divisible(in_chs * exp_ratio)
202
+ self.pw_exp = ConvNormAct(
203
+ nn.Conv2d,
204
+ in_chs,
205
+ mid_chs,
206
+ kernel_size=1,
207
+ stride=1,
208
+ padding=0,
209
+ groups=1,
210
+ bias=False,
211
+ eps=1e-05,
212
+ )
213
+
214
+ if dw_kernel_size_mid:
215
+ dw_mid_groups = num_groups(group_size, mid_chs)
216
+ self.dw_mid = ConvNormAct(
217
+ Conv2dSame,
218
+ mid_chs,
219
+ mid_chs,
220
+ kernel_size=dw_kernel_size_mid,
221
+ stride=stride,
222
+ padding=0,
223
+ dilation=dilation,
224
+ groups=dw_mid_groups,
225
+ bias=False,
226
+ eps=1e-05,
227
+ )
228
+ else:
229
+ self.dw_mid = nn.Identity()
230
+
231
+ self.pw_proj = ConvNormAct(
232
+ nn.Conv2d,
233
+ mid_chs,
234
+ out_chs,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0,
238
+ groups=1,
239
+ bias=False,
240
+ apply_act=False,
241
+ eps=1e-05,
242
+ )
243
+ if layer_scale_init_value is not None:
244
+ self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
245
+ else:
246
+ self.layer_scale = nn.Identity()
247
+
248
+ def __call__(self, x: mx.array) -> mx.array:
249
+ shortcut = x
250
+ x = self.dw_start(x)
251
+ x = self.pw_exp(x)
252
+ x = self.dw_mid(x)
253
+ x = self.pw_proj(x)
254
+ x = self.layer_scale(x)
255
+ if self.has_skip:
256
+ x = x + shortcut
257
+ return x
258
+
259
+
260
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/conv_bn_act.py#L15
261
+ class ConvNormAct(nn.Module):
262
+ def __init__(
263
+ self,
264
+ conv_cls,
265
+ in_chs: int,
266
+ out_chs: int,
267
+ kernel_size: int = 3,
268
+ stride: int = 1,
269
+ padding: int = 0,
270
+ dilation: int = 1,
271
+ groups: int = 1,
272
+ bias: bool = False,
273
+ apply_act: bool = True,
274
+ eps: float = 1e-6,
275
+ ):
276
+ super().__init__()
277
+ self.out_chs = out_chs
278
+ self.conv = conv_cls(
279
+ in_chs,
280
+ out_chs,
281
+ kernel_size,
282
+ stride,
283
+ padding,
284
+ (dilation, dilation),
285
+ groups,
286
+ bias,
287
+ )
288
+ self.bn = RMSNormAct2d(out_chs, eps=eps, apply_act=apply_act)
289
+
290
+ def __call__(self, x: mx.array) -> mx.array:
291
+ c = self.conv(x)
292
+ r = self.bn(c)
293
+ return r
294
+
295
+
296
+ def pad_same(
297
+ x,
298
+ kernel_size: List[int],
299
+ stride: List[int],
300
+ dilation: List[int] = (1, 1),
301
+ value: float = 0,
302
+ ):
303
+ """
304
+ Input should be in MLX format
305
+ """
306
+ ih, iw = x.shape[1:3]
307
+ pad_h = get_same_padding(ih, kernel_size[0], stride[0], dilation[0])
308
+ pad_w = get_same_padding(iw, kernel_size[1], stride[1], dilation[1])
309
+
310
+ # MLX pad format: [(low, high), (low, high), ...] for each axis
311
+ # Padding order is reversed compared to PyTorch F.pad
312
+ pad_widths = [
313
+ (0, 0), # No padding for batch dimension
314
+ (pad_h // 2, pad_h - pad_h // 2), # Height padding
315
+ (pad_w // 2, pad_w - pad_w // 2), # Width padding
316
+ (0, 0), # No padding for channel dimension
317
+ ]
318
+
319
+ x = mx.pad(x, pad_widths, constant_values=value)
320
+ return x
321
+
322
+
323
+ def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
324
+ dynamic = False
325
+ if isinstance(padding, str):
326
+ # for any string padding, the padding will be calculated for you, one of three ways
327
+ padding = padding.lower()
328
+ if padding == "same":
329
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
330
+ if is_static_pad(kernel_size, **kwargs):
331
+ # static case, no extra overhead
332
+ padding = get_padding(kernel_size, **kwargs)
333
+ else:
334
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
335
+ padding = 0
336
+ dynamic = True
337
+ elif padding == "valid":
338
+ # 'VALID' padding, same as padding=0
339
+ padding = 0
340
+ else:
341
+ # Default to PyTorch style 'same'-ish symmetric padding
342
+ padding = get_padding(kernel_size, **kwargs)
343
+ return padding, dynamic
344
+
345
+
346
+ def get_same_padding(
347
+ input_size: int, kernel_size: int, stride: int, dilation: int = 1
348
+ ) -> int:
349
+ """Calculate padding needed for 'same' output size."""
350
+ effective_kernel_size = dilation * (kernel_size - 1) + 1
351
+ output_size = (input_size + stride - 1) // stride
352
+ total_padding = max(
353
+ 0, (output_size - 1) * stride + effective_kernel_size - input_size
354
+ )
355
+ return total_padding
356
+
357
+
358
+ def get_padding(kernel_size, stride=1, dilation=1, **_):
359
+ """Get symmetric padding for given kernel size."""
360
+ if isinstance(kernel_size, int):
361
+ kernel_size = [kernel_size, kernel_size]
362
+ if isinstance(stride, int):
363
+ stride = [stride, stride]
364
+ if isinstance(dilation, int):
365
+ dilation = [dilation, dilation]
366
+
367
+ padding = []
368
+ for k, d in zip(kernel_size, dilation):
369
+ effective_k = d * (k - 1) + 1
370
+ pad_total = effective_k - 1
371
+ padding.append(pad_total // 2)
372
+ return tuple(padding)
373
+
374
+
375
+ def is_static_pad(kernel_size, stride=1, dilation=1, **_):
376
+ """Check if padding can be calculated statically."""
377
+ if isinstance(kernel_size, int):
378
+ kernel_size = [kernel_size, kernel_size]
379
+ if isinstance(stride, int):
380
+ stride = [stride, stride]
381
+ if isinstance(dilation, int):
382
+ dilation = [dilation, dilation]
383
+
384
+ # Static padding is possible when stride is 1 for all dimensions
385
+ return all(s == 1 for s in stride)
386
+
387
+
388
+ class Conv2dSame(nn.Conv2d):
389
+ def __init__(self, *args, **kwargs):
390
+ super().__init__(*args, **kwargs)
391
+ self.kernel_size = self.weight.shape[1:3]
392
+
393
+ def __call__(self, x: mx.array) -> mx.array:
394
+ x = pad_same(x, self.kernel_size, self.stride, self.dilation)
395
+ y = mx.conv2d(
396
+ x, self.weight, self.stride, self.padding, self.dilation, self.groups
397
+ )
398
+ if "bias" in self:
399
+ y = y + self.bias
400
+ return y
401
+
402
+
403
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L629
404
+ class EdgeResidual(nn.Module):
405
+ def __init__(
406
+ self,
407
+ in_chs: int,
408
+ out_chs: int,
409
+ exp_kernel_size: int = 3,
410
+ stride: int = 1,
411
+ dilation: int = 1,
412
+ group_size: int = 0,
413
+ pad_type: str = "",
414
+ force_in_chs: int = 0,
415
+ noskip: bool = False,
416
+ expand_ratio: float = 1.0,
417
+ pw_kernel_size: int = 1,
418
+ norm_layer=RMSNormAct2d,
419
+ ):
420
+ super().__init__()
421
+
422
+ if force_in_chs > 0:
423
+ mid_chs = make_divisible(force_in_chs * expand_ratio)
424
+ else:
425
+ mid_chs = make_divisible(in_chs * expand_ratio)
426
+
427
+ groups = num_groups(group_size, mid_chs)
428
+
429
+ self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
430
+
431
+ self.conv_exp = Conv2dSame(
432
+ in_chs,
433
+ mid_chs,
434
+ kernel_size=exp_kernel_size,
435
+ stride=stride,
436
+ padding=0,
437
+ dilation=(dilation, dilation),
438
+ groups=groups,
439
+ bias=False,
440
+ )
441
+
442
+ self.bn1 = norm_layer(mid_chs, eps=1e-05) if norm_layer else nn.Identity()
443
+
444
+ # Point-wise linear projection
445
+ padding_pwl = (pw_kernel_size - 1) // 2
446
+ self.conv_pwl = nn.Conv2d(
447
+ mid_chs,
448
+ out_chs,
449
+ kernel_size=pw_kernel_size,
450
+ padding=padding_pwl,
451
+ bias=False,
452
+ )
453
+
454
+ self.bn2 = (
455
+ norm_layer(out_chs, eps=1e-05, apply_act=False)
456
+ if norm_layer
457
+ else nn.Identity()
458
+ )
459
+
460
+ def __call__(self, x: mx.array) -> mx.array:
461
+ shortcut = x
462
+ x = self.conv_exp(x)
463
+ x = self.bn1(x)
464
+ x = self.conv_pwl(x)
465
+ x = self.bn2(x)
466
+ if self.has_skip:
467
+ x = x + shortcut
468
+ return x
469
+
470
+
471
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/_efficientnet_blocks.py#L449
472
+ class MobileAttention(nn.Module):
473
+ def __init__(
474
+ self,
475
+ in_chs: int,
476
+ out_chs: int,
477
+ stride: int = 1,
478
+ dw_kernel_size: int = 3,
479
+ dilation: int = 1,
480
+ group_size: int = 1,
481
+ pad_type: str = "",
482
+ num_heads: int = 8,
483
+ key_dim: int = 64,
484
+ value_dim: int = 64,
485
+ use_multi_query: bool = True,
486
+ query_strides: Tuple[int, int] = (1, 1),
487
+ kv_stride: int = 1,
488
+ cpe_dw_kernel_size: int = 3,
489
+ noskip: bool = False,
490
+ act_layer=nn.GELU,
491
+ aa_layer=None,
492
+ drop_path_rate: float = 0.0,
493
+ attn_drop: float = 0.0,
494
+ proj_drop: float = 0.0,
495
+ layer_scale_init_value: Optional[float] = 1e-5,
496
+ use_bias: bool = False,
497
+ ):
498
+ super().__init__()
499
+ self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
500
+ self.query_strides = to_2tuple(query_strides)
501
+ self.kv_stride = kv_stride
502
+ self.has_query_stride = any([s > 1 for s in self.query_strides])
503
+
504
+ # Normalization layer
505
+ self.norm = RMSNormAct2d(
506
+ in_chs,
507
+ eps=1e-05,
508
+ apply_act=False,
509
+ )
510
+ # Determine number of heads if not provided
511
+ if num_heads is None:
512
+ assert in_chs % key_dim == 0
513
+ num_heads = in_chs // key_dim
514
+
515
+ # Attention layer
516
+ if use_multi_query:
517
+ self.attn = MultiQueryAttention2d(
518
+ in_chs,
519
+ dim_out=out_chs,
520
+ num_heads=num_heads,
521
+ key_dim=key_dim,
522
+ value_dim=value_dim,
523
+ query_strides=query_strides,
524
+ kv_stride=kv_stride,
525
+ dilation=dilation,
526
+ padding=pad_type,
527
+ dw_kernel_size=dw_kernel_size,
528
+ attn_drop=attn_drop,
529
+ proj_drop=proj_drop,
530
+ )
531
+ else:
532
+ raise NotImplementedError("attention not implemented")
533
+
534
+ # Layer scaling
535
+ if layer_scale_init_value is not None:
536
+ self.layer_scale = LayerScale2d(out_chs, layer_scale_init_value)
537
+ else:
538
+ self.layer_scale = nn.Identity()
539
+
540
+ # Drop path for residual connection
541
+ self.drop_path = nn.Identity()
542
+
543
+ def __call__(self, x: mx.array) -> mx.array:
544
+ shortcut = x
545
+ x = self.norm(x)
546
+ x = self.attn(x)
547
+ x = self.layer_scale(x)
548
+
549
+ # Apply skip connection if available
550
+ if self.has_skip:
551
+ x = self.drop_path(x) + shortcut
552
+ return x
553
+
554
+
555
+ def create_conv2d(
556
+ in_channels,
557
+ out_channels,
558
+ kernel_size,
559
+ stride=1,
560
+ dilation=1,
561
+ depthwise=False,
562
+ bias=False,
563
+ **kwargs,
564
+ ):
565
+ """Helper function to create a 2D convolution with common parameters"""
566
+ if depthwise:
567
+ # Depthwise convolution
568
+ return nn.Conv2d(
569
+ in_channels,
570
+ out_channels,
571
+ kernel_size=kernel_size,
572
+ stride=stride,
573
+ padding=(kernel_size - 1) // 2 * dilation,
574
+ dilation=dilation,
575
+ groups=in_channels,
576
+ bias=bias,
577
+ )
578
+ else:
579
+ # Regular convolution
580
+ return nn.Conv2d(
581
+ in_channels,
582
+ out_channels,
583
+ kernel_size=kernel_size,
584
+ stride=stride,
585
+ padding=(kernel_size - 1) // 2 * dilation,
586
+ dilation=dilation,
587
+ bias=bias,
588
+ )
589
+
590
+
591
+ def to_2tuple(x):
592
+ """Convert input to 2-tuple"""
593
+ if isinstance(x, tuple):
594
+ return x
595
+ return (x, x)
596
+
597
+
598
+ class NamedSequential(nn.Module):
599
+ def __init__(self):
600
+ super().__init__()
601
+ self._order = []
602
+
603
+ def add_module(self, name, module):
604
+ setattr(self, name, module)
605
+ self._order.append(name)
606
+
607
+ def __call__(self, x):
608
+ for name in self._order:
609
+ x = getattr(self, name)(x)
610
+ return x
611
+
612
+
613
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/layers/attention2d.py#L82
614
+ class MultiQueryAttention2d(nn.Module):
615
+ def __init__(
616
+ self,
617
+ dim: int,
618
+ dim_out: Optional[int] = None,
619
+ num_heads: int = 8,
620
+ key_dim: int = 64,
621
+ value_dim: int = 64,
622
+ query_strides: Tuple[int, int] = (1, 1),
623
+ kv_stride: int = 1,
624
+ dilation: int = 1,
625
+ padding: str = "",
626
+ dw_kernel_size: int = 3,
627
+ attn_drop: float = 0.0,
628
+ proj_drop: float = 0.0,
629
+ ):
630
+ super().__init__()
631
+ dim_out = dim_out or dim
632
+ self.num_heads = num_heads
633
+ self.query_strides = to_2tuple(query_strides)
634
+ self.kv_stride = kv_stride
635
+ self.fused_attn = True
636
+ self.key_dim = key_dim
637
+ self.value_dim = value_dim
638
+ head_dim = key_dim
639
+ self.scale = head_dim**-0.5
640
+
641
+ self.query = NamedSequential()
642
+ self.query.add_module(
643
+ "proj",
644
+ create_conv2d(
645
+ dim,
646
+ self.num_heads * self.key_dim,
647
+ kernel_size=1,
648
+ ),
649
+ )
650
+ self.key = NamedSequential()
651
+ if kv_stride > 1:
652
+ self.key.add_module(
653
+ "down_conv",
654
+ create_conv2d(
655
+ dim,
656
+ dim,
657
+ kernel_size=dw_kernel_size,
658
+ stride=kv_stride,
659
+ dilation=dilation,
660
+ padding=padding,
661
+ depthwise=True,
662
+ ),
663
+ )
664
+ self.key.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
665
+ self.key.add_module(
666
+ "proj", create_conv2d(dim, key_dim, kernel_size=1, bias=False)
667
+ )
668
+
669
+ self.value = NamedSequential()
670
+ if kv_stride > 1:
671
+ self.value.add_module(
672
+ "down_conv",
673
+ create_conv2d(
674
+ dim,
675
+ dim,
676
+ kernel_size=dw_kernel_size,
677
+ stride=kv_stride,
678
+ dilation=dilation,
679
+ padding=padding,
680
+ depthwise=True,
681
+ ),
682
+ )
683
+ self.value.add_module("norm", RMSNormAct2d(dim, eps=1e-6, apply_act=False))
684
+ self.value.add_module(
685
+ "proj", create_conv2d(dim, value_dim, kernel_size=1, bias=False)
686
+ )
687
+
688
+ # Attention dropout
689
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
690
+
691
+ # Output projection
692
+ self.output = NamedSequential()
693
+ self.output.add_module(
694
+ "proj",
695
+ create_conv2d(
696
+ value_dim * num_heads,
697
+ dim_out,
698
+ kernel_size=1,
699
+ stride=1,
700
+ bias=False,
701
+ ),
702
+ )
703
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
704
+
705
+ def _reshape_input(self, t: mx.array):
706
+ """
707
+ Input shape MLX: [B, H, W, C]
708
+ Input shape PyTorch: [B, C, H, W]
709
+
710
+ PyTorch Reshape: [B, C, H, W] -> [B, C, -1] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
711
+ MLX Reshape: [B, H, W, C] -> [B, -1, C] -> [B, 1, -1, C] -> SDPA
712
+ """
713
+ s = t.shape
714
+ t = t.reshape(s[0], -1, s[3])[:, None, :, :]
715
+
716
+ return t
717
+
718
+ def _reshape_projected_query(self, t: mx.array, num_heads: int, key_dim: int):
719
+ """
720
+ Input shape MLX: [B, H, W, C] where C = num_heads * key_dim
721
+ """
722
+ B, H, W, C = t.shape
723
+ # t = t.reshape(B, H, W, num_heads, key_dim)
724
+ t = t.reshape(B, H * W, num_heads, key_dim)
725
+ return t.transpose(0, 2, 1, 3)
726
+
727
+ def _reshape_output(self, t: mx.array, num_heads: int, h_px: int, w_px: int):
728
+ """
729
+ Input shape: [B, NH, L, D] where L = h_px * w_px
730
+ Output shape MLX: [B, H, W, C] where C = NH * D
731
+ """
732
+ B, NH, L, D = t.shape
733
+ # First transpose to [B, L, NH, D]
734
+ t = t.transpose(0, 2, 1, 3)
735
+ # Then reshape to [B, H, W, NH*D]
736
+ t = t.reshape(B, h_px, w_px, NH * D)
737
+ return t
738
+
739
+ def __call__(self, x: mx.array, attn_mask: Optional[mx.array] = None) -> mx.array:
740
+ B, H, W, C = x.shape
741
+ q = self.query(x)
742
+ q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
743
+
744
+ k = self.key(x)
745
+ k = self._reshape_input(k)
746
+
747
+ v = self.value(x)
748
+ v = self._reshape_input(v)
749
+
750
+ if self.fused_attn:
751
+ o = mx.fast.scaled_dot_product_attention(
752
+ q,
753
+ k,
754
+ v,
755
+ scale=1.0 / sqrt(q.shape[-1]),
756
+ )
757
+ else:
758
+ raise NotImplementedError("unfused attention not implemented")
759
+
760
+ o = self._reshape_output(
761
+ o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]
762
+ )
763
+ x = self.output(o)
764
+ return x
765
+
766
+
767
+ def num_groups(group_size: Optional[int], channels: int) -> int:
768
+ if not group_size: # 0 or None
769
+ return 1 # normal conv with 1 group
770
+ else:
771
+ # NOTE group_size == 1 -> depthwise conv
772
+ assert channels % group_size == 0
773
+ return channels // group_size
774
+
775
+
776
+ def make_divisible(v, divisor: int = 8, min_value=None, round_limit: float = 0.9):
777
+ min_value = min_value or divisor
778
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
779
+ # Make sure that round down does not go down by more than 10%.
780
+ if new_v < round_limit * v:
781
+ new_v += divisor
782
+ return new_v
783
+
784
+
785
+ @dataclass(frozen=True)
786
+ class EdgeResidualConfig:
787
+ kernel_size: int = 3
788
+ filters: int = 32
789
+ strides: int = 1
790
+ expand_ratio: float = 4.0
791
+ is_multiscale: bool = False
792
+
793
+
794
+ def _er(kernel_size, filters, strides=1, expand_ratio=4.0, is_multiscale=False):
795
+ return EdgeResidualConfig(
796
+ kernel_size=kernel_size,
797
+ filters=filters,
798
+ strides=strides,
799
+ expand_ratio=expand_ratio,
800
+ is_multiscale=is_multiscale,
801
+ )
802
+
803
+
804
+ @dataclass(frozen=True)
805
+ class UniversalInvertedResidualConfig:
806
+ start_dw_kernel_size: int = 0 # Zero size means no conv
807
+ mid_dw_kernel_size: int = 0 # Zero size means no conv
808
+ filters: int = 32
809
+ strides: int = 1
810
+ expand_ratio: float = 4.0
811
+ is_multiscale: bool = False
812
+
813
+
814
+ def _uir(
815
+ start_dw_kernel_size,
816
+ mid_dw_kernel_size,
817
+ filters,
818
+ strides=1,
819
+ expand_ratio=4.0,
820
+ is_multiscale=False,
821
+ ):
822
+ return UniversalInvertedResidualConfig(
823
+ start_dw_kernel_size=start_dw_kernel_size,
824
+ mid_dw_kernel_size=mid_dw_kernel_size,
825
+ filters=filters,
826
+ strides=strides,
827
+ expand_ratio=expand_ratio,
828
+ is_multiscale=is_multiscale,
829
+ )
830
+
831
+
832
+ @dataclass(frozen=True)
833
+ class MultiQueryAttentionBlockConfig:
834
+ num_heads: int = 8
835
+ kv_dim: int = 16
836
+ kv_strides: int = 1
837
+ mmqa_avg_pool_kv: bool = False
838
+ mmqa_dropout: float = 0.0
839
+ mmqa_dw_kernel_size: int = 3
840
+ is_multiscale: bool = False
841
+
842
+
843
+ def _mmqa(
844
+ num_heads,
845
+ kv_dim,
846
+ kv_strides,
847
+ mmqa_avg_pool_kv=False,
848
+ is_multiscale=False,
849
+ ):
850
+ conf = MultiQueryAttentionBlockConfig(
851
+ num_heads=num_heads,
852
+ kv_dim=kv_dim,
853
+ kv_strides=kv_strides,
854
+ mmqa_avg_pool_kv=mmqa_avg_pool_kv,
855
+ is_multiscale=is_multiscale,
856
+ )
857
+ return conf
858
+
859
+
860
+ # https://github.com/huggingface/new-model-addition-timm-gemma3p5-non-fork/blob/mobilenet-gemma3n-rw/timm/models/mobilenetv5.py#L596
861
+ def gemma3n_mobilenet_def():
862
+ return [
863
+ # Stage 1: Edge Residuals
864
+ [_er(3, 128, 2)] + [_er(3, 128, 1)] * 2,
865
+ # Stage 2: Universal Inverted Residuals
866
+ [_uir(3, 5, 256, 2, 6.0)] + [_uir(k, 0, 256) for k in [5, 3, 5, 3]],
867
+ # Stage 3: Universal Inverted Residuals with Multi-Query Attention
868
+ [_uir(5, 5, 640, 2, 6.0)]
869
+ + [_uir(5, 0, 640)] * 7
870
+ + [_uir(0, 0, 640, 1, 1.0)]
871
+ + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0)] * 13
872
+ + [_mmqa(12, 64, 2), _uir(0, 0, 640, 1, 2.0, is_multiscale=True)],
873
+ # Stage 4: Universal Inverted Residuals with Multi-Query Attention
874
+ [_uir(5, 5, 1280, 2, 6.0)]
875
+ + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0)] * 18
876
+ + [_mmqa(16, 96, 1), _uir(0, 0, 1280, 1, 2.0, is_multiscale=True)],
877
+ ]
878
+
879
+
880
+ class VisionTower(nn.Module):
881
+ def __init__(self, config: VisionConfig):
882
+ super().__init__()
883
+ self.conv_stem = ConvNormAct(
884
+ Conv2dSame,
885
+ in_chs=3,
886
+ out_chs=64,
887
+ kernel_size=3,
888
+ stride=2,
889
+ padding=0,
890
+ eps=1e-05,
891
+ bias=True,
892
+ )
893
+ msfa_indices = (3, 4)
894
+ msfa_output_resolution = (16, 16)
895
+
896
+ (num_features, self.blocks) = self.build()
897
+ self.num_features = self.head_hidden_size = (
898
+ num_features # output of msfa is output of forward_features()
899
+ )
900
+ self.msfa_indices = msfa_indices
901
+ self.msfa_output_resolution = msfa_output_resolution
902
+
903
+ self.msfa = MobileNetV5MultiScaleFusionAdapter(
904
+ in_chs=[1920],
905
+ out_chs=2048,
906
+ output_resolution=self.msfa_output_resolution,
907
+ )
908
+
909
+ def build(self):
910
+ blocks = []
911
+ in_chs = self.conv_stem.out_chs
912
+ for stage, block_config in enumerate(gemma3n_mobilenet_def()):
913
+ block_group = []
914
+ for config in block_config:
915
+ match config:
916
+ case EdgeResidualConfig(
917
+ kernel_size, filters, strides, expand_ratio, is_multiscale
918
+ ):
919
+ x = EdgeResidual(
920
+ exp_kernel_size=kernel_size,
921
+ in_chs=in_chs,
922
+ out_chs=filters,
923
+ stride=strides,
924
+ expand_ratio=expand_ratio,
925
+ )
926
+ in_chs = filters # in_chs of next is out_chs of prev
927
+ block_group.append(x)
928
+ case UniversalInvertedResidualConfig(
929
+ start_dw_kernel_size,
930
+ mid_dw_kernel_size,
931
+ filters,
932
+ strides,
933
+ expand_ratio,
934
+ is_multiscale,
935
+ ):
936
+ x = UniversalInvertedResidual(
937
+ in_chs=in_chs,
938
+ out_chs=filters,
939
+ dw_kernel_size_start=start_dw_kernel_size,
940
+ dw_kernel_size_mid=mid_dw_kernel_size,
941
+ stride=strides,
942
+ exp_ratio=expand_ratio,
943
+ )
944
+ in_chs = filters
945
+ block_group.append(x)
946
+ case MultiQueryAttentionBlockConfig(
947
+ num_heads,
948
+ kv_dim,
949
+ kv_strides,
950
+ mmqa_avg_pool_kv,
951
+ is_multiscale,
952
+ ):
953
+ x = MobileAttention(
954
+ in_chs=in_chs,
955
+ out_chs=in_chs,
956
+ stride=1,
957
+ num_heads=num_heads,
958
+ key_dim=kv_dim,
959
+ value_dim=kv_dim,
960
+ kv_stride=kv_strides,
961
+ act_layer=None,
962
+ )
963
+ block_group.append(x)
964
+ case _:
965
+ continue
966
+ blocks.append(block_group)
967
+ return (in_chs, blocks)
968
+
969
+ def __call__(
970
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
971
+ ) -> mx.array:
972
+ feat_idx = 0
973
+ x = x.transpose(0, 2, 3, 1) # Convert from NCHW to NHWC
974
+ x = self.conv_stem(x)
975
+ intermediates = []
976
+
977
+ if feat_idx in self.msfa_indices:
978
+ intermediates.append(x)
979
+
980
+ # MBV5 is constructed of 4 stages, each stage is a group of blocks.
981
+ for block_group in self.blocks:
982
+ feat_idx += 1
983
+ for block in block_group:
984
+ x = block(x)
985
+
986
+ if feat_idx in self.msfa_indices:
987
+ intermediates.append(x)
988
+
989
+ x = self.msfa(intermediates)
990
+ return x
991
+
992
+
993
+ class VisionModel(nn.Module):
994
+ def __init__(self, config: VisionConfig):
995
+ super().__init__()
996
+ self.model_type = config.model_type
997
+ if self.model_type not in ["gemma3", "gemma3_vision", "gemma3n_vision"]:
998
+ raise ValueError(f"Unsupported model type: {self.model_type}")
999
+
1000
+ self.timm_model = VisionTower(config)
1001
+
1002
+ def __call__(
1003
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
1004
+ ) -> mx.array:
1005
+ return self.timm_model(x, output_hidden_states)
1006
+
1007
+ def sanitize(self, weights):
1008
+ sanitized_weights = {}
1009
+ skip_transpose = False
1010
+ _, H, _, C = weights["vision_tower.timm_model.blocks.0.0.conv_exp.weight"].shape
1011
+ if C > H:
1012
+ skip_transpose = True
1013
+
1014
+ for k, v in weights.items():
1015
+ # PyTorch conv2d weight: [out_channels, in_channels, kH, kW]
1016
+ # MLX conv2d weight: [out_channels, kH, KW, in_channels]
1017
+ if ("conv" in k and "weight" in k) or ("attn" and "proj.weight") in k:
1018
+ if len(v.shape) == 4 and not skip_transpose:
1019
+ v = v.transpose(0, 2, 3, 1)
1020
+ sanitized_weights[k] = v
1021
+
1022
+ return sanitized_weights