nexaai 1.0.19rc5__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (221) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  5. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  6. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  7. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  8. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  9. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  10. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  11. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  12. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
  13. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
  14. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  15. nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
  16. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  17. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
  18. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
  19. nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
  20. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
  21. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
  23. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
  24. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
  25. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  26. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
  33. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
  34. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
  35. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
  36. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
  37. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
  38. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
  39. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  40. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
  41. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
  42. nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
  43. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  44. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
  45. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
  46. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
  47. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  48. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
  49. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
  50. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
  51. nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
  54. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
  55. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
  56. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
  57. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
  58. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
  59. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
  60. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
  61. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
  62. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
  64. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
  66. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
  67. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
  195. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
  196. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
  197. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
  198. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
  199. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
  200. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
  201. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  202. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  203. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  204. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  205. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
  206. nexaai/mlx_backend/vlm/interface.py +21 -4
  207. nexaai/mlx_backend/vlm/main.py +6 -2
  208. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  209. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  210. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  211. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  212. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  213. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  214. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  215. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  216. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  217. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  218. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
  219. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +221 -21
  220. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
  221. {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,312 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+ from .config import VisionConfig
9
+
10
+
11
+ def check_array_shape(arr):
12
+ shape = arr.shape
13
+
14
+ # Check if the shape has 4 dimensions
15
+ if len(shape) not in [4, 5]:
16
+ return False
17
+
18
+ B, out_channels, kH, KW, t = shape
19
+
20
+ if t == 3:
21
+ return True
22
+
23
+ # Check if out_channels is the largest, and kH and KW are the same
24
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
25
+ return True
26
+ else:
27
+ return False
28
+
29
+
30
+ def rotate_half(x):
31
+ """Rotates half the hidden dims of the input."""
32
+ x1 = x[..., : x.shape[-1] // 2]
33
+ x2 = x[..., x.shape[-1] // 2 :]
34
+ return mx.concatenate([-x2, x1], axis=-1)
35
+
36
+
37
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
38
+ orig_dtype = tensor.dtype
39
+
40
+ cos = mx.cos(freqs)
41
+ sin = mx.sin(freqs)
42
+
43
+ cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
44
+ cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
45
+ cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
46
+
47
+ sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
48
+ sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
49
+ sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
50
+
51
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
52
+ return output.astype(orig_dtype)
53
+
54
+
55
+ class VisionRotaryEmbedding(nn.Module):
56
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
57
+ super().__init__()
58
+ self.dim = dim
59
+ self.theta = theta
60
+
61
+ def __call__(self, seqlen: int) -> mx.array:
62
+ inv_freq = 1.0 / (
63
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
64
+ )
65
+ seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype)
66
+ freqs = mx.outer(seq, inv_freq)
67
+ return freqs
68
+
69
+
70
+ class PatchEmbed(nn.Module):
71
+ def __init__(
72
+ self,
73
+ patch_size: int = 14,
74
+ temporal_patch_size: int = 2,
75
+ in_channels: int = 3,
76
+ embed_dim: int = 1152,
77
+ ) -> None:
78
+ super().__init__()
79
+ self.patch_size = patch_size
80
+ self.temporal_patch_size = temporal_patch_size
81
+ self.in_channels = in_channels
82
+ self.embed_dim = embed_dim
83
+
84
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
85
+ self.proj = nn.Conv3d(
86
+ in_channels,
87
+ embed_dim,
88
+ kernel_size=kernel_size,
89
+ stride=kernel_size,
90
+ bias=False,
91
+ )
92
+
93
+ def __call__(self, hidden_states: mx.array) -> mx.array:
94
+ hidden_states = hidden_states.reshape(
95
+ -1,
96
+ self.in_channels,
97
+ self.temporal_patch_size,
98
+ self.patch_size,
99
+ self.patch_size,
100
+ ).moveaxis(1, 4)
101
+
102
+ hidden_states = self.proj(hidden_states)
103
+ hidden_states = hidden_states.reshape(-1, self.embed_dim)
104
+ return hidden_states
105
+
106
+
107
+ class PatchMerger(nn.Module):
108
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
109
+ super().__init__()
110
+ self.hidden_size = context_dim * (spatial_merge_size**2)
111
+ self.ln_q = nn.LayerNorm(context_dim, eps=1e-6)
112
+ self.mlp = [
113
+ nn.Linear(self.hidden_size, self.hidden_size),
114
+ nn.GELU(),
115
+ nn.Linear(self.hidden_size, dim),
116
+ ]
117
+
118
+ def __call__(self, x: mx.array) -> mx.array:
119
+ x = self.ln_q(x).reshape(-1, self.hidden_size)
120
+ for layer in self.mlp:
121
+ x = layer(x)
122
+ return x
123
+
124
+
125
+ class Attention(nn.Module):
126
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
127
+ super().__init__()
128
+ self.num_heads = num_heads
129
+ self.head_dim = head_dim = dim // num_heads
130
+ self.scale = head_dim**-0.5
131
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
132
+ self.proj = nn.Linear(dim, dim)
133
+
134
+ def __call__(
135
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
136
+ ) -> mx.array:
137
+ seq_length = x.shape[0]
138
+ qkv = (
139
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
140
+ )
141
+ q, k, v = mx.split(qkv, 3)
142
+
143
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
144
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
145
+ attention_mask = mx.ones((1, seq_length, seq_length), dtype=x.dtype)
146
+
147
+ for i in range(1, len(cu_seqlens)):
148
+ start = int(cu_seqlens[i - 1])
149
+ end = int(cu_seqlens[i])
150
+ attention_mask[start:end, start:end] = 0
151
+
152
+ q = q.transpose(0, 2, 1, 3)
153
+ k = k.transpose(0, 2, 1, 3)
154
+ v = v.transpose(0, 2, 1, 3)
155
+
156
+ output = mx.fast.scaled_dot_product_attention(
157
+ q, k, v, scale=self.scale, mask=attention_mask
158
+ )
159
+ output = output.transpose(0, 2, 1, 3)
160
+ output = output.reshape(seq_length, -1)
161
+ return self.proj(output)
162
+
163
+
164
+ class MLP(nn.Module):
165
+ def __init__(self, dim, hidden_dim):
166
+ super().__init__()
167
+ self.activation_fn = nn.GELU(approx="fast")
168
+ self.fc1 = nn.Linear(dim, hidden_dim)
169
+ self.fc2 = nn.Linear(hidden_dim, dim)
170
+
171
+ def __call__(self, x: mx.array) -> mx.array:
172
+ x = self.activation_fn(self.fc1(x))
173
+ x = self.fc2(x)
174
+ return x
175
+
176
+
177
+ class Qwen2VLVisionBlock(nn.Module):
178
+ def __init__(self, config: VisionConfig) -> None:
179
+ super().__init__()
180
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
181
+ self.norm2 = nn.LayerNorm(config.embed_dim, eps=1e-6)
182
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
183
+
184
+ self.attn = Attention(dim=config.embed_dim, num_heads=config.num_heads)
185
+ self.mlp = MLP(dim=config.embed_dim, hidden_dim=mlp_hidden_dim)
186
+
187
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
188
+ hidden_states = hidden_states + self.attn(
189
+ self.norm1(hidden_states),
190
+ cu_seqlens=cu_seqlens,
191
+ rotary_pos_emb=rotary_pos_emb,
192
+ )
193
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
194
+ return hidden_states
195
+
196
+
197
+ class VisionModel(nn.Module):
198
+
199
+ def __init__(self, config: VisionConfig) -> None:
200
+ super().__init__()
201
+ self.config = config
202
+ self.model_type = config.model_type
203
+ if self.model_type != "qwen2_vl":
204
+ raise ValueError(f"Unsupported model type: {self.model_type}")
205
+ self.spatial_merge_size = config.spatial_merge_size
206
+
207
+ self.patch_embed = PatchEmbed(
208
+ patch_size=config.patch_size,
209
+ temporal_patch_size=config.temporal_patch_size,
210
+ in_channels=config.in_channels,
211
+ embed_dim=config.embed_dim,
212
+ )
213
+
214
+ head_dim = config.embed_dim // config.num_heads
215
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
216
+
217
+ self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
218
+ self.merger = PatchMerger(dim=config.hidden_size, context_dim=config.embed_dim)
219
+
220
+ def rot_pos_emb(self, grid_thw):
221
+ pos_ids = []
222
+
223
+ for t, h, w in grid_thw:
224
+ h, w = int(h), int(w) # Ensure h and w are integers
225
+ hpos_ids = mx.expand_dims(mx.arange(h), 1)
226
+ hpos_ids = mx.repeat(hpos_ids, w, axis=1)
227
+ hpos_ids = hpos_ids.reshape(
228
+ h // self.spatial_merge_size,
229
+ self.spatial_merge_size,
230
+ w // self.spatial_merge_size,
231
+ self.spatial_merge_size,
232
+ )
233
+ hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
234
+ hpos_ids = hpos_ids.flatten()
235
+
236
+ wpos_ids = mx.expand_dims(mx.arange(w), 0)
237
+ wpos_ids = mx.repeat(wpos_ids, h, axis=0)
238
+ wpos_ids = wpos_ids.reshape(
239
+ h // self.spatial_merge_size,
240
+ self.spatial_merge_size,
241
+ w // self.spatial_merge_size,
242
+ self.spatial_merge_size,
243
+ )
244
+ wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
245
+ wpos_ids = wpos_ids.flatten()
246
+
247
+ stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
248
+ pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
249
+
250
+ pos_ids = mx.concatenate(pos_ids, axis=0)
251
+ max_grid_size = mx.max(grid_thw[:, 1:])
252
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
253
+
254
+ rotary_pos_emb_full = rotary_pos_emb_full[pos_ids]
255
+
256
+ return rotary_pos_emb_full.reshape(pos_ids.shape[0], -1)
257
+
258
+ def __call__(
259
+ self,
260
+ hidden_states: mx.array,
261
+ grid_thw: mx.array,
262
+ output_hidden_states: Optional[bool] = None,
263
+ ) -> mx.array:
264
+
265
+ hidden_states = self.patch_embed(hidden_states)
266
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
267
+
268
+ # Assuming grid_thw has shape (batch_size, 3)
269
+ batch_size = grid_thw.shape[0]
270
+
271
+ # Calculate cu_seqlens for each item in the batch
272
+ cu_seqlens = []
273
+ for i in range(batch_size):
274
+ seq_len = grid_thw[i, 1] * grid_thw[i, 2]
275
+ cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
276
+
277
+ # Concatenate the cu_seqlens for all items in the batch
278
+ cu_seqlens = mx.concatenate(cu_seqlens)
279
+
280
+ cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
281
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
282
+
283
+ encoder_states = (hidden_states,) if output_hidden_states else None
284
+
285
+ for blk in self.blocks:
286
+ hidden_states = blk(
287
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
288
+ )
289
+ if output_hidden_states:
290
+ encoder_states = encoder_states + (hidden_states,)
291
+
292
+ return self.merger(hidden_states)
293
+
294
+ def sanitize(self, weights):
295
+ sanitized_weights = {}
296
+ for k, v in weights.items():
297
+ if "position_ids" in k:
298
+ # Remove unused position_ids
299
+ continue
300
+ elif "patch_embed.proj.weight" in k:
301
+ # PyTorch conv2d weight tensors have shape:
302
+ # [out_channels, in_channels, kH, KW]
303
+ # MLX conv2d expects the weight be of shape:
304
+ # [out_channels, kH, KW, in_channels]
305
+ if check_array_shape(v):
306
+ sanitized_weights[k] = v
307
+ else:
308
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
309
+ else:
310
+ sanitized_weights[k] = v
311
+
312
+ return sanitized_weights
@@ -0,0 +1,117 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Any, Optional
4
+
5
+ import mlx.core as mx
6
+ from mlx.utils import tree_map
7
+
8
+ from .cache import QuantizedKVCache
9
+
10
+
11
+ @dataclass
12
+ class BaseModelArgs:
13
+ @classmethod
14
+ def from_dict(cls, params):
15
+ return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
16
+
17
+
18
+ def create_causal_mask(
19
+ N: int,
20
+ offset: int = 0,
21
+ window_size: Optional[int] = None,
22
+ lengths: Optional[mx.array] = None,
23
+ ):
24
+ rinds = mx.arange(offset + N)
25
+ linds = mx.arange(offset, offset + N) if offset else rinds
26
+ linds = linds[:, None]
27
+ rinds = rinds[None]
28
+ mask = linds >= rinds
29
+ if window_size is not None:
30
+ mask = mask & (linds <= rinds + window_size)
31
+ if lengths is not None:
32
+ lengths = lengths[:, None, None, None]
33
+ mask = mask & (rinds < lengths)
34
+ return mask
35
+
36
+
37
+ def create_attention_mask(h: mx.array, cache: Optional[Any] = None, return_array: bool = False):
38
+ T = h.shape[1]
39
+ if T > 1:
40
+ offset = 0
41
+ window_size = None
42
+ if cache is not None and cache[0] is not None:
43
+ c = cache[0]
44
+ offset = c.offset
45
+ if hasattr(c, "max_size"):
46
+ window_size = c.max_size
47
+ offset = min(window_size, offset)
48
+ return_array = return_array or offset + T > window_size
49
+ if return_array:
50
+ return create_causal_mask(T, offset, window_size=window_size)
51
+ else:
52
+ return "causal"
53
+ else:
54
+ mask = None
55
+ return mask
56
+
57
+
58
+ def quantized_scaled_dot_product_attention(
59
+ queries: mx.array,
60
+ q_keys: tuple[mx.array, mx.array, mx.array],
61
+ q_values: tuple[mx.array, mx.array, mx.array],
62
+ scale: float,
63
+ mask: Optional[mx.array],
64
+ group_size: int = 64,
65
+ bits: int = 8,
66
+ ) -> mx.array:
67
+ B, n_q_heads, L, D = queries.shape
68
+ n_kv_heads = q_keys[0].shape[-3]
69
+ n_repeats = n_q_heads // n_kv_heads
70
+
71
+ queries *= scale
72
+
73
+ if n_repeats > 1:
74
+ queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
75
+ q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
76
+ q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
77
+
78
+ scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
79
+ if mask is not None:
80
+ if isinstance(mask, str):
81
+ qL, kL = scores.shape[-2:]
82
+ q_indices = mx.arange(kL - qL, kL)
83
+ k_indices = mx.arange(kL)
84
+ mask = q_indices[:, None] >= k_indices[None]
85
+ if mask.dtype == mx.bool_:
86
+ scores = mx.where(mask, scores, mx.finfo(scores.dtype).min)
87
+ else:
88
+ scores += mask
89
+ scores = mx.softmax(scores, axis=-1, precise=True)
90
+ out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
91
+
92
+ if n_repeats > 1:
93
+ out = mx.reshape(out, (B, n_q_heads, L, D))
94
+
95
+ return out
96
+
97
+
98
+ def scaled_dot_product_attention(
99
+ queries,
100
+ keys,
101
+ values,
102
+ cache,
103
+ scale: float,
104
+ mask: Optional[mx.array],
105
+ ) -> mx.array:
106
+ if isinstance(cache, QuantizedKVCache):
107
+ return quantized_scaled_dot_product_attention(
108
+ queries,
109
+ keys,
110
+ values,
111
+ scale=scale,
112
+ mask=mask,
113
+ group_size=cache.group_size,
114
+ bits=cache.bits,
115
+ )
116
+ else:
117
+ return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)