nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (196) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
  5. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
  6. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  7. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  8. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  9. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  10. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  11. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  12. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  13. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  14. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  15. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  16. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  17. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  18. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  19. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  21. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  22. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  23. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  24. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  33. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  34. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  35. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  36. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  37. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  39. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  40. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  41. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  43. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  44. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  45. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  46. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  47. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  48. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  49. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  54. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  55. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  56. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  57. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  58. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  59. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  61. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  62. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  64. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  195. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
  196. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
@@ -1,522 +0,0 @@
1
- import inspect
2
- from dataclasses import dataclass
3
- from typing import List, Optional
4
-
5
- import mlx.core as mx
6
- import mlx.nn as nn
7
-
8
- from ..kernels import bicubic_interpolate
9
-
10
-
11
- @dataclass
12
- class VisionConfig:
13
- model_type: str = "moonvit"
14
- depth: int = 27
15
- embed_dim: int = 1152
16
- hidden_size: int = 1152
17
- num_heads: int = 16
18
- image_size: int = 384
19
- patch_size: int = 14
20
- vocab_size: int = 32000
21
- mlp_ratio: float = 4.0
22
- num_channels: int = 3
23
- layer_norm_eps: float = 1e-6
24
- intermediate_size: int = 4304
25
- init_pos_emb_height: int = 64
26
- init_pos_emb_width: int = 64
27
- spatial_patch_size: int = 14
28
- spatial_merge_size: int = 2
29
- temporal_patch_size: int = 2
30
- merge_kernel_size: list[int, int] = None
31
-
32
- def __post_init__(self):
33
- if self.merge_kernel_size is None:
34
- self.merge_kernel_size = (self.spatial_merge_size, self.spatial_merge_size)
35
-
36
- @classmethod
37
- def from_dict(cls, params):
38
- return cls(
39
- **{
40
- k: v
41
- for k, v in params.items()
42
- if k in inspect.signature(cls).parameters
43
- }
44
- )
45
-
46
-
47
- def check_array_shape(arr):
48
- shape = arr.shape
49
-
50
- # Check if the shape has 4 dimensions
51
- if len(shape) != 4:
52
- return False
53
-
54
- out_channels, kH, KW, _ = shape
55
-
56
- # Check if out_channels is the largest, and kH and KW are the same
57
- if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
58
- return True
59
- else:
60
- return False
61
-
62
-
63
- def rotate_half(x):
64
- """Rotates half the hidden dims of the input."""
65
- x1 = x[..., : x.shape[-1] // 2]
66
- x2 = x[..., x.shape[-1] // 2 :]
67
- return mx.concatenate([-x2, x1], axis=-1)
68
-
69
-
70
- def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
71
- orig_dtype = tensor.dtype
72
-
73
- cos = mx.cos(freqs)
74
- sin = mx.sin(freqs)
75
-
76
- cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
77
- cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
78
- cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
79
-
80
- sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
81
- sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
82
- sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
83
-
84
- output = (tensor * cos) + (rotate_half(tensor) * sin)
85
- return output.astype(orig_dtype)
86
-
87
-
88
- class VisionRotaryEmbedding(nn.Module):
89
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
90
- super().__init__()
91
- self.dim = dim
92
- self.theta = theta
93
-
94
- def __call__(self, seqlen: int) -> mx.array:
95
- inv_freq = 1.0 / (
96
- self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
97
- )
98
- seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype)
99
- freqs = mx.outer(seq, inv_freq)
100
- return freqs
101
-
102
-
103
- class Learnable2DInterpPosEmb(nn.Module):
104
- def __init__(
105
- self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
106
- ) -> None:
107
- super().__init__()
108
- self.height = height
109
- self.width = width
110
- self.interpolation_mode = interpolation_mode
111
- self.weight = mx.ones((height, width, dim))
112
-
113
- def __call__(self, x: mx.array, grid_hws: mx.array) -> mx.array:
114
- pos_embs = []
115
- for shape in grid_hws.tolist():
116
- if shape == self.weight.shape[:-1]:
117
- pos_embs.append(self.weight.flatten(end_axis=1))
118
- else:
119
- result = (
120
- bicubic_interpolate(
121
- mx.expand_dims(self.weight.transpose(2, 0, 1), axis=0),
122
- size=shape,
123
- )
124
- .squeeze(0)
125
- .transpose(1, 2, 0)
126
- .flatten(end_axis=1)
127
- )
128
-
129
- pos_embs.append(result)
130
-
131
- out = x + mx.concatenate(pos_embs).astype(x.dtype)
132
- return out
133
-
134
-
135
- class PatchEmbed(nn.Module):
136
- def __init__(
137
- self,
138
- patch_size: int = 14,
139
- num_channels: int = 3,
140
- embed_dim: int = 1152,
141
- init_pos_emb_height: int = 64,
142
- ) -> None:
143
- super().__init__()
144
- self.patch_size = patch_size
145
- self.num_channels = num_channels
146
- self.embed_dim = embed_dim
147
- self.init_pos_emb_height = init_pos_emb_height
148
-
149
- self.proj = nn.Conv2d(
150
- num_channels,
151
- embed_dim,
152
- kernel_size=patch_size,
153
- stride=patch_size,
154
- bias=True,
155
- )
156
- self.pos_emb = Learnable2DInterpPosEmb(
157
- height=init_pos_emb_height, width=init_pos_emb_height, dim=embed_dim
158
- )
159
-
160
- def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> mx.array:
161
- hidden_states = self.proj(hidden_states).swapaxes(1, 3)
162
- hidden_states = hidden_states.reshape(hidden_states.shape[0], -1)
163
- hidden_states = self.pos_emb(hidden_states, grid_thw)
164
- return hidden_states
165
-
166
-
167
- def _apply_rope_input_validation(x, freqs_cis):
168
- assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
169
- assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
170
- assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
171
- assert freqs_cis.dtype == mx.complex64, freqs_cis.dtype
172
-
173
-
174
- def view_as_complex(x):
175
- """
176
- Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
177
- """
178
- # Get real and imaginary parts
179
- real, imag = x[..., 0], x[..., 1]
180
- # Create complex tensor
181
- return real + 1j * imag
182
-
183
-
184
- def view_as_real(x):
185
- """
186
- Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
187
- """
188
- # Get real and imaginary parts
189
- real = mx.real(x)
190
- imag = mx.imag(x)
191
- # Combine into a tensor with last dimension 2
192
- return mx.stack([real, imag], axis=-1)
193
-
194
-
195
- def apply_rope(
196
- q: mx.array, k: mx.array, freqs_cis: mx.array
197
- ) -> tuple[mx.array, mx.array]:
198
- """
199
- Args: (The leading dimensions of all inputs should be the same)
200
- q: query, array of shape (..., num_heads, head_dim)
201
- k: key, array of shape (..., num_heads, head_dim)
202
- freqs_cis: array of shape (..., head_dim/2), dtype=mx.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
203
- Returns:
204
- xq_out, xk_out: arrays of shape (..., num_heads, head_dim)
205
- """
206
- _apply_rope_input_validation(q, freqs_cis)
207
- _apply_rope_input_validation(k, freqs_cis)
208
-
209
- freqs_cis = mx.expand_dims(freqs_cis, axis=-2) # ..., 1, head_dim/2
210
- # ..., num_heads, head_dim/2
211
- q_ = view_as_complex(q.astype(mx.float32).reshape(*q.shape[:-1], -1, 2))
212
- k_ = view_as_complex(k.astype(mx.float32).reshape(*k.shape[:-1], -1, 2))
213
- q_out = view_as_real(q_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
214
- k_out = view_as_real(k_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
215
- return q_out.astype(q.dtype), k_out.astype(k.dtype)
216
-
217
-
218
- class Attention(nn.Module):
219
- def __init__(self, dim: int, num_heads: int = 16) -> None:
220
- super().__init__()
221
- self.num_heads = num_heads
222
- self.head_dim = head_dim = dim // num_heads
223
- self.scale = head_dim**-0.5
224
- self.wqkv = nn.Linear(dim, dim * 3, bias=True)
225
- self.wo = nn.Linear(dim, dim, bias=True)
226
-
227
- def __call__(
228
- self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
229
- ) -> mx.array:
230
- seq_length = x.shape[0]
231
- qkv = self.wqkv(x)
232
-
233
- qkv_shape = qkv.shape[:-1] + (
234
- 3,
235
- self.num_heads,
236
- self.head_dim,
237
- )
238
- # xqkv: (batch_size, seqlen, 3, nheads, headdim)
239
- qkv = qkv.reshape(*qkv_shape)
240
-
241
- q, k, v = mx.split(qkv, 3, axis=1)
242
- q = q.squeeze(1)
243
- k = k.squeeze(1)
244
- v = v.squeeze(1)
245
-
246
- q, k = apply_rope(q, k, rotary_pos_emb)
247
-
248
- attention_mask = mx.zeros((1, seq_length, seq_length), dtype=x.dtype)
249
-
250
- # Create attention mask for each sequence in the batch
251
- for i in range(1, len(cu_seqlens)):
252
- start = int(cu_seqlens[i - 1])
253
- end = int(cu_seqlens[i])
254
- attention_mask[..., start:end, start:end] = 1
255
-
256
- q = q.transpose(1, 0, 2)
257
- k = k.transpose(1, 0, 2)
258
- v = v.transpose(1, 0, 2)
259
-
260
- attn_weight = q @ k.swapaxes(-2, -1) / mx.sqrt(q.shape[-1])
261
- attn_weight += attention_mask
262
- attn_weight = mx.softmax(attn_weight, axis=-1).astype(q.dtype)
263
-
264
- attn_output = attn_weight @ v
265
- attn_output = attn_output.transpose(1, 0, 2)
266
- attn_output = attn_output.reshape(seq_length, -1)
267
- return self.wo(attn_output)
268
-
269
-
270
- class MLP(nn.Module):
271
- def __init__(self, dim, hidden_dim):
272
- super().__init__()
273
- self.activation_fn = nn.GELU()
274
- self.fc0 = nn.Linear(dim, hidden_dim)
275
- self.fc1 = nn.Linear(hidden_dim, dim)
276
-
277
- def __call__(self, x: mx.array) -> mx.array:
278
- x = self.activation_fn(self.fc0(x))
279
- x = self.fc1(x)
280
- return x
281
-
282
-
283
- class Qwen2VLVisionBlock(nn.Module):
284
- def __init__(self, config: VisionConfig) -> None:
285
- super().__init__()
286
- self.norm0 = nn.LayerNorm(config.embed_dim, eps=1e-6)
287
- self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
288
-
289
- self.attn = Attention(dim=config.embed_dim, num_heads=config.num_heads)
290
- self.mlp = MLP(dim=config.embed_dim, hidden_dim=config.intermediate_size)
291
-
292
- def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
293
- hidden_states = hidden_states + self.attn(
294
- self.norm0(hidden_states),
295
- cu_seqlens=cu_seqlens,
296
- rotary_pos_emb=rotary_pos_emb,
297
- )
298
- hidden_states = hidden_states + self.mlp(self.norm1(hidden_states))
299
- return hidden_states
300
-
301
-
302
- class Rope2DPosEmb(nn.Module):
303
- """2D rotary position embedding with multi-resolution support.
304
-
305
- This class is intended to be used in the following way:
306
- 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
307
- 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
308
- 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
309
- The rope is shared across all attention layers and all heads.
310
-
311
- Refs:
312
- - RoFormer: https://arxiv.org/abs/2104.09864
313
- - VisionLLaMA: https://arxiv.org/abs/2403.00522
314
- - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
315
-
316
- Args:
317
- dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
318
- max_height (int): the maximum height of the 2D grid
319
- max_width (int): the maximum width of the 2D grid
320
- theta_base (float): the base of the theta
321
- """
322
-
323
- def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
324
- super().__init__()
325
- self.dim = dim
326
- assert self.dim % 4 == 0, "dim must be divisible by 4"
327
- self.max_height = max_height
328
- self.max_width = max_width
329
- self.theta_base = theta_base
330
-
331
- self._freqs_cis = None
332
-
333
- def extra_repr(self):
334
- return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
335
-
336
- def _precompute_freqs_cis(self) -> mx.array:
337
- """Calculate the cis(freqs) for each position in the 2D grid.
338
-
339
- Return: complex array of shape (max_height, max_width, dim//2) and value:
340
- height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
341
- weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
342
- note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
343
- """
344
- N = self.max_height * self.max_width
345
- flat_pos = mx.arange(0, N, dtype=mx.float32)
346
- x_pos = flat_pos % self.max_width
347
- y_pos = flat_pos // self.max_width
348
- dim_range = mx.arange(0, self.dim, 4)[: (self.dim // 4)].astype(
349
- mx.float32
350
- ) # C/4
351
- freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
352
- x_freqs = mx.outer(x_pos, freqs) # N, C/4
353
- y_freqs = mx.outer(y_pos, freqs) # N, C/4
354
-
355
- # Create complex numbers using cos and sin
356
- x_cos = mx.cos(x_freqs)
357
- x_sin = mx.sin(x_freqs)
358
- y_cos = mx.cos(y_freqs)
359
- y_sin = mx.sin(y_freqs)
360
-
361
- # Create complex numbers
362
- x_cis = x_cos + 1j * x_sin # N, C/4
363
- y_cis = y_cos + 1j * y_sin # N, C/4
364
-
365
- # N, C/4, 2
366
- freqs_cis = mx.stack([x_cis, y_cis], axis=-1)
367
-
368
- # max_height, max_width, C/2
369
- freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
370
- return freqs_cis
371
-
372
- def get_freqs_cis(self, grid_hws: mx.array) -> mx.array:
373
- """
374
- Args:
375
- grid_hws (mx.array): grid height and width
376
-
377
- Returns:
378
- freqs_cis: array of shape (sum(t * height * width), dim//2)
379
- """
380
- if self._freqs_cis is None:
381
- self._freqs_cis = self._precompute_freqs_cis()
382
-
383
- shapes = grid_hws.tolist()
384
- assert all(
385
- 1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
386
- ), (
387
- shapes,
388
- self.max_height,
389
- self.max_width,
390
- )
391
-
392
- freqs_cis_list = []
393
- for h, w in shapes:
394
- # Get the slice of precomputed frequencies for this shape
395
- shape_freqs = self._freqs_cis[:h, :w]
396
- # Reshape to flatten the spatial dimensions
397
- shape_freqs = shape_freqs.reshape(-1, self.dim // 2)
398
- freqs_cis_list.append(shape_freqs)
399
-
400
- freqs_cis = mx.concatenate(freqs_cis_list, axis=0)
401
- return freqs_cis
402
-
403
-
404
- def patch_merger(
405
- x: mx.array,
406
- grid_hws: mx.array,
407
- merge_kernel_size: list[int, int] = (2, 2),
408
- ) -> List[mx.array]:
409
- d_model = x.shape[-1]
410
-
411
- outputs = []
412
- pre_sum = 0
413
- for x_shape in grid_hws.tolist():
414
- height, width = x_shape[0], x_shape[1]
415
- # Get the current sequence
416
- seq = x[pre_sum : pre_sum + height * width]
417
- # Reshape along self.merge_kernel_size and concat to the last dimension
418
- kernel_height, kernel_width = merge_kernel_size
419
- new_height, new_width = height // kernel_height, width // kernel_width
420
- reshaped_seq = seq.reshape(
421
- new_height, kernel_height, new_width, kernel_width, d_model
422
- )
423
- reshaped_seq = mx.transpose(reshaped_seq, (0, 2, 1, 3, 4))
424
- padded_seq = reshaped_seq.reshape(
425
- new_height * new_width, kernel_height * kernel_width, -1
426
- )
427
- outputs.append(padded_seq)
428
- pre_sum += height * width
429
-
430
- return outputs
431
-
432
-
433
- class VisionModel(nn.Module):
434
-
435
- def __init__(self, config: VisionConfig) -> None:
436
- super().__init__()
437
- self.config = config
438
- self.model_type = config.model_type
439
- if self.model_type not in ["qwen2_vl", "moonvit"]:
440
- raise ValueError(f"Unsupported model type: {self.model_type}")
441
- self.spatial_merge_size = config.spatial_merge_size
442
- self.merge_kernel_size = config.merge_kernel_size
443
-
444
- self.patch_embed = PatchEmbed(
445
- patch_size=config.patch_size,
446
- num_channels=config.num_channels,
447
- embed_dim=config.embed_dim,
448
- init_pos_emb_height=config.init_pos_emb_height,
449
- )
450
-
451
- head_dim = config.embed_dim // config.num_heads
452
- self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
453
-
454
- self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
455
- self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=1e-6)
456
- self.rope_pos_emb = Rope2DPosEmb(head_dim, 512, 512)
457
-
458
- def __call__(
459
- self,
460
- hidden_states: mx.array,
461
- grid_thw: mx.array,
462
- output_hidden_states: Optional[bool] = None,
463
- ) -> mx.array:
464
-
465
- hidden_states = self.patch_embed(hidden_states, grid_thw)
466
- rotary_pos_emb = self.rope_pos_emb.get_freqs_cis(grid_thw)
467
-
468
- # Assuming grid_thw has shape (batch_size, 3)
469
- batch_size = grid_thw.shape[0]
470
-
471
- # Calculate cu_seqlens for each item in the batch
472
- lengths = mx.concatenate(
473
- (
474
- mx.zeros((1,), dtype=grid_thw.dtype),
475
- grid_thw[:, 0] * grid_thw[:, 1],
476
- )
477
- )
478
- cu_seqlens = mx.cumsum(lengths.astype(mx.int32), axis=0)
479
-
480
- encoder_states = (hidden_states,) if output_hidden_states else None
481
-
482
- for blk in self.blocks:
483
- hidden_states = blk(
484
- hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
485
- )
486
- if output_hidden_states:
487
- encoder_states = encoder_states + (hidden_states,)
488
-
489
- hidden_states = self.final_layernorm(hidden_states)
490
-
491
- hidden_states = patch_merger(
492
- hidden_states, grid_thw, merge_kernel_size=self.merge_kernel_size
493
- )
494
-
495
- return hidden_states
496
-
497
- def sanitize(self, weights):
498
- sanitized_weights = {}
499
- for k, v in weights.items():
500
- if "position_ids" in k:
501
- # Remove unused position_ids
502
- continue
503
- elif "patch_embed.proj.weight" in k:
504
- # PyTorch conv2d weight tensors have shape:
505
- # [out_channels, in_channels, kH, KW]
506
- # MLX conv2d expects the weight be of shape:
507
- # [out_channels, kH, KW, in_channels]
508
- if check_array_shape(v):
509
- sanitized_weights[k] = v
510
- else:
511
- sanitized_weights[k] = v.transpose(0, 2, 3, 1)
512
-
513
- elif "vision_tower.blocks" in k:
514
- if "attn" not in k and ("wqkv" in k or "wo" in k):
515
- new_key = k.replace("wqkv", "attn.wqkv").replace("wo", "attn.wo")
516
- sanitized_weights[new_key] = v
517
- else:
518
- sanitized_weights[k] = v
519
- else:
520
- sanitized_weights[k] = v
521
-
522
- return sanitized_weights
@@ -1,8 +0,0 @@
1
- from .llama4 import (
2
- LanguageModel,
3
- Model,
4
- ModelConfig,
5
- TextConfig,
6
- VisionConfig,
7
- VisionModel,
8
- )