nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (200) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
  5. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
  6. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
  7. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
  8. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
  9. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
  10. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  11. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  12. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  13. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  14. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  15. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  16. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  17. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  18. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  19. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  21. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  22. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  23. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  24. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  33. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  34. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  35. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  36. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  37. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  39. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  40. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  41. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  43. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  44. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  45. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  46. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  47. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  48. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  49. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  54. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  55. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  56. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  57. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  58. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  59. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  61. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  62. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  63. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  64. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  65. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  195. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  196. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  197. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  198. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  199. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
  200. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
@@ -1,386 +0,0 @@
1
- # Copyright © 2023-2024 Apple Inc.
2
-
3
- import json
4
- import os
5
- from typing import Optional
6
-
7
- import mlx.core as mx
8
- from huggingface_hub import hf_hub_download
9
- from mlx.utils import tree_unflatten
10
-
11
- from .clip import CLIPTextModel
12
- from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
13
- from .tokenizer import Tokenizer
14
- from .unet import UNetModel
15
- from .vae import Autoencoder
16
-
17
- _DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
18
- _MODELS = {
19
- # See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
20
- "stabilityai/sdxl-turbo": {
21
- "unet_config": "unet/config.json",
22
- "unet": "unet/diffusion_pytorch_model.safetensors",
23
- "text_encoder_config": "text_encoder/config.json",
24
- "text_encoder": "text_encoder/model.safetensors",
25
- "text_encoder_2_config": "text_encoder_2/config.json",
26
- "text_encoder_2": "text_encoder_2/model.safetensors",
27
- "vae_config": "vae/config.json",
28
- "vae": "vae/diffusion_pytorch_model.safetensors",
29
- "diffusion_config": "scheduler/scheduler_config.json",
30
- "tokenizer_vocab": "tokenizer/vocab.json",
31
- "tokenizer_merges": "tokenizer/merges.txt",
32
- "tokenizer_2_vocab": "tokenizer_2/vocab.json",
33
- "tokenizer_2_merges": "tokenizer_2/merges.txt",
34
- },
35
- # See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
36
- "stabilityai/stable-diffusion-2-1-base": {
37
- "unet_config": "unet/config.json",
38
- "unet": "unet/diffusion_pytorch_model.safetensors",
39
- "text_encoder_config": "text_encoder/config.json",
40
- "text_encoder": "text_encoder/model.safetensors",
41
- "vae_config": "vae/config.json",
42
- "vae": "vae/diffusion_pytorch_model.safetensors",
43
- "diffusion_config": "scheduler/scheduler_config.json",
44
- "tokenizer_vocab": "tokenizer/vocab.json",
45
- "tokenizer_merges": "tokenizer/merges.txt",
46
- },
47
- }
48
-
49
-
50
- def map_unet_weights(key, value):
51
- # Map up/downsampling
52
- if "downsamplers" in key:
53
- key = key.replace("downsamplers.0.conv", "downsample")
54
- if "upsamplers" in key:
55
- key = key.replace("upsamplers.0.conv", "upsample")
56
-
57
- # Map the mid block
58
- if "mid_block.resnets.0" in key:
59
- key = key.replace("mid_block.resnets.0", "mid_blocks.0")
60
- if "mid_block.attentions.0" in key:
61
- key = key.replace("mid_block.attentions.0", "mid_blocks.1")
62
- if "mid_block.resnets.1" in key:
63
- key = key.replace("mid_block.resnets.1", "mid_blocks.2")
64
-
65
- # Map attention layers
66
- if "to_k" in key:
67
- key = key.replace("to_k", "key_proj")
68
- if "to_out.0" in key:
69
- key = key.replace("to_out.0", "out_proj")
70
- if "to_q" in key:
71
- key = key.replace("to_q", "query_proj")
72
- if "to_v" in key:
73
- key = key.replace("to_v", "value_proj")
74
-
75
- # Map transformer ffn
76
- if "ff.net.2" in key:
77
- key = key.replace("ff.net.2", "linear3")
78
- if "ff.net.0" in key:
79
- k1 = key.replace("ff.net.0.proj", "linear1")
80
- k2 = key.replace("ff.net.0.proj", "linear2")
81
- v1, v2 = mx.split(value, 2)
82
-
83
- return [(k1, v1), (k2, v2)]
84
-
85
- if "conv_shortcut.weight" in key:
86
- value = value.squeeze()
87
-
88
- # Transform the weights from 1x1 convs to linear
89
- if len(value.shape) == 4 and ("proj_in" in key or "proj_out" in key):
90
- value = value.squeeze()
91
-
92
- if len(value.shape) == 4:
93
- value = value.transpose(0, 2, 3, 1)
94
- value = value.reshape(-1).reshape(value.shape)
95
-
96
- return [(key, value)]
97
-
98
-
99
- def map_clip_text_encoder_weights(key, value):
100
- # Remove prefixes
101
- if key.startswith("text_model."):
102
- key = key[11:]
103
- if key.startswith("embeddings."):
104
- key = key[11:]
105
- if key.startswith("encoder."):
106
- key = key[8:]
107
-
108
- # Map attention layers
109
- if "self_attn." in key:
110
- key = key.replace("self_attn.", "attention.")
111
- if "q_proj." in key:
112
- key = key.replace("q_proj.", "query_proj.")
113
- if "k_proj." in key:
114
- key = key.replace("k_proj.", "key_proj.")
115
- if "v_proj." in key:
116
- key = key.replace("v_proj.", "value_proj.")
117
-
118
- # Map ffn layers
119
- if "mlp.fc1" in key:
120
- key = key.replace("mlp.fc1", "linear1")
121
- if "mlp.fc2" in key:
122
- key = key.replace("mlp.fc2", "linear2")
123
-
124
- return [(key, value)]
125
-
126
-
127
- def map_vae_weights(key, value):
128
- # Map up/downsampling
129
- if "downsamplers" in key:
130
- key = key.replace("downsamplers.0.conv", "downsample")
131
- if "upsamplers" in key:
132
- key = key.replace("upsamplers.0.conv", "upsample")
133
-
134
- # Map attention layers
135
- if "to_k" in key:
136
- key = key.replace("to_k", "key_proj")
137
- if "to_out.0" in key:
138
- key = key.replace("to_out.0", "out_proj")
139
- if "to_q" in key:
140
- key = key.replace("to_q", "query_proj")
141
- if "to_v" in key:
142
- key = key.replace("to_v", "value_proj")
143
-
144
- # Map the mid block
145
- if "mid_block.resnets.0" in key:
146
- key = key.replace("mid_block.resnets.0", "mid_blocks.0")
147
- if "mid_block.attentions.0" in key:
148
- key = key.replace("mid_block.attentions.0", "mid_blocks.1")
149
- if "mid_block.resnets.1" in key:
150
- key = key.replace("mid_block.resnets.1", "mid_blocks.2")
151
-
152
- # Map the quant/post_quant layers
153
- if "quant_conv" in key:
154
- key = key.replace("quant_conv", "quant_proj")
155
- value = value.squeeze()
156
-
157
- # Map the conv_shortcut to linear
158
- if "conv_shortcut.weight" in key:
159
- value = value.squeeze()
160
-
161
- if len(value.shape) == 4:
162
- value = value.transpose(0, 2, 3, 1)
163
- value = value.reshape(-1).reshape(value.shape)
164
-
165
- return [(key, value)]
166
-
167
-
168
- def _flatten(params):
169
- return [(k, v) for p in params for (k, v) in p]
170
-
171
-
172
- def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
173
- dtype = mx.float16 if float16 else mx.float32
174
- weights = mx.load(weight_file)
175
- weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
176
- model.update(tree_unflatten(weights))
177
-
178
-
179
- def _check_key(key: str, part: str):
180
- # Check if it's a local path
181
- if os.path.exists(key) or '/' in key or '\\' in key:
182
- # For local paths, we'll use a default model structure
183
- return
184
- if key not in _MODELS:
185
- raise ValueError(
186
- f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
187
- )
188
-
189
- def _get_model_path(key: str, file_path: str):
190
- """Get the full path for a model file, supporting both local and HuggingFace paths"""
191
- if os.path.exists(key) or '/' in key or '\\' in key:
192
- # Local path
193
- return os.path.join(key, file_path)
194
- else:
195
- # HuggingFace path
196
- return hf_hub_download(key, file_path)
197
-
198
-
199
- def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
200
- """Load the stable diffusion UNet from Hugging Face Hub."""
201
- _check_key(key, "load_unet")
202
-
203
- # Get the config path
204
- if os.path.exists(key) or '/' in key or '\\' in key:
205
- # Local path - use SDXL Turbo structure
206
- unet_config = "unet/config.json"
207
- else:
208
- unet_config = _MODELS[key]["unet_config"]
209
-
210
- with open(_get_model_path(key, unet_config)) as f:
211
- config = json.load(f)
212
-
213
- n_blocks = len(config["block_out_channels"])
214
- model = UNetModel(
215
- UNetConfig(
216
- in_channels=config["in_channels"],
217
- out_channels=config["out_channels"],
218
- block_out_channels=config["block_out_channels"],
219
- layers_per_block=[config["layers_per_block"]] * n_blocks,
220
- transformer_layers_per_block=config.get(
221
- "transformer_layers_per_block", (1,) * 4
222
- ),
223
- num_attention_heads=(
224
- [config["attention_head_dim"]] * n_blocks
225
- if isinstance(config["attention_head_dim"], int)
226
- else config["attention_head_dim"]
227
- ),
228
- cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
229
- norm_num_groups=config["norm_num_groups"],
230
- down_block_types=config["down_block_types"],
231
- up_block_types=config["up_block_types"][::-1],
232
- addition_embed_type=config.get("addition_embed_type", None),
233
- addition_time_embed_dim=config.get("addition_time_embed_dim", None),
234
- projection_class_embeddings_input_dim=config.get(
235
- "projection_class_embeddings_input_dim", None
236
- ),
237
- )
238
- )
239
-
240
- # Download the weights and map them into the model
241
- if os.path.exists(key) or '/' in key or '\\' in key:
242
- # Local path - use SDXL Turbo structure
243
- unet_weights = "unet/diffusion_pytorch_model.safetensors"
244
- else:
245
- unet_weights = _MODELS[key]["unet"]
246
-
247
- weight_file = _get_model_path(key, unet_weights)
248
- _load_safetensor_weights(map_unet_weights, model, weight_file, float16)
249
-
250
- return model
251
-
252
-
253
- def load_text_encoder(
254
- key: str = _DEFAULT_MODEL,
255
- float16: bool = False,
256
- model_key: str = "text_encoder",
257
- config_key: Optional[str] = None,
258
- ):
259
- """Load the stable diffusion text encoder from Hugging Face Hub."""
260
- _check_key(key, "load_text_encoder")
261
-
262
- config_key = config_key or (model_key + "_config")
263
-
264
- # Download the config and create the model
265
- if os.path.exists(key) or '/' in key or '\\' in key:
266
- # Local path - use SDXL Turbo structure
267
- text_encoder_config = f"{model_key}/config.json"
268
- else:
269
- text_encoder_config = _MODELS[key][config_key]
270
-
271
- with open(_get_model_path(key, text_encoder_config)) as f:
272
- config = json.load(f)
273
-
274
- with_projection = "WithProjection" in config["architectures"][0]
275
-
276
- model = CLIPTextModel(
277
- CLIPTextModelConfig(
278
- num_layers=config["num_hidden_layers"],
279
- model_dims=config["hidden_size"],
280
- num_heads=config["num_attention_heads"],
281
- max_length=config["max_position_embeddings"],
282
- vocab_size=config["vocab_size"],
283
- projection_dim=config["projection_dim"] if with_projection else None,
284
- hidden_act=config.get("hidden_act", "quick_gelu"),
285
- )
286
- )
287
-
288
- # Download the weights and map them into the model
289
- if os.path.exists(key) or '/' in key or '\\' in key:
290
- # Local path - use SDXL Turbo structure
291
- text_encoder_weights = f"{model_key}/model.safetensors"
292
- else:
293
- text_encoder_weights = _MODELS[key][model_key]
294
-
295
- weight_file = _get_model_path(key, text_encoder_weights)
296
- _load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
297
-
298
- return model
299
-
300
-
301
- def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
302
- """Load the stable diffusion autoencoder from Hugging Face Hub."""
303
- _check_key(key, "load_autoencoder")
304
-
305
- # Download the config and create the model
306
- if os.path.exists(key) or '/' in key or '\\' in key:
307
- # Local path - use SDXL Turbo structure
308
- vae_config = "vae/config.json"
309
- else:
310
- vae_config = _MODELS[key]["vae_config"]
311
-
312
- with open(_get_model_path(key, vae_config)) as f:
313
- config = json.load(f)
314
-
315
- model = Autoencoder(
316
- AutoencoderConfig(
317
- in_channels=config["in_channels"],
318
- out_channels=config["out_channels"],
319
- latent_channels_out=2 * config["latent_channels"],
320
- latent_channels_in=config["latent_channels"],
321
- block_out_channels=config["block_out_channels"],
322
- layers_per_block=config["layers_per_block"],
323
- norm_num_groups=config["norm_num_groups"],
324
- scaling_factor=config.get("scaling_factor", 0.18215),
325
- )
326
- )
327
-
328
- # Download the weights and map them into the model
329
- if os.path.exists(key) or '/' in key or '\\' in key:
330
- # Local path - use SDXL Turbo structure
331
- vae_weights = "vae/diffusion_pytorch_model.safetensors"
332
- else:
333
- vae_weights = _MODELS[key]["vae"]
334
-
335
- weight_file = _get_model_path(key, vae_weights)
336
- _load_safetensor_weights(map_vae_weights, model, weight_file, float16)
337
-
338
- return model
339
-
340
-
341
- def load_diffusion_config(key: str = _DEFAULT_MODEL):
342
- """Load the stable diffusion config from Hugging Face Hub."""
343
- _check_key(key, "load_diffusion_config")
344
-
345
- if os.path.exists(key) or '/' in key or '\\' in key:
346
- # Local path - use SDXL Turbo structure
347
- diffusion_config = "scheduler/scheduler_config.json"
348
- else:
349
- diffusion_config = _MODELS[key]["diffusion_config"]
350
-
351
- with open(_get_model_path(key, diffusion_config)) as f:
352
- config = json.load(f)
353
-
354
- return DiffusionConfig(
355
- beta_start=config["beta_start"],
356
- beta_end=config["beta_end"],
357
- beta_schedule=config["beta_schedule"],
358
- num_train_steps=config["num_train_timesteps"],
359
- )
360
-
361
-
362
- def load_tokenizer(
363
- key: str = _DEFAULT_MODEL,
364
- vocab_key: str = "tokenizer_vocab",
365
- merges_key: str = "tokenizer_merges",
366
- ):
367
- _check_key(key, "load_tokenizer")
368
-
369
- if os.path.exists(key) or '/' in key or '\\' in key:
370
- # Local path - use SDXL Turbo structure
371
- # For SDXL Turbo, we always use the main tokenizer files
372
- vocab_file = _get_model_path(key, "tokenizer/vocab.json")
373
- merges_file = _get_model_path(key, "tokenizer/merges.txt")
374
- else:
375
- vocab_file = _get_model_path(key, _MODELS[key][vocab_key])
376
- merges_file = _get_model_path(key, _MODELS[key][merges_key])
377
-
378
- with open(vocab_file, encoding="utf-8") as f:
379
- vocab = json.load(f)
380
-
381
- with open(merges_file, encoding="utf-8") as f:
382
- bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
383
- bpe_merges = [tuple(m.split()) for m in bpe_merges]
384
- bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
385
-
386
- return Tokenizer(bpe_ranks, vocab)
@@ -1,105 +0,0 @@
1
- # Copyright © 2023 Apple Inc.
2
-
3
- import mlx.core as mx
4
-
5
- from .config import DiffusionConfig
6
-
7
-
8
- def _linspace(a, b, num):
9
- x = mx.arange(0, num) / (num - 1)
10
- return (b - a) * x + a
11
-
12
-
13
- def _interp(y, x_new):
14
- """Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
15
- x_low = x_new.astype(mx.int32)
16
- x_high = mx.minimum(x_low + 1, len(y) - 1)
17
-
18
- y_low = y[x_low]
19
- y_high = y[x_high]
20
- delta_x = x_new - x_low
21
- y_new = y_low * (1 - delta_x) + delta_x * y_high
22
-
23
- return y_new
24
-
25
-
26
- class SimpleEulerSampler:
27
- """A simple Euler integrator that can be used to sample from our diffusion models.
28
-
29
- The method ``step()`` performs one Euler step from x_t to x_t_prev.
30
- """
31
-
32
- def __init__(self, config: DiffusionConfig):
33
- # Compute the noise schedule
34
- if config.beta_schedule == "linear":
35
- betas = _linspace(
36
- config.beta_start, config.beta_end, config.num_train_steps
37
- )
38
- elif config.beta_schedule == "scaled_linear":
39
- betas = _linspace(
40
- config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
41
- ).square()
42
- else:
43
- raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
44
-
45
- alphas = 1 - betas
46
- alphas_cumprod = mx.cumprod(alphas)
47
-
48
- self._sigmas = mx.concatenate(
49
- [mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
50
- )
51
-
52
- @property
53
- def max_time(self):
54
- return len(self._sigmas) - 1
55
-
56
- def sample_prior(self, shape, dtype=mx.float32, key=None):
57
- noise = mx.random.normal(shape, key=key)
58
- return (
59
- noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()
60
- ).astype(dtype)
61
-
62
- def add_noise(self, x, t, key=None):
63
- noise = mx.random.normal(x.shape, key=key)
64
- s = self.sigmas(t)
65
- return (x + noise * s) * (s.square() + 1).rsqrt()
66
-
67
- def sigmas(self, t):
68
- return _interp(self._sigmas, t)
69
-
70
- def timesteps(self, num_steps: int, start_time=None, dtype=mx.float32):
71
- start_time = start_time or (len(self._sigmas) - 1)
72
- assert 0 < start_time <= (len(self._sigmas) - 1)
73
- steps = _linspace(start_time, 0, num_steps + 1).astype(dtype)
74
- return list(zip(steps, steps[1:]))
75
-
76
- def step(self, eps_pred, x_t, t, t_prev):
77
- sigma = self.sigmas(t).astype(eps_pred.dtype)
78
- sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
79
-
80
- dt = sigma_prev - sigma
81
- x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
82
-
83
- x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
84
-
85
- return x_t_prev
86
-
87
-
88
- class SimpleEulerAncestralSampler(SimpleEulerSampler):
89
- def step(self, eps_pred, x_t, t, t_prev):
90
- sigma = self.sigmas(t).astype(eps_pred.dtype)
91
- sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
92
-
93
- sigma2 = sigma.square()
94
- sigma_prev2 = sigma_prev.square()
95
- sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
96
- sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
97
-
98
- dt = sigma_down - sigma
99
- x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
100
- noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
101
- x_t_prev = x_t_prev + noise * sigma_up
102
-
103
- x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
104
-
105
- return x_t_prev
@@ -1,100 +0,0 @@
1
- # Copyright © 2023 Apple Inc.
2
-
3
- import regex
4
-
5
-
6
- class Tokenizer:
7
- """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
8
-
9
- def __init__(self, bpe_ranks, vocab):
10
- self.bpe_ranks = bpe_ranks
11
- self.vocab = vocab
12
- self.pat = regex.compile(
13
- r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
14
- regex.IGNORECASE,
15
- )
16
-
17
- self._cache = {self.bos: self.bos, self.eos: self.eos}
18
-
19
- @property
20
- def bos(self):
21
- return "<|startoftext|>"
22
-
23
- @property
24
- def bos_token(self):
25
- return self.vocab[self.bos]
26
-
27
- @property
28
- def eos(self):
29
- return "<|endoftext|>"
30
-
31
- @property
32
- def eos_token(self):
33
- return self.vocab[self.eos]
34
-
35
- def bpe(self, text):
36
- if text in self._cache:
37
- return self._cache[text]
38
-
39
- unigrams = list(text[:-1]) + [text[-1] + "</w>"]
40
- unique_bigrams = set(zip(unigrams, unigrams[1:]))
41
-
42
- if not unique_bigrams:
43
- return unigrams
44
-
45
- # In every iteration try to merge the two most likely bigrams. If none
46
- # was merged we are done.
47
- #
48
- # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
49
- while unique_bigrams:
50
- bigram = min(
51
- unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
52
- )
53
- if bigram not in self.bpe_ranks:
54
- break
55
-
56
- new_unigrams = []
57
- skip = False
58
- for a, b in zip(unigrams, unigrams[1:]):
59
- if skip:
60
- skip = False
61
- continue
62
-
63
- if (a, b) == bigram:
64
- new_unigrams.append(a + b)
65
- skip = True
66
-
67
- else:
68
- new_unigrams.append(a)
69
-
70
- if not skip:
71
- new_unigrams.append(b)
72
-
73
- unigrams = new_unigrams
74
- unique_bigrams = set(zip(unigrams, unigrams[1:]))
75
-
76
- self._cache[text] = unigrams
77
-
78
- return unigrams
79
-
80
- def tokenize(self, text, prepend_bos=True, append_eos=True):
81
- if isinstance(text, list):
82
- return [self.tokenize(t, prepend_bos, append_eos) for t in text]
83
-
84
- # Lower case cleanup and split according to self.pat. Hugging Face does
85
- # a much more thorough job here but this should suffice for 95% of
86
- # cases.
87
- clean_text = regex.sub(r"\s+", " ", text.lower())
88
- tokens = regex.findall(self.pat, clean_text)
89
-
90
- # Split the tokens according to the byte-pair merge file
91
- bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
92
-
93
- # Map to token ids and return
94
- tokens = [self.vocab[t] for t in bpe_tokens]
95
- if prepend_bos:
96
- tokens = [self.bos_token] + tokens
97
- if append_eos:
98
- tokens.append(self.eos_token)
99
-
100
- return tokens