nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (196) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
  5. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
  6. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  7. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  8. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  9. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  10. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  11. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  12. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  13. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  14. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  15. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  16. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  17. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  18. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  19. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  21. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  22. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  23. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  24. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  33. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  34. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  35. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  36. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  37. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  39. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  40. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  41. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  43. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  44. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  45. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  46. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  47. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  48. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  49. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  54. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  55. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  56. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  57. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  58. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  59. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  61. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  62. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  64. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  195. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
  196. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
@@ -1,490 +0,0 @@
1
- from typing import Optional
2
-
3
- import mlx.core as mx
4
- import mlx.nn as nn
5
- import numpy as np
6
-
7
- from ..base import (
8
- LanguageModelOutput,
9
- create_attention_mask,
10
- scaled_dot_product_attention,
11
- )
12
- from ..cache import KVCache
13
- from .config import ModelConfig, TextConfig
14
-
15
-
16
- class Qwen2RotaryEmbedding:
17
- def __init__(self, dim, max_position_embeddings=2048, base=10000):
18
- self.dim = dim
19
- self.max_position_embeddings = max_position_embeddings
20
- self.base = base
21
-
22
- inv_freq = 1.0 / (
23
- self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
24
- )
25
- self.inv_freq = inv_freq
26
-
27
- self._set_cos_sin_cache(seq_len=max_position_embeddings)
28
-
29
- def _set_cos_sin_cache(self, seq_len):
30
- self.max_seq_len_cached = seq_len
31
- t = mx.arange(self.max_seq_len_cached).astype(mx.float32)
32
-
33
- freqs = mx.outer(t, self.inv_freq)
34
- emb = mx.concatenate((freqs, freqs), axis=-1)
35
- self.cos_cached = mx.cos(emb)
36
- self.sin_cached = mx.sin(emb)
37
-
38
- def __call__(self, x, seq_len=None):
39
-
40
- if seq_len > self.max_seq_len_cached:
41
- self._set_cos_sin_cache(seq_len=seq_len)
42
-
43
- return (
44
- self.cos_cached[:seq_len].astype(x.dtype),
45
- self.sin_cached[:seq_len].astype(x.dtype),
46
- )
47
-
48
-
49
- def rotate_half(x):
50
- """Rotates half the hidden dims of the input."""
51
- x1 = x[..., : x.shape[-1] // 2]
52
- x2 = x[..., x.shape[-1] // 2 :]
53
- return mx.concatenate([-x2, x1], axis=-1)
54
-
55
-
56
- def apply_multimodal_rotary_pos_emb(q, k, cos, sin, position_ids, mrope_section):
57
- """
58
- Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
59
- Args:
60
- q (mx.array): The query tensor.
61
- k (mx.array): The key tensor.
62
- cos (mx.array): The cosine part of the rotary embedding.
63
- sin (mx.array): The sine part of the rotary embedding.
64
- mrope_section (List[int]): Multimodal rope section for channel dimension of temporal, height and width.
65
- unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
66
- Returns:
67
- tuple(mx.array): The rotated query and key tensors.
68
- """
69
-
70
- mrope_section = np.cumsum(mrope_section * 2)[:-1].tolist()
71
- cos = cos[position_ids]
72
- sin = sin[position_ids]
73
-
74
- cos = mx.concatenate(
75
- [m[i % 3] for i, m in enumerate(mx.split(cos, mrope_section, axis=-1))], axis=-1
76
- )[
77
- :, None, :, :
78
- ] # unsqueeze dim 1
79
- sin = mx.concatenate(
80
- [m[i % 3] for i, m in enumerate(mx.split(sin, mrope_section, axis=-1))], axis=-1
81
- )[:, None, :, :]
82
-
83
- # Apply rotary embedding
84
- q_embed = (q * cos) + (rotate_half(q) * sin)
85
- k_embed = (k * cos) + (rotate_half(k) * sin)
86
-
87
- return q_embed, k_embed
88
-
89
-
90
- class Attention(nn.Module):
91
- def __init__(self, args: TextConfig):
92
- super().__init__()
93
-
94
- dim = args.hidden_size
95
- self.n_heads = n_heads = args.num_attention_heads
96
- assert args.num_key_value_heads is not None
97
- self.n_kv_heads = n_kv_heads = args.num_key_value_heads
98
-
99
- self.head_dim = head_dim = args.hidden_size // n_heads
100
- self.scale = head_dim**-0.5
101
-
102
- self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
103
- self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
104
- self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
105
- self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
106
-
107
- self.rope_scaling = args.rope_scaling
108
-
109
- self.rotary_emb = Qwen2RotaryEmbedding(
110
- head_dim,
111
- max_position_embeddings=args.max_position_embeddings,
112
- base=args.rope_theta,
113
- )
114
-
115
- def __call__(
116
- self,
117
- x: mx.array,
118
- mask: Optional[mx.array] = None,
119
- cache: Optional[KVCache] = None,
120
- position_ids: Optional[mx.array] = None,
121
- ) -> mx.array:
122
- B, L, D = x.shape
123
-
124
- queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
125
-
126
- # Prepare the queries, keys and values for the attention computation
127
- queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
128
- 0, 2, 1, 3
129
- )
130
- keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
131
- values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
132
- 0, 2, 1, 3
133
- )
134
-
135
- kv_seq_len = keys.shape[-2]
136
-
137
- if position_ids is None:
138
- kv_seq_len += cache.offset + 1
139
- position_ids = mx.arange(cache.offset, cache.offset + L)
140
- position_ids = mx.expand_dims(position_ids, axis=0)
141
- position_ids = mx.tile(position_ids, (3, 1, 1))
142
- else:
143
- kv_seq_len += cache.offset + 1 if cache is not None else 0
144
-
145
- cos, sin = self.rotary_emb(values, kv_seq_len)
146
-
147
- if mask is not None and isinstance(mask, mx.array):
148
- mask = mask[..., : keys.shape[-2]]
149
- queries, keys = apply_multimodal_rotary_pos_emb(
150
- queries, keys, cos, sin, position_ids, self.rope_scaling["mrope_section"]
151
- )
152
-
153
- if cache is not None:
154
- keys, values = cache.update_and_fetch(keys, values)
155
-
156
- output = scaled_dot_product_attention(
157
- queries, keys, values, cache, scale=self.scale, mask=mask
158
- )
159
- output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
160
- return self.o_proj(output)
161
-
162
-
163
- class MLP(nn.Module):
164
- def __init__(self, dim, hidden_dim):
165
- super().__init__()
166
- self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
167
- self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
168
- self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
169
-
170
- def __call__(self, x) -> mx.array:
171
- return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
172
-
173
-
174
- class Qwen2VLDecoderLayer(nn.Module):
175
- def __init__(self, args: TextConfig):
176
- super().__init__()
177
- self.num_attention_heads = args.num_attention_heads
178
- self.hidden_size = args.hidden_size
179
- self.self_attn = Attention(args)
180
- self.mlp = MLP(args.hidden_size, args.intermediate_size)
181
- self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
182
- self.post_attention_layernorm = nn.RMSNorm(
183
- args.hidden_size, eps=args.rms_norm_eps
184
- )
185
- self.args = args
186
-
187
- def __call__(
188
- self,
189
- x: mx.array,
190
- mask: Optional[mx.array] = None,
191
- cache: Optional[KVCache] = None,
192
- position_ids: Optional[mx.array] = None,
193
- ) -> mx.array:
194
- r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
195
- h = x + r
196
- r = self.mlp(self.post_attention_layernorm(h))
197
- out = h + r
198
- return out
199
-
200
-
201
- class Qwen2Model(nn.Module):
202
- def __init__(self, args: TextConfig):
203
- super().__init__()
204
- self.args = args
205
- self.vocab_size = args.vocab_size
206
- self.num_hidden_layers = args.num_hidden_layers
207
- assert self.vocab_size > 0
208
- self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
209
- self.layers = [
210
- Qwen2VLDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
211
- ]
212
- self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
213
-
214
- def __call__(
215
- self,
216
- inputs: mx.array,
217
- inputs_embeds: Optional[mx.array] = None,
218
- mask: Optional[mx.array] = None,
219
- cache=None,
220
- position_ids: Optional[mx.array] = None,
221
- ):
222
- if inputs_embeds is None:
223
- h = self.embed_tokens(inputs)
224
- else:
225
- h = inputs_embeds
226
-
227
- if cache is None:
228
- cache = [None] * len(self.layers)
229
-
230
- if mask is None:
231
- mask = create_attention_mask(h, cache)
232
-
233
- for layer, c in zip(self.layers, cache):
234
- h = layer(h, mask, c, position_ids)
235
-
236
- return self.norm(h)
237
-
238
-
239
- class LanguageModel(nn.Module):
240
- def __init__(self, args: TextConfig, config: ModelConfig):
241
- super().__init__()
242
- self.args = args
243
- self.config = config
244
- self.model_type = args.model_type
245
- self.model = Qwen2Model(args)
246
- self.rope_deltas = None
247
-
248
- if not args.tie_word_embeddings:
249
- self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
250
-
251
- def get_rope_index(
252
- self,
253
- input_ids: mx.array,
254
- image_grid_thw: Optional[mx.array] = None,
255
- video_grid_thw: Optional[mx.array] = None,
256
- attention_mask: Optional[mx.array] = None,
257
- ):
258
- # Calculate RoPE index for image/video tokens
259
- batch_size, seq_length = input_ids.shape
260
- position_ids = mx.arange(seq_length, dtype=mx.int32)
261
- position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
262
- spatial_merge_size = self.config.vision_config.spatial_merge_size
263
- image_token_id = self.config.image_token_id
264
- video_token_id = self.config.video_token_id
265
- vision_start_token_id = self.config.vision_start_token_id
266
- mrope_position_deltas = []
267
- if input_ids is not None and (
268
- image_grid_thw is not None or video_grid_thw is not None
269
- ):
270
- total_input_ids = input_ids
271
- if attention_mask is None:
272
- attention_mask = mx.ones_like(input_ids)
273
- position_ids = mx.ones(
274
- (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
275
- )
276
- image_index, video_index = 0, 0
277
- for i, input_ids in enumerate(total_input_ids):
278
- input_ids = mx.where(
279
- attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
280
- )
281
- image_nums, video_nums = 0, 0
282
- vision_start_indices = mx.sum(
283
- mx.where(
284
- input_ids == vision_start_token_id,
285
- mx.arange(input_ids.shape[0]),
286
- mx.zeros_like(input_ids),
287
- )
288
- )
289
- vision_tokens = input_ids[vision_start_indices + 1]
290
- image_nums = (vision_tokens == image_token_id).sum().item()
291
- video_nums = (vision_tokens == video_token_id).sum().item()
292
- input_tokens = input_ids.tolist()
293
- llm_pos_ids_list: list = []
294
- st = 0
295
- remain_images, remain_videos = image_nums, video_nums
296
- for _ in range(image_nums + video_nums):
297
- if image_token_id in input_tokens and remain_images > 0:
298
- ed_image = input_tokens.index(image_token_id, st)
299
- else:
300
- ed_image = len(input_tokens) + 1
301
- if video_token_id in input_tokens and remain_videos > 0:
302
- ed_video = input_tokens.index(video_token_id, st)
303
- else:
304
- ed_video = len(input_tokens) + 1
305
- if ed_image < ed_video:
306
- t, h, w = (
307
- image_grid_thw[image_index][0],
308
- image_grid_thw[image_index][1],
309
- image_grid_thw[image_index][2],
310
- )
311
- image_index += 1
312
- remain_images -= 1
313
- ed = ed_image
314
- else:
315
- t, h, w = (
316
- video_grid_thw[video_index][0],
317
- video_grid_thw[video_index][1],
318
- video_grid_thw[video_index][2],
319
- )
320
- video_index += 1
321
- remain_videos -= 1
322
- ed = ed_video
323
- llm_grid_t, llm_grid_h, llm_grid_w = (
324
- t.item(),
325
- h.item() // spatial_merge_size,
326
- w.item() // spatial_merge_size,
327
- )
328
- text_len = ed - st
329
- st_idx = (
330
- llm_pos_ids_list[-1].max() + 1
331
- if len(llm_pos_ids_list) > 0
332
- else 0
333
- )
334
- index = mx.arange(text_len).reshape(1, text_len)
335
- index = mx.broadcast_to(index, (3, text_len))
336
- index = index + st_idx
337
- llm_pos_ids_list.append(index)
338
- t_index = mx.arange(llm_grid_t).reshape(
339
- llm_grid_t, 1
340
- ) # Equivalent to .view(-1, 1)
341
- t_index = mx.broadcast_to(
342
- t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
343
- ) # Equivalent to expand()
344
- t_index = t_index.flatten() # Flattens to 1D
345
-
346
- h_index = mx.arange(llm_grid_h).reshape(
347
- 1, llm_grid_h, 1
348
- ) # Equivalent to .view(1, -1)
349
- h_index = mx.broadcast_to(
350
- h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
351
- ) # Equivalent to expand()
352
- h_index = h_index.flatten() # Flattens to 1D
353
-
354
- w_index = mx.arange(llm_grid_w).reshape(
355
- 1, 1, llm_grid_w
356
- ) # Equivalent to .view(1, -1)
357
- w_index = mx.broadcast_to(
358
- w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
359
- ) # Equivalent to expand()
360
- w_index = w_index.flatten() # Flattens to 1D
361
-
362
- llm_pos_ids_list.append(
363
- mx.stack([t_index, h_index, w_index]) + text_len + st_idx
364
- )
365
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
366
- if st < len(input_tokens):
367
- st_idx = (
368
- llm_pos_ids_list[-1].max() + 1
369
- if len(llm_pos_ids_list) > 0
370
- else 0
371
- )
372
- text_len = len(input_tokens) - st
373
-
374
- t_index = mx.arange(text_len).reshape(
375
- 1, text_len
376
- ) # Equivalent to .view(-1, 1)
377
- t_index = mx.broadcast_to(
378
- t_index, (3, text_len)
379
- ) # Equivalent to expand(3, -1)
380
-
381
- llm_pos_ids_list.append(t_index + st_idx)
382
-
383
- llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
384
- mask = mx.array(attention_mask[i] == 1)
385
- expanded_mask = mx.expand_dims(mask, axis=0)
386
- expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
387
- expanded_positions = mx.expand_dims(llm_positions, axis=1)
388
- new_positions = mx.where(
389
- expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
390
- )
391
- updated_position_ids = mx.concatenate(
392
- [
393
- position_ids[:, :i, :],
394
- new_positions,
395
- position_ids[:, i + 1 :, :],
396
- ],
397
- axis=1,
398
- )
399
- position_ids = updated_position_ids
400
- mrope_position_deltas.append(
401
- llm_positions.max() + 1 - len(total_input_ids[i])
402
- )
403
- mrope_position_deltas = mx.array(mrope_position_deltas)[0]
404
- return position_ids, mrope_position_deltas
405
- else:
406
- if attention_mask is not None:
407
- position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
408
- position_ids = mx.where(
409
- attention_mask == 0, mx.ones_like(position_ids), position_ids
410
- )
411
- position_ids = mx.expand_dims(position_ids[0], axis=0)
412
- position_ids = mx.tile(position_ids, (3, 1, 1))
413
- max_position_ids = position_ids.max(0, keepdims=False)[0].max(
414
- -1, keepdims=True
415
- )[0]
416
- mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
417
- else:
418
- position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
419
- position_ids = mx.broadcast_to(
420
- position_ids, (3, input_ids.shape[0], input_ids.shape[1])
421
- )
422
- mrope_position_deltas = mx.zeros(
423
- [input_ids.shape[0], 1],
424
- dtype=input_ids.dtype,
425
- )
426
- return position_ids, mrope_position_deltas
427
-
428
- def __call__(
429
- self,
430
- inputs: mx.array,
431
- inputs_embeds: Optional[mx.array] = None,
432
- mask: Optional[mx.array] = None,
433
- cache=None,
434
- **kwargs,
435
- ):
436
-
437
- position_ids = kwargs.pop("position_ids", None)
438
- pixel_values = kwargs.pop("pixel_values", None)
439
- image_grid_thw = kwargs.pop("image_grid_thw", None)
440
- video_grid_thw = kwargs.pop("video_grid_thw", None)
441
- # reset rope_deltas when processing a new image/video
442
- if pixel_values is not None:
443
- self.rope_deltas = None
444
-
445
- if position_ids is None and (mask is None or mask.ndim == 2):
446
- # Calculate RoPE index once per generation in the pre-fill stage only
447
- if (
448
- (cache is not None and cache[0] is not None and cache[0].offset == 0)
449
- or self.rope_deltas is None
450
- or cache is None
451
- ):
452
- position_ids, rope_deltas = self.get_rope_index(
453
- inputs, image_grid_thw, video_grid_thw, mask
454
- )
455
- self.rope_deltas = rope_deltas
456
- else:
457
- # Use the prev pre-calculated rope-deltas to get the correct position ids
458
- batch_size, seq_length = inputs.shape
459
- delta = cache[-1].offset + self.rope_deltas if cache is not None else 0
460
- delta = delta[None][None]
461
- position_ids = mx.arange(seq_length).reshape(1, seq_length)
462
- position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
463
- if cache is not None:
464
- # Repeat delta for each batch
465
- delta = mx.repeat(delta, batch_size // delta.shape[0], axis=0)
466
- position_ids = mx.add(position_ids, delta).reshape(position_ids.shape)
467
- position_ids = mx.broadcast_to(
468
- position_ids, (3, batch_size, seq_length)
469
- )
470
-
471
- out = self.model(
472
- inputs, cache=cache, inputs_embeds=inputs_embeds, position_ids=position_ids
473
- )
474
- if self.args.tie_word_embeddings:
475
- out = self.model.embed_tokens.as_linear(out)
476
- else:
477
- out = self.lm_head(out)
478
- return LanguageModelOutput(logits=out)
479
-
480
- @property
481
- def layers(self):
482
- return self.model.layers
483
-
484
- @property
485
- def head_dim(self):
486
- return self.args.hidden_size // self.args.num_attention_heads
487
-
488
- @property
489
- def n_kv_heads(self):
490
- return self.args.num_key_value_heads
@@ -1,167 +0,0 @@
1
- import glob
2
- import inspect
3
- import json
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from typing import List, Optional
7
-
8
- import mlx.core as mx
9
- import mlx.nn as nn
10
- import numpy as np
11
- from huggingface_hub import snapshot_download
12
-
13
- from .config import ModelConfig, TextConfig, VisionConfig
14
- from .language import LanguageModel
15
- from .vision import VisionModel
16
-
17
-
18
- class Model(nn.Module):
19
- def __init__(self, config: ModelConfig):
20
- super().__init__()
21
- self.config = config
22
- self.vision_tower = VisionModel(config.vision_config)
23
- self.language_model = LanguageModel(config.text_config, config)
24
-
25
- def get_input_embeddings(
26
- self,
27
- input_ids: Optional[mx.array] = None,
28
- pixel_values: Optional[mx.array] = None,
29
- grid_thw: Optional[mx.array] = None,
30
- ):
31
-
32
- if pixel_values is None:
33
- return self.language_model.model.embed_tokens(input_ids)
34
-
35
- dtype = self.vision_tower.patch_embed.proj.weight.dtype
36
- pixel_values = pixel_values.astype(dtype)
37
-
38
- # Get the input embeddings from the language model
39
- inputs_embeds = self.language_model.model.embed_tokens(input_ids)
40
-
41
- # Get the ouptut hidden states from the vision model
42
- hidden_states = self.vision_tower(
43
- pixel_values, grid_thw, output_hidden_states=False
44
- )
45
-
46
- # Insert special image tokens in the input_ids
47
- final_inputs_embeds = self.merge_input_ids_with_image_features(
48
- self.config.image_token_id,
49
- self.config.video_token_id,
50
- hidden_states,
51
- inputs_embeds,
52
- input_ids,
53
- )
54
- return final_inputs_embeds
55
-
56
- @staticmethod
57
- def merge_input_ids_with_image_features(
58
- image_token_id,
59
- video_token_id,
60
- image_features,
61
- inputs_embeds,
62
- input_ids,
63
- ):
64
- """Merge image features into input embeddings at image token positions.
65
-
66
- Args:
67
- image_features: Vision features from the vision tower [num_features, hidden_dim]
68
- inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
69
- input_ids: Input token IDs [batch_size, seq_len]
70
-
71
- Returns:
72
- Updated input embeddings with image features inserted
73
- """
74
-
75
- # Positions of <image> tokens in input_ids
76
- image_positions = input_ids == image_token_id
77
- if mx.sum(image_positions) == 0:
78
- image_positions = input_ids == video_token_id
79
-
80
- # Get dimensions
81
- batch_size, seq_len = input_ids.shape
82
-
83
- # Process each batch item
84
- batch_outputs = []
85
- feature_start_idx = 0
86
-
87
- for batch_idx in range(batch_size):
88
- # Get mask for this batch
89
- image_mask = image_positions[batch_idx]
90
- num_positions = mx.sum(image_mask).item()
91
-
92
- if num_positions > 0:
93
- # Extract features for this batch
94
- batch_features = image_features[
95
- feature_start_idx : feature_start_idx + num_positions
96
- ]
97
-
98
- # Validate we have the right number of features
99
- if batch_features.shape[0] != num_positions:
100
- raise ValueError(
101
- f"Number of image token positions ({num_positions}) does not match "
102
- f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
103
- )
104
-
105
- # Create indices for gathering
106
- cumsum = mx.cumsum(image_mask.astype(mx.int32))
107
- feature_indices = mx.where(image_mask, cumsum - 1, 0)
108
-
109
- # Gather features
110
- gathered_features = batch_features[feature_indices]
111
-
112
- # Combine with original embeddings
113
- image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
114
- batch_output = mx.where(
115
- image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
116
- )
117
-
118
- feature_start_idx += num_positions
119
- else:
120
- # No image tokens in this batch item
121
- batch_output = inputs_embeds[batch_idx]
122
-
123
- batch_outputs.append(batch_output)
124
-
125
- # Stack all batch outputs
126
- return mx.stack(batch_outputs, axis=0)
127
-
128
- @property
129
- def layers(self):
130
- return self.language_model.model.layers
131
-
132
- def __call__(
133
- self,
134
- input_ids: mx.array,
135
- pixel_values: Optional[mx.array] = None,
136
- mask: Optional[mx.array] = None,
137
- cache=None,
138
- **kwargs,
139
- ):
140
-
141
- image_grid_thw = kwargs.pop("image_grid_thw", None)
142
- video_grid_thw = kwargs.pop("video_grid_thw", None)
143
- grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
144
- input_embddings = self.get_input_embeddings(input_ids, pixel_values, grid_thw)
145
- kwargs = {
146
- "pixel_values": pixel_values,
147
- "image_grid_thw": image_grid_thw,
148
- "video_grid_thw": video_grid_thw,
149
- **kwargs,
150
- }
151
- logits = self.language_model(
152
- input_ids, input_embddings, mask=mask, cache=cache, **kwargs
153
- )
154
- return logits
155
-
156
- def sanitize(self, weights):
157
- def transform_key(key):
158
- if "vision_tower" not in key:
159
- key = key.replace("visual", "vision_tower")
160
- if "language_model" not in key:
161
- if "model" in key:
162
- key = key.replace("model", "language_model.model")
163
- elif "lm_head" in key:
164
- key = key.replace("lm_head", "language_model.lm_head")
165
- return key
166
-
167
- return {transform_key(k): v for k, v in weights.items()}