nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc8__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (196) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/binds/libnexa_bridge.dylib +0 -0
  4. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/METADATA +1 -1
  5. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/RECORD +7 -196
  6. nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
  7. nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
  8. nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
  9. nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
  10. nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
  11. nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
  12. nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
  13. nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
  14. nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
  15. nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
  16. nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
  17. nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
  18. nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
  19. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
  20. nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
  21. nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
  22. nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
  23. nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
  24. nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
  25. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
  26. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
  27. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
  28. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
  29. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
  30. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
  31. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
  32. nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
  33. nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
  34. nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
  35. nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
  36. nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
  37. nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
  38. nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
  39. nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
  40. nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
  41. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
  42. nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
  43. nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
  44. nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
  45. nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
  46. nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
  47. nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
  48. nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
  49. nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
  50. nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
  51. nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
  52. nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
  53. nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
  54. nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
  55. nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
  56. nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
  57. nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
  58. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
  59. nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
  60. nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
  61. nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
  62. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
  63. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
  64. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
  65. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
  66. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
  67. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
  68. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
  69. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
  70. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
  71. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
  72. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
  73. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
  74. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
  75. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
  76. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
  77. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
  78. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
  79. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
  80. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
  81. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
  82. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
  83. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
  84. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
  85. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
  86. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
  87. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
  88. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
  89. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
  90. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
  91. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
  92. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
  93. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
  94. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
  95. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
  96. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
  97. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
  98. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
  99. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
  100. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
  101. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
  102. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
  103. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
  104. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
  105. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
  106. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
  107. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
  108. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
  109. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
  110. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
  111. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
  112. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
  113. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
  114. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
  115. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
  116. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
  117. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
  118. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
  119. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
  120. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
  121. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
  122. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
  123. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
  124. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
  125. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
  126. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
  127. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
  128. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
  129. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
  130. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
  131. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
  132. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
  133. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
  134. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
  135. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
  136. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
  137. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
  138. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
  139. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
  140. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
  141. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
  142. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
  143. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
  144. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
  145. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
  146. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
  147. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
  148. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
  149. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
  150. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
  151. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
  152. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
  153. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
  154. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
  155. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
  156. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
  157. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
  158. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
  159. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
  160. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
  161. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
  162. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
  163. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
  164. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  165. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
  166. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
  167. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
  168. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
  169. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
  170. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
  171. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
  172. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
  173. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  174. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
  175. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
  176. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
  177. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
  178. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
  179. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
  180. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
  181. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
  182. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
  183. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
  184. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
  185. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
  186. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
  187. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
  188. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
  189. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
  190. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
  191. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
  192. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
  193. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
  194. nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
  195. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/WHEEL +0 -0
  196. {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc8.dist-info}/top_level.txt +0 -0
@@ -1,68 +0,0 @@
1
- # Copyright © 2024 Apple Inc.
2
-
3
- from mlx_lm import generate, load
4
-
5
-
6
- def test_llm_generate_stream(model_path):
7
- # Load the corresponding model and tokenizer
8
- model, tokenizer = load(path_or_hf_repo=model_path)
9
-
10
- # Conversation history to maintain context
11
- conversation = []
12
-
13
- # Specify the maximum number of tokens
14
- max_tokens = 1_000
15
-
16
- # Specify if tokens and timing information will be printed
17
- verbose = True
18
-
19
- print("Multi-round conversation started. Type 'quit' or 'exit' to end.")
20
- print("=" * 50)
21
-
22
- while True:
23
- # Get user input
24
- user_input = input("\nUser: ").strip()
25
-
26
- # Check for exit commands
27
- if user_input.lower() in ["quit", "exit", "q"]:
28
- print("Goodbye!")
29
- break
30
-
31
- if not user_input:
32
- continue
33
-
34
- # Add user input to conversation history
35
- conversation.append({"role": "user", "content": user_input})
36
-
37
- # Transform the conversation into the chat template
38
- prompt = tokenizer.apply_chat_template(
39
- conversation=conversation, add_generation_prompt=True
40
- )
41
-
42
- # Generate response
43
- print("Assistant: ", end="", flush=True)
44
-
45
- # Generate text, already handled KV cache
46
- response = generate(
47
- model=model,
48
- tokenizer=tokenizer,
49
- prompt=prompt,
50
- max_tokens=max_tokens,
51
- verbose=verbose,
52
- )
53
-
54
- # Extract the generated text (response includes the prompt)
55
- generated_text = response.strip()
56
-
57
- # Add assistant response to conversation history
58
- conversation.append({"role": "assistant", "content": generated_text})
59
-
60
- print() # New line after response
61
-
62
-
63
- if __name__ == "__main__":
64
- import argparse
65
- parser = argparse.ArgumentParser()
66
- parser.add_argument("--model_path", type=str, default="mlx-community/Qwen3-1.7B-4bit-DWQ")
67
- args = parser.parse_args()
68
- test_llm_generate_stream(args.model_path)
File without changes
@@ -1,174 +0,0 @@
1
- # Copyright © Nexa AI
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import sys
16
- import os
17
- import mlx.core as mx
18
- import mlx.nn as nn
19
- import numpy as np
20
- import time
21
-
22
- from transformers import AutoTokenizer
23
- from huggingface_hub import snapshot_download
24
- from .modeling.nexa_jina_rerank import Model, ModelArgs
25
-
26
-
27
- def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
28
- """Create position ids from input ids, accounting for padding tokens"""
29
- mask = (input_ids != padding_idx).astype(mx.int32)
30
- incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
31
- return incremental_indices.astype(mx.int32) + padding_idx
32
-
33
-
34
- def prepare_inputs(query, documents, tokenizer, max_length=1024):
35
- """Prepare inputs for the model - match torch exactly"""
36
- sentence_pairs = [[query, doc] for doc in documents]
37
- inputs = tokenizer(
38
- sentence_pairs,
39
- padding="max_length",
40
- truncation=True,
41
- return_tensors="np",
42
- max_length=max_length,
43
- )
44
-
45
- input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
46
- seqlen = input_ids.shape[1]
47
- attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
48
-
49
- # Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
50
- token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
51
- batch_size = input_ids.shape[0]
52
- token_type_ids = mx.broadcast_to(
53
- mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
54
- )
55
-
56
- # Create position ids for each sequence in the batch
57
- position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=1)
58
-
59
- return input_ids, attention_mask, token_type_ids, position_ids
60
-
61
-
62
- def load_model(model_id):
63
- """Initialize and load the Jina V2 rerank model."""
64
- curr_dir = os.path.dirname(os.path.abspath(__file__))
65
- model_dir = f"{curr_dir}/modelfiles/nexaml_jina_v2_rerank_mlx"
66
-
67
- # Download model if not exists
68
- if not os.path.exists(model_dir):
69
- print(f"Downloading model {model_id}...")
70
-
71
- os.makedirs(model_dir, exist_ok=True)
72
-
73
- try:
74
- snapshot_download(
75
- repo_id=model_id,
76
- allow_patterns=["*.safetensors", "config.json", "tokenizer*"],
77
- local_dir=model_dir,
78
- local_dir_use_symlinks=False
79
- )
80
- print("Model download completed!")
81
- except Exception as e:
82
- print(f"Failed to download model: {e}")
83
- print("Try: huggingface-cli login (if authentication required)")
84
- raise
85
-
86
- # Create model config
87
- config = ModelArgs()
88
- model = Model(config)
89
-
90
- # Load weights
91
- weight_file = os.path.join(model_dir, "model.safetensors")
92
- if not os.path.exists(weight_file):
93
- # Try alternative naming patterns
94
- safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
95
- if safetensors_files:
96
- weight_file = os.path.join(model_dir, safetensors_files[0])
97
- else:
98
- raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
99
-
100
- print(f"Loading weights from: {weight_file}")
101
- model.load_weights(weight_file, strict=True)
102
- model.eval()
103
-
104
- return model, model_dir
105
-
106
-
107
- def load_tokenizer(model_path):
108
- """Load and configure the tokenizer."""
109
- return AutoTokenizer.from_pretrained(model_path)
110
-
111
-
112
- def rerank_documents(model, tokenizer, query, documents, max_length=1024):
113
- """Rerank documents based on query relevance."""
114
- # Prepare inputs
115
- input_ids, attention_mask, token_type_ids, position_ids = prepare_inputs(
116
- query, documents, tokenizer, max_length
117
- )
118
-
119
- # Run inference
120
- start_time = time.time()
121
- scores = model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
122
- scores = mx.squeeze(scores, axis=-1)
123
- end_time = time.time()
124
-
125
- # Apply sigmoid to get probabilities
126
- scores_sigmoid = mx.sigmoid(scores)
127
-
128
- inference_time = (end_time - start_time) * 1000 # Convert to ms
129
-
130
- return scores, scores_sigmoid, inference_time
131
-
132
-
133
- def main(model_id):
134
- """Main function to handle reranking demonstration."""
135
-
136
- # Load model and tokenizer
137
- model, model_path = load_model(model_id)
138
- tokenizer = load_tokenizer(model_path)
139
-
140
- # Example query and documents
141
- query = "What are the health benefits of green tea?"
142
- documents = [
143
- "Green tea is rich in antioxidants and may improve brain function.",
144
- "Coffee contains caffeine and can boost energy levels.",
145
- "Das Trinken von grünem Tee kann das Risiko für Herzkrankheiten senken.",
146
- "Black tea is another popular beverage with its own health benefits.",
147
- ]
148
-
149
- # Perform reranking
150
- scores, scores_sigmoid, inference_time = rerank_documents(
151
- model, tokenizer, query, documents
152
- )
153
-
154
- # Display results
155
- print("=" * 70)
156
- print("Reranking Results:")
157
- print("=" * 70)
158
- print(f"Query: {query}")
159
- print()
160
-
161
- for i, (doc, score, prob) in enumerate(zip(documents, scores.tolist(), scores_sigmoid.tolist())):
162
- print(f"Document {i+1}:")
163
- print(f" Text: {doc}")
164
- print(f" Score: {score:.4f}")
165
- print(f" Probability: {prob:.4f}")
166
- print()
167
-
168
- print(f"Inference time: {inference_time:.1f}ms")
169
- print(f"Throughput: {len(documents)/inference_time*1000:.1f} docs/s")
170
-
171
-
172
- if __name__ == "__main__":
173
- model_id = "nexaml/jina-v2-rerank-mlx"
174
- main(model_id)
@@ -1,287 +0,0 @@
1
- # Copyright © Nexa AI
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.s
14
-
15
- from __future__ import annotations
16
-
17
- import os
18
- import json
19
- import mlx.core as mx
20
- import mlx.nn as nn
21
- import numpy as np
22
- import time
23
- from pathlib import Path
24
- from typing import Any, List, Optional, Sequence
25
- from dataclasses import dataclass
26
- from abc import ABC, abstractmethod
27
-
28
- # Import necessary modules
29
- from transformers import AutoTokenizer
30
-
31
- # Import from ml.py for API alignment (assuming similar structure)
32
- try:
33
- from ml import (
34
- Reranker as BaseReranker,
35
- Path as PathType,
36
- )
37
- except ImportError:
38
- # Fallback to local definitions if ml.py not available
39
- PathType = Path
40
- BaseReranker = ABC
41
-
42
- # Import profiling module
43
- from profiling import ProfilingMixin, ProfilingData, StopReason
44
-
45
- # Import the model implementation
46
- from .modeling.nexa_jina_rerank import Model, ModelArgs
47
-
48
-
49
- @dataclass
50
- class RerankConfig:
51
- """Configuration for reranking."""
52
- batch_size: int = 1
53
- normalize: bool = True
54
- normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
55
-
56
- def __init__(
57
- self,
58
- batch_size: int = 1,
59
- normalize: bool = True,
60
- normalize_method: str = "softmax",
61
- ) -> None:
62
- self.batch_size = batch_size
63
- self.normalize = normalize
64
- self.normalize_method = normalize_method
65
-
66
-
67
- class Reranker(BaseReranker, ProfilingMixin):
68
- """
69
- Reranker interface for MLX reranking models.
70
- API aligned with ml.py Reranker abstract base class.
71
- """
72
-
73
- def __init__(
74
- self,
75
- model_path: PathType,
76
- tokenizer_path: PathType,
77
- device: Optional[str] = None,
78
- ) -> None:
79
- """Initialize the Reranker model."""
80
- # Initialize profiling mixin
81
- ProfilingMixin.__init__(self)
82
-
83
- # Store paths
84
- if (os.path.isfile(model_path)):
85
- model_path = os.path.dirname(model_path)
86
-
87
- # Call parent constructor if inheriting from ml.py
88
- if hasattr(super(), '__init__'):
89
- super().__init__(model_path, tokenizer_path, device)
90
-
91
- # Store paths and device
92
- self.model_path = model_path
93
- self.tokenizer_path = tokenizer_path
94
- self.device = device if device is not None else "cpu"
95
-
96
- # Initialize model and tokenizer as None
97
- self.model = None
98
- self.tokenizer = None
99
- self.config = None
100
-
101
- def destroy(self) -> None:
102
- """Destroy the model and free resources."""
103
- self.model = None
104
- self.tokenizer = None
105
- self.config = None
106
-
107
- def load_model(self, model_path: PathType, extra_data: Any = None) -> bool:
108
- """Load model from path."""
109
- try:
110
- # Use the provided model_path or fall back to instance path
111
- if model_path:
112
- # Apply same file-to-directory conversion as in __init__
113
- if os.path.isfile(model_path):
114
- model_path = os.path.dirname(model_path)
115
- self.model_path = model_path
116
-
117
- # Load the model using internal implementation
118
- self.model = self._load_jina_model(self.model_path)
119
- self.tokenizer = self._load_tokenizer()
120
-
121
- return True
122
- except Exception as e:
123
- print(f"Failed to load model: {e}")
124
- return False
125
-
126
- def close(self) -> None:
127
- """Close the model."""
128
- self.destroy()
129
-
130
- def rerank(
131
- self,
132
- query: str,
133
- documents: Sequence[str],
134
- config: Optional[RerankConfig] = None,
135
- clear_cache: bool = True,
136
- ) -> mx.array:
137
- """Rerank documents given a query."""
138
- if self.model is None or self.tokenizer is None:
139
- raise RuntimeError("Model not loaded. Call load_model() first.")
140
-
141
- if config is None:
142
- config = RerankConfig()
143
-
144
- # Start profiling
145
- self._start_profiling()
146
- self._prompt_start()
147
-
148
- all_scores = []
149
-
150
- # Process documents in batches
151
- batch_size = config.batch_size
152
- for i in range(0, len(documents), batch_size):
153
- batch_docs = documents[i:i + batch_size]
154
- batch_scores = self._rerank_batch(query, batch_docs, config)
155
- all_scores.append(batch_scores)
156
-
157
- if clear_cache:
158
- mx.clear_cache()
159
-
160
- # End prompt processing, start decode
161
- self._prompt_end()
162
- self._decode_start()
163
-
164
- # Concatenate all batch scores into a single array
165
- res = mx.concatenate(all_scores, axis=0) if len(all_scores) > 1 else all_scores[0]
166
-
167
- # End decode and profiling
168
- self._decode_end()
169
- self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
170
- self._end_profiling()
171
-
172
- return res
173
-
174
- def _load_jina_model(self, model_dir: str) -> Model:
175
- """Initialize and load the Jina V2 rerank model."""
176
-
177
- # Validate that model path exists
178
- if not os.path.exists(model_dir):
179
- raise ValueError(f"Model path does not exist: {model_dir}")
180
-
181
- # Store model directory for tokenizer loading
182
- self._model_dir = model_dir
183
-
184
- # Create model config
185
- config = ModelArgs()
186
- model = Model(config)
187
-
188
- # Load weights
189
- weight_file = os.path.join(model_dir, "model.safetensors")
190
- if not os.path.exists(weight_file):
191
- # Try alternative naming patterns
192
- safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
193
- if safetensors_files:
194
- weight_file = os.path.join(model_dir, safetensors_files[0])
195
- else:
196
- raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
197
-
198
- model.load_weights(weight_file, strict=True)
199
- model.eval()
200
-
201
- return model
202
-
203
- def _load_tokenizer(self) -> AutoTokenizer:
204
- """Load and configure the tokenizer."""
205
- return AutoTokenizer.from_pretrained(self._model_dir)
206
-
207
- def _rerank_batch(self, query: str, documents: List[str], config: RerankConfig) -> mx.array:
208
- """Rerank a batch of documents and return their scores."""
209
- # Prepare inputs
210
- input_ids, attention_mask, token_type_ids, position_ids = self._prepare_inputs(
211
- query, documents, self.tokenizer, max_length=1024
212
- )
213
-
214
- # Run inference
215
- scores = self.model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
216
- scores = mx.squeeze(scores, axis=-1)
217
-
218
- # Apply normalization if requested
219
- if config.normalize:
220
- scores = self._normalize_scores(scores, config.normalize_method)
221
-
222
- return scores
223
-
224
- def _create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):
225
- """Create position ids from input ids, accounting for padding tokens"""
226
- mask = (input_ids != padding_idx).astype(mx.int32)
227
- incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
228
- return incremental_indices.astype(mx.int32) + padding_idx
229
-
230
- def _prepare_inputs(self, query, documents, tokenizer, max_length=1024):
231
- """Prepare inputs for the model - match torch exactly"""
232
- sentence_pairs = [[query, doc] for doc in documents]
233
- inputs = tokenizer(
234
- sentence_pairs,
235
- padding="max_length",
236
- truncation=True,
237
- return_tensors="np",
238
- max_length=max_length,
239
- )
240
-
241
- input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
242
- seqlen = input_ids.shape[1]
243
- attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
244
-
245
- # Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
246
- token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
247
- batch_size = input_ids.shape[0]
248
- token_type_ids = mx.broadcast_to(
249
- mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
250
- )
251
-
252
- # Create position ids for each sequence in the batch
253
- position_ids = self._create_position_ids_from_input_ids(input_ids, padding_idx=1)
254
-
255
- return input_ids, attention_mask, token_type_ids, position_ids
256
-
257
- def _normalize_scores(self, scores: mx.array, method: str) -> mx.array:
258
- """Normalize scores using specified method."""
259
- if method == "none":
260
- return scores
261
- elif method == "softmax":
262
- # For 1D arrays, use axis=0; for higher dims, use axis=-1
263
- if len(scores.shape) == 1:
264
- return mx.softmax(scores, axis=0)
265
- else:
266
- return mx.softmax(scores, axis=-1)
267
- elif method == "min-max":
268
- min_val = mx.min(scores)
269
- max_val = mx.max(scores)
270
- if max_val > min_val:
271
- return (scores - min_val) / (max_val - min_val)
272
- return scores
273
- else:
274
- return scores
275
-
276
-
277
- # Factory function for creating reranker instances
278
- def create_reranker(
279
- model_path: PathType,
280
- tokenizer_path: Optional[PathType] = None,
281
- device: Optional[str] = None,
282
- ) -> Reranker:
283
- """Create and return a Reranker instance."""
284
- if tokenizer_path is None:
285
- tokenizer_path = model_path
286
-
287
- return Reranker(model_path, tokenizer_path, device)
@@ -1,127 +0,0 @@
1
- # Copyright © Nexa AI
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import time
16
- import mlx.core as mx
17
- from .interface import create_reranker, RerankConfig
18
-
19
-
20
- def test_reranking():
21
- """Test reranking model functionality."""
22
- # Create reranker instance
23
- model_path = "nexaml/jina-v2-rerank-mlx"
24
- reranker = create_reranker(model_path=model_path)
25
-
26
- # Load the model
27
- print("Loading reranking model...")
28
- success = reranker.load_model(model_path, extra_data="nexaml/jina-v2-rerank-mlx")
29
-
30
- if not success:
31
- print("Failed to load model!")
32
- return
33
-
34
- print("✅ Model loaded successfully!")
35
-
36
- # Test query and documents (same as generate.py)
37
- query = "What are the health benefits of green tea?"
38
- documents = [
39
- "Green tea is rich in antioxidants and may improve brain function.",
40
- "Coffee contains caffeine and can boost energy levels.",
41
- "Das Trinken von grünem Tee kann das Risiko für Herzkrankheiten senken.",
42
- "Black tea is another popular beverage with its own health benefits.",
43
- ]
44
-
45
- # Configure reranking with no normalization to get raw scores
46
- config = RerankConfig(
47
- batch_size=len(documents),
48
- normalize=False,
49
- normalize_method="none"
50
- )
51
-
52
- # Generate reranking scores
53
- start_time = time.time()
54
- scores = reranker.rerank(query, documents, config)
55
- end_time = time.time()
56
-
57
- # Calculate sigmoid probabilities manually
58
- scores_sigmoid = mx.sigmoid(scores).tolist()
59
-
60
- inference_time = (end_time - start_time) * 1000 # Convert to ms
61
-
62
- print("=" * 70)
63
- print("Reranking Results:")
64
- print("=" * 70)
65
- print(f"Query: {query}")
66
- print()
67
-
68
- for i, (doc, score, prob) in enumerate(zip(documents, scores.tolist(), scores_sigmoid)):
69
- print(f"Document {i+1}:")
70
- print(f" Text: {doc}")
71
- print(f" Score: {score:.4f}")
72
- print(f" Probability: {prob:.4f}")
73
- print()
74
-
75
- print(f"Inference time: {inference_time:.1f}ms")
76
- print(f"Throughput: {len(documents)/inference_time*1000:.1f} docs/s")
77
-
78
- # Cleanup
79
- reranker.close()
80
-
81
-
82
- def main(model_id):
83
- """Main function to handle reranking demonstration - aligned with embedding generate.py format."""
84
- # Create reranker instance
85
- reranker = create_reranker(model_path=model_id)
86
-
87
- # Load the model
88
- success = reranker.load_model(model_id, extra_data=model_id)
89
-
90
- if not success:
91
- print("Failed to load model!")
92
- return
93
-
94
- # Simple test like embedding generate.py
95
- query = "What are the health benefits of green tea?"
96
- documents = [
97
- "Green tea is rich in antioxidants and may improve brain function.",
98
- "Coffee contains caffeine and can boost energy levels.",
99
- ]
100
-
101
- # Get raw scores
102
- config = RerankConfig(normalize=False)
103
- scores = reranker.rerank(query, documents, config)
104
-
105
- # Calculate statistics on raw MLX array
106
- scores_sigmoid = mx.sigmoid(scores)
107
-
108
- print(f"Scores shape: {scores.shape}")
109
- print(f"Score sample values: {scores.tolist()}")
110
- print(f"Scores min: {scores.min():.4f}, Max: {scores.max():.4f}, Mean: {scores.mean():.4f}, Std: {scores.std():.4f}")
111
- print(f"Sigmoid probabilities: {scores_sigmoid.tolist()}")
112
-
113
- # Cleanup
114
- reranker.close()
115
-
116
-
117
- if __name__ == "__main__":
118
- import argparse
119
- parser = argparse.ArgumentParser()
120
- parser.add_argument("--model_path", type=str, default="nexaml/jina-v2-rerank-mlx")
121
- args = parser.parse_args()
122
-
123
- # Use test_reranking for comprehensive test, main for simple format like generate.py
124
- if hasattr(args, 'simple') and args.simple:
125
- main(args.model_path)
126
- else:
127
- test_reranking()
File without changes