nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (224) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  5. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  6. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  7. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  8. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  9. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  10. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  11. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  12. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
  13. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
  14. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  15. nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
  16. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  17. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
  18. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
  19. nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
  20. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
  21. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
  23. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
  24. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
  25. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  26. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
  33. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
  34. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
  35. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
  36. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
  37. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
  38. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
  39. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  40. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
  41. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
  42. nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
  43. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  44. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
  45. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
  46. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
  47. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  48. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
  49. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
  50. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
  51. nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
  54. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
  55. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
  56. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
  57. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
  58. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
  59. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
  60. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
  61. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
  62. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
  64. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
  66. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
  67. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
  195. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
  196. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
  197. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
  198. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
  199. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
  200. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
  201. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  202. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  203. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  204. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  205. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
  206. nexaai/mlx_backend/vlm/interface.py +21 -4
  207. nexaai/mlx_backend/vlm/main.py +6 -2
  208. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  209. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  210. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  211. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  212. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  213. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  214. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  215. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  216. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
  217. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  218. nexaai/utils/manifest_utils.py +222 -15
  219. nexaai/utils/model_manager.py +83 -7
  220. nexaai/utils/model_types.py +2 -0
  221. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
  222. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +224 -24
  223. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
  224. {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .qwen2_5_vl import LanguageModel, Model, VisionModel
@@ -0,0 +1,108 @@
1
+ import inspect
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Union
4
+
5
+
6
+ @dataclass
7
+ class VisionConfig:
8
+ model_type: str = "qwen2_5_vl"
9
+ depth: int = 32
10
+ hidden_size: int = 1280
11
+ intermediate_size: int = 3420
12
+ out_hidden_size: int = 1536
13
+ num_heads: int = 16
14
+ image_size: int = 384
15
+ patch_size: int = 14
16
+ vocab_size: int = 32000
17
+ mlp_ratio: float = 4.0
18
+ in_channels: int = 3
19
+ layer_norm_eps: float = 1e-6
20
+ spatial_patch_size: int = 14
21
+ spatial_merge_size: int = 2
22
+ tokens_per_second: int = 2
23
+ temporal_patch_size: int = 2
24
+ window_size: int = 112
25
+ patch_size: int = 14
26
+ fullatt_block_indexes: list[int] = field(default_factory=lambda: [7, 15, 23, 31])
27
+
28
+ @classmethod
29
+ def from_dict(cls, params):
30
+ return cls(
31
+ **{
32
+ k: v
33
+ for k, v in params.items()
34
+ if k in inspect.signature(cls).parameters
35
+ }
36
+ )
37
+
38
+
39
+ @dataclass
40
+ class TextConfig:
41
+ model_type: str
42
+ hidden_size: int
43
+ num_hidden_layers: int
44
+ intermediate_size: int
45
+ num_attention_heads: int
46
+ rms_norm_eps: float
47
+ vocab_size: int
48
+ num_key_value_heads: Optional[int] = None
49
+ max_position_embeddings: Optional[int] = 128000
50
+ rope_theta: float = 1000000.0
51
+ rope_traditional: bool = False
52
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
53
+ tie_word_embeddings: bool = True
54
+
55
+ def __post_init__(self):
56
+ if self.num_key_value_heads is None:
57
+ self.num_key_value_heads = self.num_attention_heads
58
+
59
+ if self.rope_scaling:
60
+ required_keys = {"mrope_section", "type"}
61
+ if not all(key in self.rope_scaling for key in required_keys):
62
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
63
+
64
+ if not self.rope_scaling["type"] in ["mrope", "default"]:
65
+ raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
66
+
67
+ @classmethod
68
+ def from_dict(cls, params):
69
+ return cls(
70
+ **{
71
+ k: v
72
+ for k, v in params.items()
73
+ if k in inspect.signature(cls).parameters
74
+ }
75
+ )
76
+
77
+
78
+ @dataclass
79
+ class ModelConfig:
80
+ text_config: TextConfig
81
+ vision_config: VisionConfig
82
+ model_type: str
83
+ ignore_index: int = -100
84
+ image_token_id: int = 151655
85
+ video_token_id: int = 151656
86
+ vision_start_token_id: int = 151652
87
+ vision_end_token_id: int = 151653
88
+ vision_token_id: int = 151654
89
+ vision_feature_select_strategy: str = "default"
90
+ vision_feature_layer: int = -2
91
+ vocab_size: int = 32000
92
+ eos_token_id: Optional[List[int]] = None
93
+
94
+ @classmethod
95
+ def from_dict(cls, params):
96
+ # Copy text config parameters from root level
97
+ excluded_keys = {"vision_config"}
98
+ params["text_config"] = dict(
99
+ filter(lambda x: x[0] not in excluded_keys, params.items())
100
+ )
101
+
102
+ return cls(
103
+ **{
104
+ k: v
105
+ for k, v in params.items()
106
+ if k in inspect.signature(cls).parameters
107
+ }
108
+ )
@@ -0,0 +1,490 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import (
8
+ LanguageModelOutput,
9
+ create_attention_mask,
10
+ scaled_dot_product_attention,
11
+ )
12
+ from ..cache import KVCache
13
+ from .config import ModelConfig, TextConfig
14
+
15
+
16
+ class Qwen2RotaryEmbedding:
17
+ def __init__(self, dim, max_position_embeddings=2048, base=10000):
18
+ self.dim = dim
19
+ self.max_position_embeddings = max_position_embeddings
20
+ self.base = base
21
+
22
+ inv_freq = 1.0 / (
23
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
24
+ )
25
+ self.inv_freq = inv_freq
26
+
27
+ self._set_cos_sin_cache(seq_len=max_position_embeddings)
28
+
29
+ def _set_cos_sin_cache(self, seq_len):
30
+ self.max_seq_len_cached = seq_len
31
+ t = mx.arange(self.max_seq_len_cached).astype(mx.float32)
32
+
33
+ freqs = mx.outer(t, self.inv_freq)
34
+ emb = mx.concatenate((freqs, freqs), axis=-1)
35
+ self.cos_cached = mx.cos(emb)
36
+ self.sin_cached = mx.sin(emb)
37
+
38
+ def __call__(self, x, seq_len=None):
39
+
40
+ if seq_len > self.max_seq_len_cached:
41
+ self._set_cos_sin_cache(seq_len=seq_len)
42
+
43
+ return (
44
+ self.cos_cached[:seq_len].astype(x.dtype),
45
+ self.sin_cached[:seq_len].astype(x.dtype),
46
+ )
47
+
48
+
49
+ def rotate_half(x):
50
+ """Rotates half the hidden dims of the input."""
51
+ x1 = x[..., : x.shape[-1] // 2]
52
+ x2 = x[..., x.shape[-1] // 2 :]
53
+ return mx.concatenate([-x2, x1], axis=-1)
54
+
55
+
56
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section):
57
+ """
58
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
59
+ Args:
60
+ q (mx.array): The query tensor.
61
+ k (mx.array): The key tensor.
62
+ cos (mx.array): The cosine part of the rotary embedding.
63
+ sin (mx.array): The sine part of the rotary embedding.
64
+ mrope_section (List[int]): Multimodal rope section for channel dimension of temporal, height and width.
65
+ unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
66
+ Returns:
67
+ tuple(mx.array): The rotated query and key tensors.
68
+ """
69
+
70
+ mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()
71
+ cos = cos[position_ids]
72
+ sin = sin[position_ids]
73
+
74
+ cos = mx.concatenate(
75
+ [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))], axis=-1
76
+ )[
77
+ :, None, :, :
78
+ ] # unsqueeze dim 1
79
+ sin = mx.concatenate(
80
+ [m[i % 3] for i, m in enumerate(mx.split(sin, mrope_section, axis=-1))], axis=-1
81
+ )[:, None, :, :]
82
+
83
+ # Apply rotary embedding
84
+ q_embed = (q * cos) + (rotate_half(q) * sin)
85
+ k_embed = (k * cos) + (rotate_half(k) * sin)
86
+
87
+ return q_embed, k_embed
88
+
89
+
90
+ class Attention(nn.Module):
91
+ def __init__(self, args: TextConfig):
92
+ super().__init__()
93
+
94
+ dim = args.hidden_size
95
+ self.n_heads = n_heads = args.num_attention_heads
96
+ assert args.num_key_value_heads is not None
97
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
98
+
99
+ self.head_dim = head_dim = args.hidden_size // n_heads
100
+ self.scale = head_dim**-0.5
101
+
102
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
103
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
104
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
105
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
106
+
107
+ self.rope_scaling = args.rope_scaling
108
+
109
+ self.rotary_emb = Qwen2RotaryEmbedding(
110
+ head_dim,
111
+ max_position_embeddings=args.max_position_embeddings,
112
+ base=args.rope_theta,
113
+ )
114
+
115
+ def __call__(
116
+ self,
117
+ x: mx.array,
118
+ mask: Optional[mx.array] = None,
119
+ cache: Optional[KVCache] = None,
120
+ position_ids: Optional[mx.array] = None,
121
+ ) -> mx.array:
122
+ B, L, D = x.shape
123
+
124
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
125
+
126
+ # Prepare the queries, keys and values for the attention computation
127
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
128
+ 0, 2, 1, 3
129
+ )
130
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
131
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
132
+ 0, 2, 1, 3
133
+ )
134
+
135
+ kv_seq_len = keys.shape[-2]
136
+
137
+ if position_ids is None:
138
+ kv_seq_len += cache.offset + 1
139
+ position_ids = mx.arange(cache.offset, cache.offset + L)
140
+ position_ids = mx.expand_dims(position_ids, axis=0)
141
+ position_ids = mx.tile(position_ids, (3, 1, 1))
142
+ else:
143
+ kv_seq_len += cache.offset + 1 if cache is not None else 0
144
+
145
+ cos, sin = self.rotary_emb(values, kv_seq_len)
146
+
147
+ if mask is not None and isinstance(mask, mx.array):
148
+ mask = mask[..., : keys.shape[-2]]
149
+ queries, keys = apply_multimodal_rotary_pos_emb(
150
+ queries, keys, cos, sin, position_ids, self.rope_scaling["mrope_section"]
151
+ )
152
+
153
+ if cache is not None:
154
+ keys, values = cache.update_and_fetch(keys, values)
155
+
156
+ output = scaled_dot_product_attention(
157
+ queries, keys, values, cache, scale=self.scale, mask=mask
158
+ )
159
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
160
+ return self.o_proj(output)
161
+
162
+
163
+ class MLP(nn.Module):
164
+ def __init__(self, dim, hidden_dim):
165
+ super().__init__()
166
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
167
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
168
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
169
+
170
+ def __call__(self, x) -> mx.array:
171
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
172
+
173
+
174
+ class Qwen2VLDecoderLayer(nn.Module):
175
+ def __init__(self, args: TextConfig):
176
+ super().__init__()
177
+ self.num_attention_heads = args.num_attention_heads
178
+ self.hidden_size = args.hidden_size
179
+ self.self_attn = Attention(args)
180
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
181
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
182
+ self.post_attention_layernorm = nn.RMSNorm(
183
+ args.hidden_size, eps=args.rms_norm_eps
184
+ )
185
+ self.args = args
186
+
187
+ def __call__(
188
+ self,
189
+ x: mx.array,
190
+ mask: Optional[mx.array] = None,
191
+ cache: Optional[KVCache] = None,
192
+ position_ids: Optional[mx.array] = None,
193
+ ) -> mx.array:
194
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
195
+ h = x + r
196
+ r = self.mlp(self.post_attention_layernorm(h))
197
+ out = h + r
198
+ return out
199
+
200
+
201
+ class Qwen2Model(nn.Module):
202
+ def __init__(self, args: TextConfig):
203
+ super().__init__()
204
+ self.args = args
205
+ self.vocab_size = args.vocab_size
206
+ self.num_hidden_layers = args.num_hidden_layers
207
+ assert self.vocab_size > 0
208
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
209
+ self.layers = [
210
+ Qwen2VLDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
211
+ ]
212
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
213
+
214
+ def __call__(
215
+ self,
216
+ inputs: mx.array,
217
+ inputs_embeds: Optional[mx.array] = None,
218
+ mask: Optional[mx.array] = None,
219
+ cache=None,
220
+ position_ids: Optional[mx.array] = None,
221
+ ):
222
+ if inputs_embeds is None:
223
+ h = self.embed_tokens(inputs)
224
+ else:
225
+ h = inputs_embeds
226
+
227
+ if cache is None:
228
+ cache = [None] * len(self.layers)
229
+
230
+ if mask is None:
231
+ mask = create_attention_mask(h, cache)
232
+
233
+ for layer, c in zip(self.layers, cache):
234
+ h = layer(h, mask, c, position_ids)
235
+
236
+ return self.norm(h)
237
+
238
+
239
+ class LanguageModel(nn.Module):
240
+ def __init__(self, args: TextConfig, config: ModelConfig):
241
+ super().__init__()
242
+ self.args = args
243
+ self.config = config
244
+ self.model_type = args.model_type
245
+ self.model = Qwen2Model(args)
246
+ self.rope_deltas = None
247
+
248
+ if not args.tie_word_embeddings:
249
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
250
+
251
+ def get_rope_index(
252
+ self,
253
+ input_ids: mx.array,
254
+ image_grid_thw: Optional[mx.array] = None,
255
+ video_grid_thw: Optional[mx.array] = None,
256
+ attention_mask: Optional[mx.array] = None,
257
+ ):
258
+ # Calculate RoPE index for image/video tokens
259
+ batch_size, seq_length = input_ids.shape
260
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
261
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
262
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
263
+ image_token_id = self.config.image_token_id
264
+ video_token_id = self.config.video_token_id
265
+ vision_start_token_id = self.config.vision_start_token_id
266
+ mrope_position_deltas = []
267
+ if input_ids is not None and (
268
+ image_grid_thw is not None or video_grid_thw is not None
269
+ ):
270
+ total_input_ids = input_ids
271
+ if attention_mask is None:
272
+ attention_mask = mx.ones_like(input_ids)
273
+ position_ids = mx.ones(
274
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
275
+ )
276
+ image_index, video_index = 0, 0
277
+ for i, input_ids in enumerate(total_input_ids):
278
+ input_ids = mx.where(
279
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
280
+ )
281
+ image_nums, video_nums = 0, 0
282
+ vision_start_indices = mx.sum(
283
+ mx.where(
284
+ input_ids == vision_start_token_id,
285
+ mx.arange(input_ids.shape[0]),
286
+ mx.zeros_like(input_ids),
287
+ )
288
+ )
289
+ vision_tokens = input_ids[vision_start_indices + 1]
290
+ image_nums = (vision_tokens == image_token_id).sum().item()
291
+ video_nums = (vision_tokens == video_token_id).sum().item()
292
+ input_tokens = input_ids.tolist()
293
+ llm_pos_ids_list: list = []
294
+ st = 0
295
+ remain_images, remain_videos = image_nums, video_nums
296
+ for _ in range(image_nums + video_nums):
297
+ if image_token_id in input_tokens and remain_images > 0:
298
+ ed_image = input_tokens.index(image_token_id, st)
299
+ else:
300
+ ed_image = len(input_tokens) + 1
301
+ if video_token_id in input_tokens and remain_videos > 0:
302
+ ed_video = input_tokens.index(video_token_id, st)
303
+ else:
304
+ ed_video = len(input_tokens) + 1
305
+ if ed_image < ed_video:
306
+ t, h, w = (
307
+ image_grid_thw[image_index][0],
308
+ image_grid_thw[image_index][1],
309
+ image_grid_thw[image_index][2],
310
+ )
311
+ image_index += 1
312
+ remain_images -= 1
313
+ ed = ed_image
314
+ else:
315
+ t, h, w = (
316
+ video_grid_thw[video_index][0],
317
+ video_grid_thw[video_index][1],
318
+ video_grid_thw[video_index][2],
319
+ )
320
+ video_index += 1
321
+ remain_videos -= 1
322
+ ed = ed_video
323
+ llm_grid_t, llm_grid_h, llm_grid_w = (
324
+ t.item(),
325
+ h.item() // spatial_merge_size,
326
+ w.item() // spatial_merge_size,
327
+ )
328
+ text_len = ed - st
329
+ st_idx = (
330
+ llm_pos_ids_list[-1].max() + 1
331
+ if len(llm_pos_ids_list) > 0
332
+ else 0
333
+ )
334
+ index = mx.arange(text_len).reshape(1, text_len)
335
+ index = mx.broadcast_to(index, (3, text_len))
336
+ index = index + st_idx
337
+ llm_pos_ids_list.append(index)
338
+ t_index = mx.arange(llm_grid_t).reshape(
339
+ llm_grid_t, 1
340
+ ) # Equivalent to .view(-1, 1)
341
+ t_index = mx.broadcast_to(
342
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
343
+ ) # Equivalent to expand()
344
+ t_index = t_index.flatten() # Flattens to 1D
345
+
346
+ h_index = mx.arange(llm_grid_h).reshape(
347
+ 1, llm_grid_h, 1
348
+ ) # Equivalent to .view(1, -1)
349
+ h_index = mx.broadcast_to(
350
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
351
+ ) # Equivalent to expand()
352
+ h_index = h_index.flatten() # Flattens to 1D
353
+
354
+ w_index = mx.arange(llm_grid_w).reshape(
355
+ 1, 1, llm_grid_w
356
+ ) # Equivalent to .view(1, -1)
357
+ w_index = mx.broadcast_to(
358
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
359
+ ) # Equivalent to expand()
360
+ w_index = w_index.flatten() # Flattens to 1D
361
+
362
+ llm_pos_ids_list.append(
363
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
364
+ )
365
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
366
+ if st < len(input_tokens):
367
+ st_idx = (
368
+ llm_pos_ids_list[-1].max() + 1
369
+ if len(llm_pos_ids_list) > 0
370
+ else 0
371
+ )
372
+ text_len = len(input_tokens) - st
373
+
374
+ t_index = mx.arange(text_len).reshape(
375
+ 1, text_len
376
+ ) # Equivalent to .view(-1, 1)
377
+ t_index = mx.broadcast_to(
378
+ t_index, (3, text_len)
379
+ ) # Equivalent to expand(3, -1)
380
+
381
+ llm_pos_ids_list.append(t_index + st_idx)
382
+
383
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
384
+ mask = mx.array(attention_mask[i] == 1)
385
+ expanded_mask = mx.expand_dims(mask, axis=0)
386
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
387
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
388
+ new_positions = mx.where(
389
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
390
+ )
391
+ updated_position_ids = mx.concatenate(
392
+ [
393
+ position_ids[:, :i, :],
394
+ new_positions,
395
+ position_ids[:, i + 1 :, :],
396
+ ],
397
+ axis=1,
398
+ )
399
+ position_ids = updated_position_ids
400
+ mrope_position_deltas.append(
401
+ llm_positions.max() + 1 - len(total_input_ids[i])
402
+ )
403
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
404
+ return position_ids, mrope_position_deltas
405
+ else:
406
+ if attention_mask is not None:
407
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
408
+ position_ids = mx.where(
409
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
410
+ )
411
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
412
+ position_ids = mx.tile(position_ids, (3, 1, 1))
413
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
414
+ -1, keepdims=True
415
+ )[0]
416
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
417
+ else:
418
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
419
+ position_ids = mx.broadcast_to(
420
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
421
+ )
422
+ mrope_position_deltas = mx.zeros(
423
+ [input_ids.shape[0], 1],
424
+ dtype=input_ids.dtype,
425
+ )
426
+ return position_ids, mrope_position_deltas
427
+
428
+ def __call__(
429
+ self,
430
+ inputs: mx.array,
431
+ inputs_embeds: Optional[mx.array] = None,
432
+ mask: Optional[mx.array] = None,
433
+ cache=None,
434
+ **kwargs,
435
+ ):
436
+
437
+ position_ids = kwargs.pop("position_ids", None)
438
+ pixel_values = kwargs.pop("pixel_values", None)
439
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
440
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
441
+ # reset rope_deltas when processing a new image/video
442
+ if pixel_values is not None:
443
+ self.rope_deltas = None
444
+
445
+ if position_ids is None and (mask is None or mask.ndim == 2):
446
+ # Calculate RoPE index once per generation in the pre-fill stage only
447
+ if (
448
+ (cache is not None and cache[0] is not None and cache[0].offset == 0)
449
+ or self.rope_deltas is None
450
+ or cache is None
451
+ ):
452
+ position_ids, rope_deltas = self.get_rope_index(
453
+ inputs, image_grid_thw, video_grid_thw, mask
454
+ )
455
+ self.rope_deltas = rope_deltas
456
+ else:
457
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
458
+ batch_size, seq_length = inputs.shape
459
+ delta = cache[-1].offset + self.rope_deltas if cache is not None else 0
460
+ delta = delta[None][None]
461
+ position_ids = mx.arange(seq_length).reshape(1, seq_length)
462
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
463
+ if cache is not None:
464
+ # Repeat delta for each batch
465
+ delta = mx.repeat(delta, batch_size // delta.shape[0], axis=0)
466
+ position_ids = mx.add(position_ids, delta).reshape(position_ids.shape)
467
+ position_ids = mx.broadcast_to(
468
+ position_ids, (3, batch_size, seq_length)
469
+ )
470
+
471
+ out = self.model(
472
+ inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
473
+ )
474
+ if self.args.tie_word_embeddings:
475
+ out = self.model.embed_tokens.as_linear(out)
476
+ else:
477
+ out = self.lm_head(out)
478
+ return LanguageModelOutput(logits=out)
479
+
480
+ @property
481
+ def layers(self):
482
+ return self.model.layers
483
+
484
+ @property
485
+ def head_dim(self):
486
+ return self.args.hidden_size // self.args.num_attention_heads
487
+
488
+ @property
489
+ def n_kv_heads(self):
490
+ return self.args.num_key_value_heads