nexaai 1.0.19rc5__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (221) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  5. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  6. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  7. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  8. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  9. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  10. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  11. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  12. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
  13. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
  14. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  15. nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
  16. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  17. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
  18. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
  19. nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
  20. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
  21. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
  23. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
  24. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
  25. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  26. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
  33. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
  34. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
  35. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
  36. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
  37. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
  38. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
  39. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  40. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
  41. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
  42. nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
  43. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  44. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
  45. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
  46. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
  47. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  48. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
  49. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
  50. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
  51. nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
  54. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
  55. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
  56. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
  57. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
  58. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
  59. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
  60. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
  61. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
  62. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
  64. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
  66. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
  67. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
  195. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
  196. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
  197. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
  198. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
  199. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
  200. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
  201. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  202. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  203. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  204. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  205. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
  206. nexaai/mlx_backend/vlm/interface.py +21 -4
  207. nexaai/mlx_backend/vlm/main.py +6 -2
  208. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  209. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  210. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  211. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  212. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  213. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  214. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  215. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  216. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  217. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  218. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
  219. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +221 -21
  220. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
  221. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,560 @@
1
+ import inspect
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Optional, Tuple
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+
9
+ from ..base import pixel_shuffle
10
+
11
+
12
+ @dataclass
13
+ class VisionConfig:
14
+ model_type: str
15
+ hidden_size: int
16
+ image_size: int
17
+ initializer_range: float
18
+ intermediate_size: int
19
+ norm_eps: float
20
+ num_attention_heads: int
21
+ num_channels: int
22
+ num_hidden_layers: int
23
+ patch_size: int
24
+ pixel_shuffle_ratio: float
25
+ projector_dropout: float
26
+ projector_input_dim: int
27
+ projector_output_dim: int
28
+ rope_theta: float
29
+ vision_feature_layer: int
30
+ vision_feature_select_strategy: str
31
+ vision_output_dim: int
32
+
33
+ @classmethod
34
+ def from_dict(cls, params):
35
+ return cls(
36
+ **{
37
+ k: v
38
+ for k, v in params.items()
39
+ if k in inspect.signature(cls).parameters
40
+ }
41
+ )
42
+
43
+
44
+ def check_array_shape(arr):
45
+ shape = arr.shape
46
+
47
+ # Check if the shape has 4 dimensions
48
+ if len(shape) != 4:
49
+ return False
50
+
51
+ out_channels, kH, KW, _ = shape
52
+
53
+ # Check if out_channels is the largest, and kH and KW are the same
54
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
55
+ return True
56
+ else:
57
+ return False
58
+
59
+
60
+ class Llama4MultiModalProjector(nn.Module):
61
+ def __init__(self, config):
62
+ super().__init__()
63
+ self.linear_1 = nn.Linear(
64
+ config.vision_config.vision_output_dim,
65
+ config.text_config.hidden_size,
66
+ bias=False,
67
+ )
68
+
69
+ def __call__(self, image_features):
70
+ hidden_states = self.linear_1(image_features)
71
+ return hidden_states
72
+
73
+
74
+ class Llama4VisionPixelShuffleMLP(nn.Module):
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
78
+ self.inner_dim = int(
79
+ config.projector_input_dim // (self.pixel_shuffle_ratio**2)
80
+ )
81
+ self.output_dim = config.projector_output_dim
82
+ self.mlp = Llama4VisionMLP(config, bias=False, is_projector=True)
83
+
84
+ def __call__(self, encoded_patches: mx.array) -> mx.array:
85
+ encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
86
+ return self.mlp(encoded_patches)
87
+
88
+
89
+ # TODO there is a different RoPE for vision encoder, defined as below
90
+ def reshape_for_broadcast(freqs_ci: mx.array, query: mx.array):
91
+ ndim = query.ndim
92
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
93
+ return freqs_ci.reshape(*shape)
94
+
95
+
96
+ def view_as_complex(x):
97
+ """
98
+ Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
99
+
100
+ Args:
101
+ x: A real tensor with last dimension of size 2.
102
+
103
+ Returns:
104
+ A complex tensor with size one less than the input.
105
+ """
106
+ # Ensure the last dimension is size 2
107
+ assert x.shape[-1] == 2, f"Last dimension must be 2, got {x.shape[-1]}"
108
+
109
+ # Get real and imaginary parts
110
+ real, imag = x[..., 0], x[..., 1]
111
+
112
+ # Create complex tensor
113
+ return real + 1j * imag
114
+
115
+
116
+ def view_as_real(x):
117
+ """
118
+ Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
119
+
120
+ Args:
121
+ x: A complex tensor.
122
+
123
+ Returns:
124
+ A real tensor with an extra dimension of size 2.
125
+ """
126
+ # Get real and imaginary parts
127
+ real = mx.real(x)
128
+ imag = mx.imag(x)
129
+
130
+ # Combine into a tensor with last dimension 2
131
+ return mx.stack([real, imag], axis=-1)
132
+
133
+
134
+ def vision_apply_rotary_emb(
135
+ query: mx.array,
136
+ key: mx.array,
137
+ freqs_ci: mx.array,
138
+ ) -> Tuple[mx.array, mx.array]:
139
+
140
+ query_ = view_as_complex(query.astype(mx.float32).reshape(*query.shape[:-1], -1, 2))
141
+ key_ = view_as_complex(key.astype(mx.float32).reshape(*key.shape[:-1], -1, 2))
142
+ freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_)
143
+ query_out = view_as_real(query_ * freqs_ci).flatten(3)
144
+ key_out = view_as_real(key_ * freqs_ci).flatten(3)
145
+ return query_out.astype(query.dtype), key_out.astype(key.dtype)
146
+
147
+
148
+ class Llama4VisionAttention(nn.Module):
149
+ def __init__(self, config: VisionConfig):
150
+ super().__init__()
151
+ self.config = config
152
+ self.embed_dim = config.hidden_size
153
+ self.num_heads = config.num_attention_heads
154
+ self.head_dim = config.hidden_size // config.num_attention_heads
155
+ self.num_key_value_groups = 1
156
+ self.scale = self.head_dim**-0.5
157
+
158
+ self.q_proj = nn.Linear(
159
+ self.embed_dim, self.num_heads * self.head_dim, bias=True
160
+ )
161
+ self.k_proj = nn.Linear(
162
+ self.embed_dim, self.num_heads * self.head_dim, bias=True
163
+ )
164
+ self.v_proj = nn.Linear(
165
+ self.embed_dim, self.num_heads * self.head_dim, bias=True
166
+ )
167
+ self.o_proj = nn.Linear(
168
+ self.num_heads * self.head_dim, self.embed_dim, bias=True
169
+ )
170
+
171
+ def __call__(
172
+ self,
173
+ hidden_states: mx.array,
174
+ freqs_ci: mx.array,
175
+ mask: Optional[mx.array] = None,
176
+ cache: Optional[mx.array] = None,
177
+ ):
178
+ B, L, D = hidden_states.shape
179
+
180
+ query_states = self.q_proj(hidden_states).reshape(B, L, self.num_heads, -1)
181
+ key_states = self.k_proj(hidden_states).reshape(B, L, self.num_heads, -1)
182
+ value_states = self.v_proj(hidden_states).reshape(B, L, self.num_heads, -1)
183
+
184
+ query_states, key_states = vision_apply_rotary_emb(
185
+ query_states, key_states, freqs_ci=freqs_ci
186
+ )
187
+
188
+ query_states = query_states.transpose(0, 2, 1, 3)
189
+ key_states = key_states.transpose(0, 2, 1, 3)
190
+ value_states = value_states.transpose(0, 2, 1, 3)
191
+
192
+ attn_output = mx.fast.scaled_dot_product_attention(
193
+ query_states, key_states, value_states, scale=self.scale
194
+ )
195
+
196
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
197
+ attn_output = self.o_proj(attn_output)
198
+ return attn_output
199
+
200
+
201
+ class Llama4VisionMLP(nn.Module):
202
+ def __init__(self, config, bias=True, is_projector=False):
203
+ super().__init__()
204
+ self.config = config
205
+ self.activation_fn = nn.GELU(approx="fast") # ACT2FN[config.hidden_act]
206
+ self.is_projector = is_projector
207
+ self.hidden_size = config.hidden_size
208
+ self.intermediate_size = config.intermediate_size
209
+
210
+ # Determine dimensions for first linear layer based on whether this is a projector
211
+ fc1_input_dim = self.intermediate_size if is_projector else self.hidden_size
212
+ fc1_output_dim = (
213
+ config.projector_input_dim if is_projector else self.intermediate_size
214
+ )
215
+
216
+ self.fc1 = nn.Linear(fc1_input_dim, fc1_output_dim, bias=bias)
217
+
218
+ # Determine dimensions for second linear layer
219
+ fc2_input_dim = (
220
+ config.projector_output_dim if is_projector else self.intermediate_size
221
+ )
222
+ fc2_output_dim = (
223
+ config.projector_output_dim if is_projector else self.hidden_size
224
+ )
225
+
226
+ self.fc2 = nn.Linear(fc2_input_dim, fc2_output_dim, bias=bias)
227
+
228
+ self.is_projector = is_projector
229
+
230
+ def __call__(self, hidden_states: mx.array) -> mx.array:
231
+ hidden_states = self.fc1(hidden_states)
232
+ hidden_states = self.activation_fn(hidden_states)
233
+
234
+ if self.is_projector:
235
+ return self.activation_fn(self.fc2(hidden_states))
236
+
237
+ return self.fc2(hidden_states)
238
+
239
+
240
+ class Llama4VisionEncoderLayer(nn.Module):
241
+ def __init__(self, config: VisionConfig):
242
+ super().__init__()
243
+ self.hidden_size = config.hidden_size
244
+
245
+ self.self_attn = Llama4VisionAttention(config)
246
+ self.mlp = Llama4VisionMLP(config)
247
+
248
+ self.input_layernorm = nn.LayerNorm(config.hidden_size)
249
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
250
+
251
+ def __call__(
252
+ self,
253
+ hidden_state: mx.array,
254
+ freqs_ci: mx.array,
255
+ mask: Optional[mx.array] = None,
256
+ ):
257
+ # Self Attention
258
+ residual = hidden_state
259
+
260
+ hidden_state = self.input_layernorm(hidden_state)
261
+
262
+ hidden_state = self.self_attn(
263
+ hidden_state,
264
+ freqs_ci=freqs_ci,
265
+ mask=mask,
266
+ )
267
+ hidden_state = residual + hidden_state
268
+
269
+ # Feed forward
270
+ residual = hidden_state
271
+ hidden_state = self.post_attention_layernorm(hidden_state)
272
+ hidden_state = self.mlp(hidden_state)
273
+ hidden_state = residual + hidden_state
274
+ return hidden_state
275
+
276
+
277
+ class Llama4VisionEncoder(nn.Module):
278
+ """
279
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
280
+ [`Llama4VisionEncoderLayer`].
281
+
282
+ Args:
283
+ config: VisionConfig
284
+ """
285
+
286
+ def __init__(self, config: VisionConfig):
287
+ super().__init__()
288
+ self.config = config
289
+ self.layers = [
290
+ Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
291
+ ]
292
+ self.config = config
293
+
294
+ def __call__(
295
+ self,
296
+ hidden_states: mx.array,
297
+ freqs_ci: mx.array, # TODO move this to an attribute instead of keeping it around
298
+ mask: Optional[mx.array] = None,
299
+ ):
300
+
301
+ for i, encoder_layer in enumerate(self.layers):
302
+ hidden_states = encoder_layer(
303
+ hidden_state=hidden_states,
304
+ mask=mask,
305
+ freqs_ci=freqs_ci,
306
+ )
307
+
308
+ return hidden_states
309
+
310
+
311
+ class Llama4UnfoldConvolution(nn.Module):
312
+ def __init__(self, config):
313
+ super().__init__()
314
+ kernel_size = config.patch_size
315
+ if isinstance(kernel_size, int):
316
+ kernel_size = (kernel_size, kernel_size)
317
+ self.kernel_size = kernel_size
318
+ self.stride = config.patch_size
319
+ self.linear = nn.Linear(
320
+ config.num_channels * kernel_size[0] * kernel_size[1],
321
+ config.hidden_size,
322
+ bias=False,
323
+ )
324
+
325
+ def _pair(self, x):
326
+ """Convert input to a pair of values."""
327
+ if isinstance(x, (list, tuple)):
328
+ return tuple(x)
329
+ return (x, x)
330
+
331
+ def unfold(self, input_tensor):
332
+ """
333
+ Extract sliding local blocks from a batched input tensor (MLX implementation).
334
+
335
+ This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
336
+
337
+ Args:
338
+ input_tensor: Input tensor of shape (B, C, H, W)
339
+
340
+ Returns:
341
+ Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
342
+ where L is the number of blocks
343
+ """
344
+ # Convert to pairs
345
+ kernel_size = self._pair(self.kernel_size)
346
+ stride = self._pair(self.stride)
347
+ padding = (0, 0) # No padding in the original code
348
+ dilation = (1, 1) # Default dilation
349
+
350
+ # Input shape
351
+ batch_size, channels, height, width = input_tensor.shape
352
+
353
+ # Calculate output dimensions
354
+ height_out = (
355
+ height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
356
+ ) // stride[0] + 1
357
+ width_out = (
358
+ width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
359
+ ) // stride[1] + 1
360
+
361
+ # Initialize output arrays
362
+ blocks = []
363
+
364
+ # Extract blocks
365
+ for i in range(0, height - kernel_size[0] * dilation[0] + 1, stride[0]):
366
+ for j in range(0, width - kernel_size[1] * dilation[1] + 1, stride[1]):
367
+ # Extract the block for all channels
368
+ block = []
369
+ for di in range(kernel_size[0]):
370
+ for dj in range(kernel_size[1]):
371
+ h_idx = i + di * dilation[0]
372
+ w_idx = j + dj * dilation[1]
373
+ # Get the block for all channels and add to our list
374
+ block.append(input_tensor[:, :, h_idx, w_idx])
375
+
376
+ # Stack the channel-blocks
377
+ block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
378
+ block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
379
+ blocks.append(block)
380
+
381
+ # Stack all blocks together
382
+ result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
383
+
384
+ # Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
385
+ result = mx.reshape(
386
+ result,
387
+ (
388
+ batch_size,
389
+ channels * kernel_size[0] * kernel_size[1],
390
+ height_out * width_out,
391
+ ),
392
+ )
393
+
394
+ return result
395
+
396
+ def __call__(self, hidden_states: mx.array) -> mx.array:
397
+ hidden_states = self.unfold(hidden_states)
398
+ hidden_states = hidden_states.swapaxes(1, 2)
399
+ hidden_states = self.linear(hidden_states)
400
+ return hidden_states
401
+
402
+
403
+ class Llama4VisionRotaryEmbedding:
404
+ def __init__(self, config):
405
+ super().__init__()
406
+ idx = config.image_size // config.patch_size
407
+ img_idx = mx.arange(idx**2, dtype=mx.int32).reshape(idx**2, 1)
408
+ img_idx = mx.concatenate([img_idx, img_idx[:1]], axis=0)
409
+ img_idx[-1, -1] = -2 # ID_CLS_TOKEN
410
+ frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
411
+ frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
412
+ freq_dim = config.hidden_size // config.num_attention_heads // 2
413
+ rope_freq = 1.0 / (
414
+ config.rope_theta
415
+ ** (
416
+ mx.arange(0, freq_dim, 2, dtype=mx.float32)[: (freq_dim // 2)]
417
+ / freq_dim
418
+ )
419
+ )
420
+
421
+ # Expand dimensions for frequencies_x and frequencies_y
422
+ freqs_x_expanded = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
423
+ freqs_y_expanded = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
424
+
425
+ def repeat_interleave(tensor, repeats, dim=-1):
426
+ # Get the shape
427
+ shape = list(tensor.shape)
428
+
429
+ # Reshape to add an extra dimension for repeating
430
+ tensor = mx.reshape(tensor, shape[:-1] + [shape[-1], 1])
431
+
432
+ # Repeat along the new dimension
433
+ tensor = mx.repeat(tensor, repeats, axis=-1)
434
+
435
+ # Reshape to flatten the last two dimensions
436
+ return mx.reshape(tensor, shape[:-1] + [shape[-1] * repeats])
437
+
438
+ # Apply interleaving
439
+ freqs_x = repeat_interleave(freqs_x_expanded, 2)
440
+ freqs_y = repeat_interleave(freqs_y_expanded, 2)
441
+ freqs = mx.concatenate([freqs_x, freqs_y], axis=-1).astype(mx.float32)[..., ::2]
442
+ # Replaced masked_fill with where
443
+ mask = img_idx.reshape(-1, 1, 1) < 0
444
+ freqs = mx.where(mask, mx.zeros_like(freqs), freqs)
445
+ freq_cis = mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
446
+ freq_cis = view_as_complex(freq_cis)
447
+ self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
448
+
449
+ def __call__(self, hidden_states):
450
+ return self.freqs_ci
451
+
452
+
453
+ class VisionModel(nn.Module):
454
+ def __init__(self, config: VisionConfig):
455
+ super().__init__()
456
+ self.image_size = config.image_size
457
+ self.patch_size = config.patch_size
458
+ self.hidden_size = config.hidden_size
459
+ self.num_channels = config.num_channels
460
+ self.model_type = config.model_type
461
+ if self.model_type not in ["llama4", "llama4_vision_model"]:
462
+ raise ValueError(f"Model type {self.model_type} not supported")
463
+
464
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
465
+ self.scale = config.hidden_size**-0.5
466
+
467
+ self.class_embedding = self.scale * mx.random.normal((self.hidden_size,))
468
+ self.positional_embedding_vlm = self.scale * mx.random.normal(
469
+ (self.num_patches, self.hidden_size)
470
+ )
471
+
472
+ self.patch_embedding = Llama4UnfoldConvolution(config)
473
+
474
+ self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
475
+
476
+ # layer norms
477
+ self.layernorm_pre = nn.LayerNorm(self.hidden_size)
478
+ self.layernorm_post = nn.LayerNorm(self.hidden_size)
479
+
480
+ # encoders
481
+ self.model = Llama4VisionEncoder(config)
482
+ self.vision_adapter = Llama4VisionPixelShuffleMLP(config)
483
+
484
+ def get_input_embeddings(self):
485
+ """
486
+ This function is used to fetch the first embedding layer to activate grads on inputs.
487
+ """
488
+ return self.patch_embedding
489
+
490
+ def __call__(
491
+ self,
492
+ pixel_values: mx.array,
493
+ output_attentions: Optional[bool] = None,
494
+ output_hidden_states: Optional[bool] = None,
495
+ capture_activations: Optional[bool] = True,
496
+ ):
497
+
498
+ batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
499
+ num_concurrent_media = 1
500
+ num_chunks = 1
501
+
502
+ hidden_state = self.patch_embedding(pixel_values)
503
+
504
+ _, num_patches, hidden_dim = hidden_state.shape
505
+
506
+ # Add cls token
507
+ hidden_state = hidden_state.reshape(
508
+ batch_size_times_num_tiles * num_concurrent_media * num_chunks,
509
+ num_patches,
510
+ hidden_dim,
511
+ )
512
+
513
+ class_embedding = mx.broadcast_to(
514
+ self.class_embedding, (hidden_state.shape[0], 1, hidden_state.shape[-1])
515
+ )
516
+ hidden_state = mx.concatenate([hidden_state, class_embedding], axis=1)
517
+ num_patches += 1
518
+
519
+ # Position embeddings
520
+ hidden_state = hidden_state.reshape(
521
+ batch_size_times_num_tiles * num_concurrent_media,
522
+ num_chunks,
523
+ num_patches,
524
+ hidden_dim,
525
+ )
526
+
527
+ positional_embedding = self.positional_embedding_vlm
528
+ hidden_state = hidden_state + positional_embedding
529
+
530
+ hidden_state = self.layernorm_pre(hidden_state)
531
+
532
+ hidden_state = hidden_state.reshape(batch_size_times_num_tiles, -1, hidden_dim)
533
+ freqs_ci = self.rotary_embedding(pixel_values)
534
+
535
+ hidden_state = self.model(
536
+ hidden_state,
537
+ mask=None,
538
+ freqs_ci=freqs_ci,
539
+ )
540
+
541
+ hidden_state = self.layernorm_post(hidden_state)
542
+
543
+ hidden_state = hidden_state[:, :-1, :]
544
+
545
+ # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
546
+ final_hidden_state = self.vision_adapter(hidden_state)
547
+
548
+ # Return only the final state
549
+ return final_hidden_state
550
+
551
+ def sanitize(self, weights):
552
+ sanitized_weights = {}
553
+ for k, v in weights.items():
554
+ if "position_ids" in k:
555
+ # Remove unused position_ids
556
+ continue
557
+ else:
558
+ sanitized_weights[k] = v
559
+
560
+ return sanitized_weights
@@ -0,0 +1,8 @@
1
+ from .llava import (
2
+ LanguageModel,
3
+ Model,
4
+ ModelConfig,
5
+ TextConfig,
6
+ VisionConfig,
7
+ VisionModel,
8
+ )