nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (196) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
  5. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
  6. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  7. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  8. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  9. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  10. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  11. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  12. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  13. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  14. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  15. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  16. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  17. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  18. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  19. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  21. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  22. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  23. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  24. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  33. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  34. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  35. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  36. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  37. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  39. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  40. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  41. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  43. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  44. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  45. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  46. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  47. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  48. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  49. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  54. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  55. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  56. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  57. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  58. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  59. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  61. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  62. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  64. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  195. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
  196. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
@@ -1,560 +0,0 @@
1
- import inspect
2
- import math
3
- from dataclasses import dataclass
4
- from typing import Optional, Tuple
5
-
6
- import mlx.core as mx
7
- import mlx.nn as nn
8
-
9
- from ..base import pixel_shuffle
10
-
11
-
12
- @dataclass
13
- class VisionConfig:
14
- model_type: str
15
- hidden_size: int
16
- image_size: int
17
- initializer_range: float
18
- intermediate_size: int
19
- norm_eps: float
20
- num_attention_heads: int
21
- num_channels: int
22
- num_hidden_layers: int
23
- patch_size: int
24
- pixel_shuffle_ratio: float
25
- projector_dropout: float
26
- projector_input_dim: int
27
- projector_output_dim: int
28
- rope_theta: float
29
- vision_feature_layer: int
30
- vision_feature_select_strategy: str
31
- vision_output_dim: int
32
-
33
- @classmethod
34
- def from_dict(cls, params):
35
- return cls(
36
- **{
37
- k: v
38
- for k, v in params.items()
39
- if k in inspect.signature(cls).parameters
40
- }
41
- )
42
-
43
-
44
- def check_array_shape(arr):
45
- shape = arr.shape
46
-
47
- # Check if the shape has 4 dimensions
48
- if len(shape) != 4:
49
- return False
50
-
51
- out_channels, kH, KW, _ = shape
52
-
53
- # Check if out_channels is the largest, and kH and KW are the same
54
- if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
55
- return True
56
- else:
57
- return False
58
-
59
-
60
- class Llama4MultiModalProjector(nn.Module):
61
- def __init__(self, config):
62
- super().__init__()
63
- self.linear_1 = nn.Linear(
64
- config.vision_config.vision_output_dim,
65
- config.text_config.hidden_size,
66
- bias=False,
67
- )
68
-
69
- def __call__(self, image_features):
70
- hidden_states = self.linear_1(image_features)
71
- return hidden_states
72
-
73
-
74
- class Llama4VisionPixelShuffleMLP(nn.Module):
75
- def __init__(self, config):
76
- super().__init__()
77
- self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
78
- self.inner_dim = int(
79
- config.projector_input_dim // (self.pixel_shuffle_ratio**2)
80
- )
81
- self.output_dim = config.projector_output_dim
82
- self.mlp = Llama4VisionMLP(config, bias=False, is_projector=True)
83
-
84
- def __call__(self, encoded_patches: mx.array) -> mx.array:
85
- encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
86
- return self.mlp(encoded_patches)
87
-
88
-
89
- # TODO there is a different RoPE for vision encoder, defined as below
90
- def reshape_for_broadcast(freqs_ci: mx.array, query: mx.array):
91
- ndim = query.ndim
92
- shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
93
- return freqs_ci.reshape(*shape)
94
-
95
-
96
- def view_as_complex(x):
97
- """
98
- Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
99
-
100
- Args:
101
- x: A real tensor with last dimension of size 2.
102
-
103
- Returns:
104
- A complex tensor with size one less than the input.
105
- """
106
- # Ensure the last dimension is size 2
107
- assert x.shape[-1] == 2, f"Last dimension must be 2, got {x.shape[-1]}"
108
-
109
- # Get real and imaginary parts
110
- real, imag = x[..., 0], x[..., 1]
111
-
112
- # Create complex tensor
113
- return real + 1j * imag
114
-
115
-
116
- def view_as_real(x):
117
- """
118
- Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
119
-
120
- Args:
121
- x: A complex tensor.
122
-
123
- Returns:
124
- A real tensor with an extra dimension of size 2.
125
- """
126
- # Get real and imaginary parts
127
- real = mx.real(x)
128
- imag = mx.imag(x)
129
-
130
- # Combine into a tensor with last dimension 2
131
- return mx.stack([real, imag], axis=-1)
132
-
133
-
134
- def vision_apply_rotary_emb(
135
- query: mx.array,
136
- key: mx.array,
137
- freqs_ci: mx.array,
138
- ) -> Tuple[mx.array, mx.array]:
139
-
140
- query_ = view_as_complex(query.astype(mx.float32).reshape(*query.shape[:-1], -1, 2))
141
- key_ = view_as_complex(key.astype(mx.float32).reshape(*key.shape[:-1], -1, 2))
142
- freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_)
143
- query_out = view_as_real(query_ * freqs_ci).flatten(3)
144
- key_out = view_as_real(key_ * freqs_ci).flatten(3)
145
- return query_out.astype(query.dtype), key_out.astype(key.dtype)
146
-
147
-
148
- class Llama4VisionAttention(nn.Module):
149
- def __init__(self, config: VisionConfig):
150
- super().__init__()
151
- self.config = config
152
- self.embed_dim = config.hidden_size
153
- self.num_heads = config.num_attention_heads
154
- self.head_dim = config.hidden_size // config.num_attention_heads
155
- self.num_key_value_groups = 1
156
- self.scale = self.head_dim**-0.5
157
-
158
- self.q_proj = nn.Linear(
159
- self.embed_dim, self.num_heads * self.head_dim, bias=True
160
- )
161
- self.k_proj = nn.Linear(
162
- self.embed_dim, self.num_heads * self.head_dim, bias=True
163
- )
164
- self.v_proj = nn.Linear(
165
- self.embed_dim, self.num_heads * self.head_dim, bias=True
166
- )
167
- self.o_proj = nn.Linear(
168
- self.num_heads * self.head_dim, self.embed_dim, bias=True
169
- )
170
-
171
- def __call__(
172
- self,
173
- hidden_states: mx.array,
174
- freqs_ci: mx.array,
175
- mask: Optional[mx.array] = None,
176
- cache: Optional[mx.array] = None,
177
- ):
178
- B, L, D = hidden_states.shape
179
-
180
- query_states = self.q_proj(hidden_states).reshape(B, L, self.num_heads, -1)
181
- key_states = self.k_proj(hidden_states).reshape(B, L, self.num_heads, -1)
182
- value_states = self.v_proj(hidden_states).reshape(B, L, self.num_heads, -1)
183
-
184
- query_states, key_states = vision_apply_rotary_emb(
185
- query_states, key_states, freqs_ci=freqs_ci
186
- )
187
-
188
- query_states = query_states.transpose(0, 2, 1, 3)
189
- key_states = key_states.transpose(0, 2, 1, 3)
190
- value_states = value_states.transpose(0, 2, 1, 3)
191
-
192
- attn_output = mx.fast.scaled_dot_product_attention(
193
- query_states, key_states, value_states, scale=self.scale
194
- )
195
-
196
- attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
197
- attn_output = self.o_proj(attn_output)
198
- return attn_output
199
-
200
-
201
- class Llama4VisionMLP(nn.Module):
202
- def __init__(self, config, bias=True, is_projector=False):
203
- super().__init__()
204
- self.config = config
205
- self.activation_fn = nn.GELU(approx="fast") # ACT2FN[config.hidden_act]
206
- self.is_projector = is_projector
207
- self.hidden_size = config.hidden_size
208
- self.intermediate_size = config.intermediate_size
209
-
210
- # Determine dimensions for first linear layer based on whether this is a projector
211
- fc1_input_dim = self.intermediate_size if is_projector else self.hidden_size
212
- fc1_output_dim = (
213
- config.projector_input_dim if is_projector else self.intermediate_size
214
- )
215
-
216
- self.fc1 = nn.Linear(fc1_input_dim, fc1_output_dim, bias=bias)
217
-
218
- # Determine dimensions for second linear layer
219
- fc2_input_dim = (
220
- config.projector_output_dim if is_projector else self.intermediate_size
221
- )
222
- fc2_output_dim = (
223
- config.projector_output_dim if is_projector else self.hidden_size
224
- )
225
-
226
- self.fc2 = nn.Linear(fc2_input_dim, fc2_output_dim, bias=bias)
227
-
228
- self.is_projector = is_projector
229
-
230
- def __call__(self, hidden_states: mx.array) -> mx.array:
231
- hidden_states = self.fc1(hidden_states)
232
- hidden_states = self.activation_fn(hidden_states)
233
-
234
- if self.is_projector:
235
- return self.activation_fn(self.fc2(hidden_states))
236
-
237
- return self.fc2(hidden_states)
238
-
239
-
240
- class Llama4VisionEncoderLayer(nn.Module):
241
- def __init__(self, config: VisionConfig):
242
- super().__init__()
243
- self.hidden_size = config.hidden_size
244
-
245
- self.self_attn = Llama4VisionAttention(config)
246
- self.mlp = Llama4VisionMLP(config)
247
-
248
- self.input_layernorm = nn.LayerNorm(config.hidden_size)
249
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
250
-
251
- def __call__(
252
- self,
253
- hidden_state: mx.array,
254
- freqs_ci: mx.array,
255
- mask: Optional[mx.array] = None,
256
- ):
257
- # Self Attention
258
- residual = hidden_state
259
-
260
- hidden_state = self.input_layernorm(hidden_state)
261
-
262
- hidden_state = self.self_attn(
263
- hidden_state,
264
- freqs_ci=freqs_ci,
265
- mask=mask,
266
- )
267
- hidden_state = residual + hidden_state
268
-
269
- # Feed forward
270
- residual = hidden_state
271
- hidden_state = self.post_attention_layernorm(hidden_state)
272
- hidden_state = self.mlp(hidden_state)
273
- hidden_state = residual + hidden_state
274
- return hidden_state
275
-
276
-
277
- class Llama4VisionEncoder(nn.Module):
278
- """
279
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
280
- [`Llama4VisionEncoderLayer`].
281
-
282
- Args:
283
- config: VisionConfig
284
- """
285
-
286
- def __init__(self, config: VisionConfig):
287
- super().__init__()
288
- self.config = config
289
- self.layers = [
290
- Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
291
- ]
292
- self.config = config
293
-
294
- def __call__(
295
- self,
296
- hidden_states: mx.array,
297
- freqs_ci: mx.array, # TODO move this to an attribute instead of keeping it around
298
- mask: Optional[mx.array] = None,
299
- ):
300
-
301
- for i, encoder_layer in enumerate(self.layers):
302
- hidden_states = encoder_layer(
303
- hidden_state=hidden_states,
304
- mask=mask,
305
- freqs_ci=freqs_ci,
306
- )
307
-
308
- return hidden_states
309
-
310
-
311
- class Llama4UnfoldConvolution(nn.Module):
312
- def __init__(self, config):
313
- super().__init__()
314
- kernel_size = config.patch_size
315
- if isinstance(kernel_size, int):
316
- kernel_size = (kernel_size, kernel_size)
317
- self.kernel_size = kernel_size
318
- self.stride = config.patch_size
319
- self.linear = nn.Linear(
320
- config.num_channels * kernel_size[0] * kernel_size[1],
321
- config.hidden_size,
322
- bias=False,
323
- )
324
-
325
- def _pair(self, x):
326
- """Convert input to a pair of values."""
327
- if isinstance(x, (list, tuple)):
328
- return tuple(x)
329
- return (x, x)
330
-
331
- def unfold(self, input_tensor):
332
- """
333
- Extract sliding local blocks from a batched input tensor (MLX implementation).
334
-
335
- This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
336
-
337
- Args:
338
- input_tensor: Input tensor of shape (B, C, H, W)
339
-
340
- Returns:
341
- Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
342
- where L is the number of blocks
343
- """
344
- # Convert to pairs
345
- kernel_size = self._pair(self.kernel_size)
346
- stride = self._pair(self.stride)
347
- padding = (0, 0) # No padding in the original code
348
- dilation = (1, 1) # Default dilation
349
-
350
- # Input shape
351
- batch_size, channels, height, width = input_tensor.shape
352
-
353
- # Calculate output dimensions
354
- height_out = (
355
- height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
356
- ) // stride[0] + 1
357
- width_out = (
358
- width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
359
- ) // stride[1] + 1
360
-
361
- # Initialize output arrays
362
- blocks = []
363
-
364
- # Extract blocks
365
- for i in range(0, height - kernel_size[0] * dilation[0] + 1, stride[0]):
366
- for j in range(0, width - kernel_size[1] * dilation[1] + 1, stride[1]):
367
- # Extract the block for all channels
368
- block = []
369
- for di in range(kernel_size[0]):
370
- for dj in range(kernel_size[1]):
371
- h_idx = i + di * dilation[0]
372
- w_idx = j + dj * dilation[1]
373
- # Get the block for all channels and add to our list
374
- block.append(input_tensor[:, :, h_idx, w_idx])
375
-
376
- # Stack the channel-blocks
377
- block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
378
- block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
379
- blocks.append(block)
380
-
381
- # Stack all blocks together
382
- result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
383
-
384
- # Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
385
- result = mx.reshape(
386
- result,
387
- (
388
- batch_size,
389
- channels * kernel_size[0] * kernel_size[1],
390
- height_out * width_out,
391
- ),
392
- )
393
-
394
- return result
395
-
396
- def __call__(self, hidden_states: mx.array) -> mx.array:
397
- hidden_states = self.unfold(hidden_states)
398
- hidden_states = hidden_states.swapaxes(1, 2)
399
- hidden_states = self.linear(hidden_states)
400
- return hidden_states
401
-
402
-
403
- class Llama4VisionRotaryEmbedding:
404
- def __init__(self, config):
405
- super().__init__()
406
- idx = config.image_size // config.patch_size
407
- img_idx = mx.arange(idx**2, dtype=mx.int32).reshape(idx**2, 1)
408
- img_idx = mx.concatenate([img_idx, img_idx[:1]], axis=0)
409
- img_idx[-1, -1] = -2 # ID_CLS_TOKEN
410
- frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
411
- frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
412
- freq_dim = config.hidden_size // config.num_attention_heads // 2
413
- rope_freq = 1.0 / (
414
- config.rope_theta
415
- ** (
416
- mx.arange(0, freq_dim, 2, dtype=mx.float32)[: (freq_dim // 2)]
417
- / freq_dim
418
- )
419
- )
420
-
421
- # Expand dimensions for frequencies_x and frequencies_y
422
- freqs_x_expanded = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
423
- freqs_y_expanded = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
424
-
425
- def repeat_interleave(tensor, repeats, dim=-1):
426
- # Get the shape
427
- shape = list(tensor.shape)
428
-
429
- # Reshape to add an extra dimension for repeating
430
- tensor = mx.reshape(tensor, shape[:-1] + [shape[-1], 1])
431
-
432
- # Repeat along the new dimension
433
- tensor = mx.repeat(tensor, repeats, axis=-1)
434
-
435
- # Reshape to flatten the last two dimensions
436
- return mx.reshape(tensor, shape[:-1] + [shape[-1] * repeats])
437
-
438
- # Apply interleaving
439
- freqs_x = repeat_interleave(freqs_x_expanded, 2)
440
- freqs_y = repeat_interleave(freqs_y_expanded, 2)
441
- freqs = mx.concatenate([freqs_x, freqs_y], axis=-1).astype(mx.float32)[..., ::2]
442
- # Replaced masked_fill with where
443
- mask = img_idx.reshape(-1, 1, 1) < 0
444
- freqs = mx.where(mask, mx.zeros_like(freqs), freqs)
445
- freq_cis = mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
446
- freq_cis = view_as_complex(freq_cis)
447
- self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
448
-
449
- def __call__(self, hidden_states):
450
- return self.freqs_ci
451
-
452
-
453
- class VisionModel(nn.Module):
454
- def __init__(self, config: VisionConfig):
455
- super().__init__()
456
- self.image_size = config.image_size
457
- self.patch_size = config.patch_size
458
- self.hidden_size = config.hidden_size
459
- self.num_channels = config.num_channels
460
- self.model_type = config.model_type
461
- if self.model_type not in ["llama4", "llama4_vision_model"]:
462
- raise ValueError(f"Model type {self.model_type} not supported")
463
-
464
- self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
465
- self.scale = config.hidden_size**-0.5
466
-
467
- self.class_embedding = self.scale * mx.random.normal((self.hidden_size,))
468
- self.positional_embedding_vlm = self.scale * mx.random.normal(
469
- (self.num_patches, self.hidden_size)
470
- )
471
-
472
- self.patch_embedding = Llama4UnfoldConvolution(config)
473
-
474
- self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
475
-
476
- # layer norms
477
- self.layernorm_pre = nn.LayerNorm(self.hidden_size)
478
- self.layernorm_post = nn.LayerNorm(self.hidden_size)
479
-
480
- # encoders
481
- self.model = Llama4VisionEncoder(config)
482
- self.vision_adapter = Llama4VisionPixelShuffleMLP(config)
483
-
484
- def get_input_embeddings(self):
485
- """
486
- This function is used to fetch the first embedding layer to activate grads on inputs.
487
- """
488
- return self.patch_embedding
489
-
490
- def __call__(
491
- self,
492
- pixel_values: mx.array,
493
- output_attentions: Optional[bool] = None,
494
- output_hidden_states: Optional[bool] = None,
495
- capture_activations: Optional[bool] = True,
496
- ):
497
-
498
- batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
499
- num_concurrent_media = 1
500
- num_chunks = 1
501
-
502
- hidden_state = self.patch_embedding(pixel_values)
503
-
504
- _, num_patches, hidden_dim = hidden_state.shape
505
-
506
- # Add cls token
507
- hidden_state = hidden_state.reshape(
508
- batch_size_times_num_tiles * num_concurrent_media * num_chunks,
509
- num_patches,
510
- hidden_dim,
511
- )
512
-
513
- class_embedding = mx.broadcast_to(
514
- self.class_embedding, (hidden_state.shape[0], 1, hidden_state.shape[-1])
515
- )
516
- hidden_state = mx.concatenate([hidden_state, class_embedding], axis=1)
517
- num_patches += 1
518
-
519
- # Position embeddings
520
- hidden_state = hidden_state.reshape(
521
- batch_size_times_num_tiles * num_concurrent_media,
522
- num_chunks,
523
- num_patches,
524
- hidden_dim,
525
- )
526
-
527
- positional_embedding = self.positional_embedding_vlm
528
- hidden_state = hidden_state + positional_embedding
529
-
530
- hidden_state = self.layernorm_pre(hidden_state)
531
-
532
- hidden_state = hidden_state.reshape(batch_size_times_num_tiles, -1, hidden_dim)
533
- freqs_ci = self.rotary_embedding(pixel_values)
534
-
535
- hidden_state = self.model(
536
- hidden_state,
537
- mask=None,
538
- freqs_ci=freqs_ci,
539
- )
540
-
541
- hidden_state = self.layernorm_post(hidden_state)
542
-
543
- hidden_state = hidden_state[:, :-1, :]
544
-
545
- # now, we use Llama4VisionPixelShuffle + mlp to project embeddings
546
- final_hidden_state = self.vision_adapter(hidden_state)
547
-
548
- # Return only the final state
549
- return final_hidden_state
550
-
551
- def sanitize(self, weights):
552
- sanitized_weights = {}
553
- for k, v in weights.items():
554
- if "position_ids" in k:
555
- # Remove unused position_ids
556
- continue
557
- else:
558
- sanitized_weights[k] = v
559
-
560
- return sanitized_weights
@@ -1,8 +0,0 @@
1
- from .llava import (
2
- LanguageModel,
3
- Model,
4
- ModelConfig,
5
- TextConfig,
6
- VisionConfig,
7
- VisionModel,
8
- )