fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
mlx_vlm/server.py ADDED
@@ -0,0 +1,1107 @@
1
+ import argparse
2
+ import asyncio
3
+ import gc
4
+ import json
5
+ import os
6
+ import traceback
7
+ import uuid
8
+ from datetime import datetime
9
+ from typing import Any, List, Literal, Optional, Tuple, Union
10
+
11
+ import mlx.core as mx
12
+ import uvicorn
13
+ from fastapi import FastAPI, HTTPException, Request
14
+ from fastapi.responses import StreamingResponse
15
+ from huggingface_hub import scan_cache_dir
16
+ from pydantic import BaseModel, ConfigDict, Field
17
+ from typing_extensions import Required, TypeAlias, TypedDict
18
+
19
+ from .generate import (
20
+ DEFAULT_MAX_TOKENS,
21
+ DEFAULT_MODEL_PATH,
22
+ DEFAULT_SEED,
23
+ DEFAULT_TEMPERATURE,
24
+ DEFAULT_TOP_P,
25
+ generate,
26
+ stream_generate,
27
+ )
28
+ from .prompt_utils import apply_chat_template
29
+ from .utils import load
30
+ from .version import __version__
31
+
32
+ app = FastAPI(
33
+ title="MLX-VLM Inference API",
34
+ description="API for using Vision Language Models (VLMs) and Omni Models (Vision, Audio and Video support) with MLX.",
35
+ version=__version__,
36
+ )
37
+
38
+ MAX_IMAGES = 10 # Maximum number of images to process at once
39
+
40
+ # Loading/unloading utilities
41
+
42
+ model_cache = {}
43
+
44
+
45
+ class FlexibleBaseModel(BaseModel):
46
+ """Base model that ignores/accepts any unknown OpenAI SDK fields."""
47
+
48
+ model_config = ConfigDict(extra="allow")
49
+
50
+
51
+ def load_model_resources(model_path: str, adapter_path: Optional[str]):
52
+ """
53
+ Loads model, processor, and config based on paths.
54
+ Handles potential loading errors.
55
+ """
56
+ try:
57
+ print(f"Loading model from: {model_path}")
58
+ if adapter_path:
59
+ print(f"Loading adapter from: {adapter_path}")
60
+ # Use the load function from utils.py which handles path resolution and loading
61
+ trust_remote_code = (
62
+ os.environ.get("MLX_TRUST_REMOTE_CODE", "false").lower() == "true"
63
+ )
64
+ model, processor = load(
65
+ model_path, adapter_path, trust_remote_code=trust_remote_code
66
+ )
67
+ config = model.config
68
+ print("Model and processor loaded successfully.")
69
+ return model, processor, config
70
+ except Exception as e:
71
+ print(f"Error loading model {model_path}: {e}")
72
+ traceback.print_exc() # Print detailed traceback for debugging
73
+ raise HTTPException(status_code=500, detail=f"Failed to load model: {e}")
74
+
75
+
76
+ def get_cached_model(model_path: str, adapter_path: Optional[str] = None):
77
+ """
78
+ Factory function to get or load the appropriate model resources from cache or by loading.
79
+ """
80
+ global model_cache
81
+
82
+ cache_key = (model_path, adapter_path)
83
+
84
+ # Return from cache if already loaded and matches the requested paths
85
+ if model_cache.get("cache_key") == cache_key:
86
+ print(f"Using cached model: {model_path}, Adapter: {adapter_path}")
87
+ return model_cache["model"], model_cache["processor"], model_cache["config"]
88
+
89
+ # If cache exists but doesn't match, clear it
90
+ if model_cache:
91
+ print("New model request, clearing existing cache...")
92
+ unload_model_sync() # Use a synchronous version for internal call
93
+
94
+ # Load the model resources
95
+ model, processor, config = load_model_resources(model_path, adapter_path)
96
+
97
+ model_cache = {
98
+ "cache_key": cache_key,
99
+ "model_path": model_path,
100
+ "adapter_path": adapter_path,
101
+ "model": model,
102
+ "processor": processor,
103
+ "config": config,
104
+ }
105
+
106
+ return model, processor, config
107
+
108
+
109
+ # Synchronous unload function for internal use
110
+ def unload_model_sync():
111
+ global model_cache
112
+ if not model_cache:
113
+ return False
114
+
115
+ print(
116
+ f"Unloading model: {model_cache.get('model_path')}, Adapter: {model_cache.get('adapter_path')}"
117
+ )
118
+ # Clear references
119
+ model_cache = {}
120
+ # Force garbage collection
121
+ gc.collect()
122
+ mx.clear_cache()
123
+ print("Model unloaded and cache cleared.")
124
+ return True
125
+
126
+
127
+ # OpenAI API Models
128
+
129
+ # Models for /responses endpoint
130
+
131
+
132
+ class ResponseInputTextParam(TypedDict, total=False):
133
+ text: Required[str]
134
+ type: Required[
135
+ Literal["input_text", "text"]
136
+ ] # The type of the input item. Always `input_text`.
137
+
138
+
139
+ class ResponseInputImageParam(TypedDict, total=False):
140
+ detail: Literal["high", "low", "auto"] = Field(
141
+ "auto", description="The detail level of the image to be sent to the model."
142
+ )
143
+ """The detail level of the image to be sent to the model.
144
+
145
+ One of `high`, `low`, or `auto`. Defaults to `auto`.
146
+ """
147
+ type: Required[
148
+ Literal["input_image"]
149
+ ] # The type of the input item. Always `input_image`.
150
+ image_url: Required[str]
151
+ file_id: Optional[str]
152
+ """The ID of the file to be sent to the model.
153
+ NOTE : wouldn't this help the model if we passed the file_id as well to the vlm models
154
+ """
155
+
156
+
157
+ class InputAudio(TypedDict, total=False):
158
+ data: Required[str]
159
+ format: Required[str]
160
+
161
+
162
+ class ResponseInputAudioParam(TypedDict, total=False):
163
+ type: Required[
164
+ Literal["input_audio"]
165
+ ] # The type of the input item. Always `input_audio`.
166
+ input_audio: Required[InputAudio]
167
+
168
+
169
+ class ImageUrl(TypedDict, total=False):
170
+ url: Required[str]
171
+
172
+
173
+ class ResponseImageUrlParam(TypedDict, total=False):
174
+ type: Required[
175
+ Literal["image_url"]
176
+ ] # The type of the input item. Always`image_url`.
177
+ image_url: Required[ImageUrl]
178
+
179
+
180
+ ResponseInputContentParam: TypeAlias = Union[
181
+ ResponseInputTextParam,
182
+ ResponseInputImageParam,
183
+ ResponseImageUrlParam,
184
+ ResponseInputAudioParam,
185
+ ]
186
+
187
+ ResponseInputMessageContentListParam: TypeAlias = List[ResponseInputContentParam]
188
+
189
+
190
+ class ResponseOutputText(TypedDict, total=False):
191
+ text: Required[str]
192
+ type: Required[
193
+ Literal["output_text"]
194
+ ] # The type of the output item. Always `output_text`
195
+
196
+
197
+ ResponseOutputMessageContentList: TypeAlias = List[ResponseOutputText]
198
+
199
+
200
+ class ChatMessage(FlexibleBaseModel):
201
+ role: Literal["user", "assistant", "system", "developer"] = Field(
202
+ ...,
203
+ description="Role of the message sender (e.g., 'system', 'user', 'assistant').",
204
+ )
205
+ content: Union[
206
+ str, ResponseInputMessageContentListParam, ResponseOutputMessageContentList
207
+ ] = Field(..., description="Content of the message.")
208
+
209
+
210
+ class OpenAIRequest(FlexibleBaseModel):
211
+ """
212
+ OpenAI-compatible request structure.
213
+ Using this structure : https://github.com/openai/openai-python/blob/main/src/openai/resources/responses/responses.py
214
+ """
215
+
216
+ input: Union[str, List[ChatMessage]] = Field(
217
+ ..., description="Input text or list of chat messages."
218
+ )
219
+ model: str = Field(..., description="The model to use for generation.")
220
+ max_output_tokens: int = Field(
221
+ DEFAULT_MAX_TOKENS, description="Maximum number of tokens to generate."
222
+ )
223
+ temperature: float = Field(
224
+ DEFAULT_TEMPERATURE, description="Temperature for sampling."
225
+ )
226
+ top_p: float = Field(DEFAULT_TOP_P, description="Top-p sampling.")
227
+ stream: bool = Field(
228
+ False, description="Whether to stream the response chunk by chunk."
229
+ )
230
+
231
+
232
+ class OpenAIUsage(BaseModel):
233
+ """Token usage details including input tokens, output tokens, breakdown, and total tokens used."""
234
+
235
+ input_tokens: int
236
+ output_tokens: int
237
+ total_tokens: int
238
+
239
+
240
+ class OpenAIErrorObject(BaseModel):
241
+ """Error object returned when the model fails to generate a Response."""
242
+
243
+ code: Optional[str] = None
244
+ message: Optional[str] = None
245
+ param: Optional[str] = None
246
+ type: Optional[str] = None
247
+
248
+
249
+ class OpenAIResponse(BaseModel):
250
+ id: str = Field(..., description="Unique identifier for this Response")
251
+ object: Literal["response"] = Field(
252
+ ..., description="The object type of this resource - always set to response"
253
+ )
254
+ created_at: int = Field(
255
+ ..., description="Unix timestamp (in seconds) of when this Response was created"
256
+ )
257
+ status: Literal["completed", "failed", "in_progress", "incomplete"] = Field(
258
+ ..., description="The status of the response generation"
259
+ )
260
+ error: Optional[OpenAIErrorObject] = Field(
261
+ None,
262
+ description="An error object returned when the model fails to generate a Response",
263
+ )
264
+ instructions: Optional[str] = Field(
265
+ None,
266
+ description="Inserts a system (or developer) message as the first item in the model's context",
267
+ )
268
+ max_output_tokens: Optional[int] = Field(
269
+ None,
270
+ description="An upper bound for the number of tokens that can be generated for a response",
271
+ )
272
+ model: str = Field(..., description="Model ID used to generate the response")
273
+ output: List[Union[ChatMessage, Any]] = Field(
274
+ ..., description="An array of content items generated by the model"
275
+ )
276
+ output_text: Optional[str] = Field(
277
+ None,
278
+ description="SDK-only convenience property containing aggregated text output",
279
+ )
280
+ temperature: Optional[float] = Field(
281
+ None, ge=0, le=2, description="Sampling temperature between 0 and 2"
282
+ )
283
+ top_p: Optional[float] = Field(
284
+ None, ge=0, le=1, description="Nucleus sampling probability mass"
285
+ )
286
+ truncation: Union[Literal["auto", "disabled"], str] = Field(
287
+ "disabled", description="The truncation strategy to use"
288
+ )
289
+ usage: OpenAIUsage = Field(
290
+ ..., description="Token usage details"
291
+ ) # we need the model to return stats
292
+ user: Optional[str] = Field(
293
+ None, description="A unique identifier representing your end-user"
294
+ )
295
+
296
+
297
+ class BaseStreamEvent(BaseModel):
298
+ type: str
299
+
300
+
301
+ class ContentPartOutputText(BaseModel):
302
+ type: Literal["output_text"]
303
+ text: str
304
+ annotations: List[str] = []
305
+
306
+
307
+ class MessageItem(BaseModel):
308
+ id: str
309
+ type: Literal["message"]
310
+ status: Literal["in_progress", "completed"]
311
+ role: str
312
+ content: List[ContentPartOutputText] = []
313
+
314
+
315
+ class ResponseCreatedEvent(BaseStreamEvent):
316
+ type: Literal["response.created"]
317
+ response: OpenAIResponse
318
+
319
+
320
+ class ResponseInProgressEvent(BaseStreamEvent):
321
+ type: Literal["response.in_progress"]
322
+ response: OpenAIResponse
323
+
324
+
325
+ class ResponseOutputItemAddedEvent(BaseStreamEvent):
326
+ type: Literal["response.output_item.added"]
327
+ output_index: int
328
+ item: MessageItem
329
+
330
+
331
+ class ResponseContentPartAddedEvent(BaseStreamEvent):
332
+ type: Literal["response.content_part.added"]
333
+ item_id: str
334
+ output_index: int
335
+ content_index: int
336
+ part: ContentPartOutputText
337
+
338
+
339
+ class ResponseOutputTextDeltaEvent(BaseStreamEvent):
340
+ type: Literal["response.output_text.delta"]
341
+ item_id: str
342
+ output_index: int
343
+ content_index: int
344
+ delta: str
345
+
346
+
347
+ class ResponseOutputTextDoneEvent(BaseStreamEvent):
348
+ type: Literal["response.output_text.done"]
349
+ item_id: str
350
+ output_index: int
351
+ content_index: int
352
+ text: str
353
+
354
+
355
+ class ResponseContentPartDoneEvent(BaseStreamEvent):
356
+ type: Literal["response.content_part.done"]
357
+ item_id: str
358
+ output_index: int
359
+ content_index: int
360
+ part: ContentPartOutputText
361
+
362
+
363
+ class ResponseOutputItemDoneEvent(BaseStreamEvent):
364
+ type: Literal["response.output_item.done"]
365
+ output_index: int
366
+ item: MessageItem
367
+
368
+
369
+ class ResponseCompletedEvent(BaseStreamEvent):
370
+ type: Literal["response.completed"]
371
+ response: OpenAIResponse
372
+
373
+
374
+ StreamEvent = Union[
375
+ ResponseCreatedEvent,
376
+ ResponseInProgressEvent,
377
+ ResponseOutputItemAddedEvent,
378
+ ResponseContentPartAddedEvent,
379
+ ResponseOutputTextDeltaEvent,
380
+ ResponseOutputTextDoneEvent,
381
+ ResponseContentPartDoneEvent,
382
+ ResponseOutputItemDoneEvent,
383
+ ResponseCompletedEvent,
384
+ ]
385
+
386
+ # Models for /chat/completion endpoint
387
+
388
+
389
+ class VLMRequest(FlexibleBaseModel):
390
+ model: str = Field(
391
+ DEFAULT_MODEL_PATH,
392
+ description="The path to the local model directory or Hugging Face repo.",
393
+ )
394
+ adapter_path: Optional[str] = Field(
395
+ None, description="The path to the adapter weights."
396
+ )
397
+ max_tokens: int = Field(
398
+ DEFAULT_MAX_TOKENS, description="Maximum number of tokens to generate."
399
+ )
400
+ temperature: float = Field(
401
+ DEFAULT_TEMPERATURE, description="Temperature for sampling."
402
+ )
403
+ top_p: float = Field(DEFAULT_TOP_P, description="Top-p sampling.")
404
+ seed: int = Field(DEFAULT_SEED, description="Seed for random generation.")
405
+ resize_shape: Optional[Tuple[int, int]] = Field(
406
+ None,
407
+ description="Resize shape for the image (height, width). Provide two integers.",
408
+ )
409
+
410
+
411
+ class GenerationRequest(VLMRequest):
412
+ """
413
+ Inherits from VLMRequest and adds additional fields for the generation request.
414
+ """
415
+
416
+ stream: bool = Field(
417
+ False, description="Whether to stream the response chunk by chunk."
418
+ )
419
+
420
+
421
+ class UsageStats(OpenAIUsage):
422
+ """
423
+ Inherits from OpenAIUsage and adds additional fields for usage statistics.
424
+ """
425
+
426
+ prompt_tps: float = Field(..., description="Tokens per second for the prompt.")
427
+ generation_tps: float = Field(
428
+ ..., description="Tokens per second for the generation."
429
+ )
430
+ peak_memory: float = Field(
431
+ ..., description="Peak memory usage during the generation."
432
+ )
433
+
434
+
435
+ class ChatRequest(GenerationRequest):
436
+ messages: List[ChatMessage]
437
+
438
+
439
+ class ChatChoice(BaseModel):
440
+ finish_reason: str
441
+ message: ChatMessage
442
+
443
+
444
+ class ChatResponse(BaseModel):
445
+ model: str
446
+ choices: List[ChatChoice]
447
+ usage: Optional[UsageStats]
448
+
449
+
450
+ class ChatStreamChoice(BaseModel):
451
+ finish_reason: Optional[str] = None
452
+ delta: ChatMessage
453
+
454
+
455
+ class ChatStreamChunk(BaseModel):
456
+ model: str
457
+ choices: List[ChatStreamChoice]
458
+ usage: Optional[UsageStats]
459
+
460
+
461
+ # Models for /models endpoint
462
+
463
+
464
+ class ModelInfo(BaseModel):
465
+ id: str
466
+ object: str
467
+ created: int
468
+
469
+
470
+ class ModelsResponse(BaseModel):
471
+ object: Literal["list"]
472
+ data: List[ModelInfo]
473
+
474
+
475
+ # OpenAI compatile endpoints
476
+
477
+
478
+ @app.post("/responses")
479
+ async def responses_endpoint(request: Request):
480
+ """
481
+ OpenAI-compatible endpoint for generating text based on a prompt and optional images.
482
+
483
+ using client.responses.create method.
484
+
485
+ example:
486
+
487
+ from openai import OpenAI
488
+
489
+ API_URL = "http://0.0.0.0:8000"
490
+ API_KEY = 'any'
491
+
492
+ def run_openai(prompt, img_url,system, stream=False, max_output_tokens=512, model="mlx-community/Qwen2.5-VL-3B-Instruct-8bit"):
493
+ ''' Calls the OpenAI API
494
+ '''
495
+
496
+ client = OpenAI(base_url=f"{API_URL}", api_key=API_KEY)
497
+
498
+ try :
499
+ response = client.responses.create(
500
+ model=model,
501
+ input=[
502
+ {"role":"system",
503
+ "content": f"{system}"
504
+ },
505
+ {
506
+ "role": "user",
507
+ "content": [
508
+ {"type": "input_text", "text": prompt},
509
+ {"type": "input_image", "image_url": f"{img_url}"},
510
+ ],
511
+ }
512
+ ],
513
+ max_output_tokens=max_output_tokens,
514
+ stream=stream
515
+ )
516
+ if not stream:
517
+ print(response.output[0].content[0].text)
518
+ print(response.usage)
519
+ else:
520
+ for event in response:
521
+ # Process different event types if needed
522
+ if hasattr(event, 'delta') and event.delta:
523
+ print(event.delta, end="", flush=True)
524
+ elif event.type == 'response.completed':
525
+ print("\n--- Usage ---")
526
+ print(event.response.usage)
527
+
528
+ except Exception as e:
529
+ # building a response object to match the one returned when request is successful so that it can be processed in the same way
530
+ return {"model - error":str(e),"content":{}, "model":model}
531
+
532
+ """
533
+
534
+ body = await request.json()
535
+ openai_request = OpenAIRequest(**body)
536
+
537
+ try:
538
+ # Get model, processor, config - loading if necessary
539
+ model, processor, config = get_cached_model(openai_request.model)
540
+
541
+ kwargs = {}
542
+
543
+ chat_messages = []
544
+ images = []
545
+ instructions = None
546
+ if openai_request.input:
547
+ if isinstance(openai_request.input, str):
548
+ # If input is a string, treat it as a single text message
549
+ chat_messages.append({"role": "user", "content": openai_request.input})
550
+ elif isinstance(openai_request.input, list):
551
+ # If input is a list, treat it as a series of chat messages
552
+ for message in openai_request.input:
553
+ if isinstance(message, ChatMessage):
554
+ if isinstance(message.content, str):
555
+ chat_messages.append(
556
+ {"role": message.role, "content": message.content}
557
+ )
558
+ if message.role == "system":
559
+ instructions = message.content
560
+ elif isinstance(message.content, list):
561
+ # Handle list of content items
562
+ for item in message.content:
563
+ if isinstance(item, dict):
564
+ if item["type"] == "input_text":
565
+ chat_messages.append(
566
+ {
567
+ "role": message.role,
568
+ "content": item["text"],
569
+ }
570
+ )
571
+ if message.role == "system":
572
+ instructions = item["text"]
573
+ # examples for multiple images (https://platform.openai.com/docs/guides/images?api-mode=responses)
574
+ elif item["type"] == "input_image":
575
+ images.append(item["image_url"])
576
+ else:
577
+ print(
578
+ f"invalid input item type: {item['type']}"
579
+ )
580
+ raise HTTPException(
581
+ status_code=400,
582
+ detail="Invalid input item type.",
583
+ )
584
+ else:
585
+ print(
586
+ f"Invalid message content item format: {item}"
587
+ )
588
+ raise HTTPException(
589
+ status_code=400,
590
+ detail="Missing type in input item.",
591
+ )
592
+ else:
593
+ print("Invalid message content format.")
594
+ raise HTTPException(
595
+ status_code=400, detail="Invalid input format."
596
+ )
597
+ else:
598
+ print("not a ChatMessage")
599
+ raise HTTPException(
600
+ status_code=400, detail="Invalid input format."
601
+ )
602
+ else:
603
+ print("neither string not list")
604
+ raise HTTPException(status_code=400, detail="Invalid input format.")
605
+
606
+ else:
607
+ print("no input")
608
+ raise HTTPException(status_code=400, detail="Missing input.")
609
+
610
+ formatted_prompt = apply_chat_template(
611
+ processor, config, chat_messages, num_images=len(images)
612
+ )
613
+
614
+ generated_at = datetime.now().timestamp()
615
+ response_id = f"resp_{uuid.uuid4().hex}"
616
+ message_id = f"msg_{uuid.uuid4().hex}"
617
+
618
+ if openai_request.stream:
619
+ # Streaming response
620
+ async def stream_generator():
621
+ token_iterator = None
622
+ try:
623
+ # Create base response object (to match the openai pipeline)
624
+ base_response = OpenAIResponse(
625
+ id=response_id,
626
+ object="response",
627
+ created_at=int(generated_at),
628
+ status="in_progress",
629
+ instructions=instructions,
630
+ max_output_tokens=openai_request.max_output_tokens,
631
+ model=openai_request.model,
632
+ output=[],
633
+ output_text="",
634
+ temperature=openai_request.temperature,
635
+ top_p=openai_request.top_p,
636
+ usage={
637
+ "input_tokens": 0, # get prompt tokens
638
+ "output_tokens": 0,
639
+ "total_tokens": 0,
640
+ },
641
+ )
642
+
643
+ # Send response.created event (to match the openai pipeline)
644
+ yield f"event: response.created\ndata: {ResponseCreatedEvent(type='response.created', response=base_response).model_dump_json()}\n\n"
645
+
646
+ # Send response.in_progress event (to match the openai pipeline)
647
+ yield f"event: response.in_progress\ndata: {ResponseInProgressEvent(type='response.in_progress', response=base_response).model_dump_json()}\n\n"
648
+
649
+ # Send response.output_item.added event (to match the openai pipeline)
650
+ message_item = MessageItem(
651
+ id=message_id,
652
+ type="message",
653
+ status="in_progress",
654
+ role="assistant",
655
+ content=[],
656
+ )
657
+ yield f"event: response.output_item.added\ndata: {ResponseOutputItemAddedEvent(type='response.output_item.added', output_index=0, item=message_item).model_dump_json()}\n\n"
658
+
659
+ # Send response.content_part.added event
660
+ content_part = ContentPartOutputText(
661
+ type="output_text", text="", annotations=[]
662
+ )
663
+ yield f"event: response.content_part.added\ndata: {ResponseContentPartAddedEvent(type='response.content_part.added', item_id=message_id, output_index=0, content_index=0, part=content_part).model_dump_json()}\n\n"
664
+
665
+ # Stream text deltas
666
+ token_iterator = stream_generate(
667
+ model=model,
668
+ processor=processor,
669
+ prompt=formatted_prompt,
670
+ image=images,
671
+ temperature=openai_request.temperature,
672
+ max_tokens=openai_request.max_output_tokens,
673
+ top_p=openai_request.top_p,
674
+ **kwargs,
675
+ )
676
+
677
+ full_text = ""
678
+ for chunk in token_iterator:
679
+ if chunk is None or not hasattr(chunk, "text"):
680
+ continue
681
+
682
+ delta = chunk.text
683
+ full_text += delta
684
+
685
+ usage_stats = {
686
+ "input_tokens": chunk.prompt_tokens,
687
+ "output_tokens": chunk.generation_tokens,
688
+ }
689
+
690
+ # Send response.output_text.delta event
691
+ yield f"event: response.output_text.delta\ndata: {ResponseOutputTextDeltaEvent(type='response.output_text.delta', item_id=message_id, output_index=0, content_index=0, delta=delta).model_dump_json()}\n\n"
692
+ await asyncio.sleep(0.01)
693
+
694
+ # Send response.output_text.done event (to match the openai pipeline)
695
+ yield f"event: response.output_text.done\ndata: {ResponseOutputTextDoneEvent(type='response.output_text.done', item_id=message_id, output_index=0, content_index=0, text=full_text).model_dump_json()}\n\n"
696
+
697
+ # Send response.content_part.done event (to match the openai pipeline)
698
+ final_content_part = ContentPartOutputText(
699
+ type="output_text", text=full_text, annotations=[]
700
+ )
701
+ yield f"event: response.content_part.done\ndata: {ResponseContentPartDoneEvent(type='response.content_part.done', item_id=message_id, output_index=0, content_index=0, part=final_content_part).model_dump_json()}\n\n"
702
+
703
+ # Send response.output_item.done event (to match the openai pipeline)
704
+ final_message_item = MessageItem(
705
+ id=message_id,
706
+ type="message",
707
+ status="completed",
708
+ role="assistant",
709
+ content=[final_content_part],
710
+ )
711
+ yield f"event: response.output_item.done\ndata: {ResponseOutputItemDoneEvent(type='response.output_item.done', output_index=0, item=final_message_item).model_dump_json()}\n\n"
712
+
713
+ # Send response.completed event (to match the openai pipeline)
714
+ completed_response = base_response.model_copy(
715
+ update={
716
+ "status": "completed",
717
+ "output": [final_message_item],
718
+ "usage": {
719
+ "input_tokens": usage_stats["input_tokens"],
720
+ "output_tokens": usage_stats["output_tokens"],
721
+ "total_tokens": usage_stats["input_tokens"]
722
+ + usage_stats["output_tokens"],
723
+ },
724
+ }
725
+ )
726
+ yield f"event: response.completed\ndata: {ResponseCompletedEvent(type='response.completed', response=completed_response).model_dump_json()}\n\n"
727
+
728
+ except Exception as e:
729
+ print(f"Error during stream generation: {e}")
730
+ traceback.print_exc()
731
+ error_data = json.dumps({"error": str(e)})
732
+ yield f"data: {error_data}\n\n"
733
+
734
+ finally:
735
+ mx.clear_cache()
736
+ gc.collect()
737
+ print("Stream finished, cleared cache.")
738
+
739
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
740
+
741
+ else:
742
+ # Non-streaming response
743
+ try:
744
+ # Use generate from generate.py
745
+ result = generate(
746
+ model=model,
747
+ processor=processor,
748
+ prompt=formatted_prompt,
749
+ image=images,
750
+ temperature=openai_request.temperature,
751
+ max_tokens=openai_request.max_output_tokens,
752
+ top_p=openai_request.top_p,
753
+ verbose=False, # stats are passed in the response
754
+ **kwargs,
755
+ )
756
+ # Clean up resources
757
+ mx.clear_cache()
758
+ gc.collect()
759
+ print("Generation finished, cleared cache.")
760
+
761
+ response = OpenAIResponse(
762
+ id=response_id,
763
+ object="response",
764
+ created_at=int(generated_at),
765
+ status="completed",
766
+ instructions=instructions,
767
+ max_output_tokens=openai_request.max_output_tokens,
768
+ model=openai_request.model,
769
+ output=[
770
+ {
771
+ "role": "assistant",
772
+ "content": [
773
+ {
774
+ "type": "output_text",
775
+ "text": result.text,
776
+ }
777
+ ],
778
+ }
779
+ ],
780
+ output_text=result.text,
781
+ temperature=openai_request.temperature,
782
+ top_p=openai_request.top_p,
783
+ usage={
784
+ "input_tokens": result.prompt_tokens,
785
+ "output_tokens": result.generation_tokens,
786
+ "total_tokens": result.total_tokens,
787
+ },
788
+ )
789
+ return response
790
+
791
+ except Exception as e:
792
+ print(f"Error during generation: {e}")
793
+ traceback.print_exc()
794
+ mx.clear_cache()
795
+ gc.collect()
796
+ raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
797
+
798
+ except HTTPException as http_exc:
799
+ # Re-raise HTTP exceptions (like model loading failure)
800
+ raise http_exc
801
+ except Exception as e:
802
+ # Catch unexpected errors
803
+ print(f"Unexpected error in /responses endpoint: {e}")
804
+ traceback.print_exc()
805
+ mx.clear_cache()
806
+ gc.collect()
807
+ raise HTTPException(
808
+ status_code=500, detail=f"An unexpected error occurred: {e}"
809
+ )
810
+
811
+
812
+ @app.post(
813
+ "/chat/completions", response_model=None
814
+ ) # Response model handled dynamically based on stream flag
815
+ async def chat_completions_endpoint(request: ChatRequest):
816
+ """
817
+ Generate text based on a prompt and optional images.
818
+ Prompt must be a list of chat messages, including system, user, and assistant messages.
819
+ System message will be ignored if not already in the prompt.
820
+ Can operate in streaming or non-streaming mode.
821
+ """
822
+
823
+ try:
824
+ # Get model, processor, config - loading if necessary
825
+ model, processor, config = get_cached_model(request.model, request.adapter_path)
826
+
827
+ kwargs = {}
828
+
829
+ if request.resize_shape is not None:
830
+ if len(request.resize_shape) not in [1, 2]:
831
+ raise HTTPException(
832
+ status_code=400,
833
+ detail="resize_shape must contain exactly two integers (height, width)",
834
+ )
835
+ kwargs["resize_shape"] = (
836
+ (request.resize_shape[0],) * 2
837
+ if len(request.resize_shape) == 1
838
+ else tuple(request.resize_shape)
839
+ )
840
+
841
+ chat_messages = request.messages
842
+
843
+ images = []
844
+ audio = []
845
+ processed_messages = []
846
+ for message in request.messages:
847
+ if isinstance(message.content, str):
848
+ processed_messages.append(
849
+ {"role": message.role, "content": message.content}
850
+ )
851
+ elif isinstance(message.content, list):
852
+ text_content = ""
853
+ for item in message.content:
854
+ if isinstance(item, dict):
855
+ # Only extract images/audio from user messages
856
+ if message.role == "user":
857
+ if item["type"] == "input_image":
858
+ images.append(item["image_url"])
859
+ elif item["type"] == "image_url":
860
+ images.append(item["image_url"]["url"])
861
+ elif item["type"] == "input_audio":
862
+ audio.append(item["input_audio"]["data"])
863
+ if item["type"] in ("text", "input_text"):
864
+ text_content = item.get("text", "")
865
+ processed_messages.append(
866
+ {"role": message.role, "content": text_content}
867
+ )
868
+
869
+ formatted_prompt = apply_chat_template(
870
+ processor,
871
+ config,
872
+ processed_messages,
873
+ num_images=len(images),
874
+ num_audios=len(audio),
875
+ )
876
+
877
+ if request.stream:
878
+ # Streaming response
879
+ async def stream_generator():
880
+ token_iterator = None
881
+ try:
882
+ # Use stream_generate from utils
883
+ token_iterator = stream_generate(
884
+ model=model,
885
+ processor=processor,
886
+ prompt=formatted_prompt,
887
+ image=images,
888
+ audio=audio,
889
+ temperature=request.temperature,
890
+ max_tokens=request.max_tokens,
891
+ top_p=request.top_p,
892
+ **kwargs,
893
+ )
894
+
895
+ for chunk in token_iterator:
896
+ if chunk is None or not hasattr(chunk, "text"):
897
+ print("Warning: Received unexpected chunk format:", chunk)
898
+ continue
899
+
900
+ # Yield chunks in Server-Sent Events (SSE) format
901
+ usage_stats = {
902
+ "input_tokens": chunk.prompt_tokens,
903
+ "output_tokens": chunk.generation_tokens,
904
+ "total_tokens": chunk.prompt_tokens
905
+ + chunk.generation_tokens,
906
+ "prompt_tps": chunk.prompt_tps,
907
+ "generation_tps": chunk.generation_tps,
908
+ "peak_memory": chunk.peak_memory,
909
+ }
910
+
911
+ choices = [
912
+ ChatStreamChoice(
913
+ delta=ChatMessage(role="assistant", content=chunk.text)
914
+ )
915
+ ]
916
+ chunk_data = ChatStreamChunk(
917
+ model=request.model, usage=usage_stats, choices=choices
918
+ )
919
+
920
+ yield f"data: {chunk_data.model_dump_json()}\n\n"
921
+ await asyncio.sleep(
922
+ 0.01
923
+ ) # Small sleep to prevent blocking event loop entirely
924
+
925
+ # Signal stream end
926
+ choices = [
927
+ ChatStreamChoice(
928
+ finish_reason="stop",
929
+ delta=ChatMessage(role="assistant", content=""),
930
+ )
931
+ ]
932
+ chunk_data = ChatStreamChunk(
933
+ model=request.model, usage=usage_stats, choices=choices
934
+ )
935
+ yield f"data: {chunk_data.model_dump_json()}\n\n"
936
+
937
+ except Exception as e:
938
+ print(f"Error during stream generation: {e}")
939
+ traceback.print_exc()
940
+ error_data = json.dumps({"error": str(e)})
941
+ yield f"data: {error_data}\n\n"
942
+
943
+ finally:
944
+ mx.clear_cache()
945
+ gc.collect()
946
+ print("Stream finished, cleared cache.")
947
+
948
+ return StreamingResponse(stream_generator(), media_type="text/event-stream")
949
+
950
+ else:
951
+ # Non-streaming response
952
+ try:
953
+ # Use generate from generate.py
954
+ gen_result = generate(
955
+ model=model,
956
+ processor=processor,
957
+ prompt=formatted_prompt,
958
+ image=images,
959
+ audio=audio,
960
+ temperature=request.temperature,
961
+ max_tokens=request.max_tokens,
962
+ top_p=request.top_p,
963
+ verbose=False, # Keep API output clean
964
+ **kwargs,
965
+ )
966
+ # Clean up resources
967
+ mx.clear_cache()
968
+ gc.collect()
969
+ print("Generation finished, cleared cache.")
970
+
971
+ usage_stats = UsageStats(
972
+ input_tokens=gen_result.prompt_tokens,
973
+ output_tokens=gen_result.generation_tokens,
974
+ total_tokens=gen_result.total_tokens,
975
+ prompt_tps=gen_result.prompt_tps,
976
+ generation_tps=gen_result.generation_tps,
977
+ peak_memory=gen_result.peak_memory,
978
+ )
979
+
980
+ choices = [
981
+ ChatChoice(
982
+ finish_reason="stop",
983
+ message=ChatMessage(role="assistant", content=gen_result.text),
984
+ )
985
+ ]
986
+ result = ChatResponse(
987
+ model=request.model, usage=usage_stats, choices=choices
988
+ )
989
+
990
+ return result
991
+
992
+ except Exception as e:
993
+ print(f"Error during generation: {e}")
994
+ traceback.print_exc()
995
+ mx.clear_cache()
996
+ gc.collect()
997
+ raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
998
+
999
+ except HTTPException as http_exc:
1000
+ # Re-raise HTTP exceptions (like model loading failure)
1001
+ raise http_exc
1002
+ except Exception as e:
1003
+ # Catch unexpected errors
1004
+ print(f"Unexpected error in /generate endpoint: {e}")
1005
+ traceback.print_exc()
1006
+ mx.clear_cache()
1007
+ gc.collect()
1008
+ raise HTTPException(
1009
+ status_code=500, detail=f"An unexpected error occurred: {e}"
1010
+ )
1011
+
1012
+
1013
+ @app.get("/models", response_model=ModelsResponse)
1014
+ def models_endpoint():
1015
+ """
1016
+ Return list of locally downloaded MLX models.
1017
+ """
1018
+
1019
+ files = ["config.json", "model.safetensors.index.json", "tokenizer_config.json"]
1020
+
1021
+ def probably_mlx_lm(repo):
1022
+ if repo.repo_type != "model":
1023
+ return False
1024
+ if "main" not in repo.refs:
1025
+ return False
1026
+ file_names = {f.file_path.name for f in repo.refs["main"].files}
1027
+ return all(f in file_names for f in files)
1028
+
1029
+ # Scan the cache directory for downloaded mlx models
1030
+ hf_cache_info = scan_cache_dir()
1031
+ downloaded_models = [repo for repo in hf_cache_info.repos if probably_mlx_lm(repo)]
1032
+
1033
+ # Create a list of available models
1034
+ models = [
1035
+ {"id": repo.repo_id, "object": "model", "created": int(repo.last_modified)}
1036
+ for repo in downloaded_models
1037
+ ]
1038
+
1039
+ response = {"object": "list", "data": models}
1040
+
1041
+ return response
1042
+
1043
+
1044
+ # MLX_VLM API endpoints
1045
+
1046
+
1047
+ @app.get("/health")
1048
+ async def health_check():
1049
+ """
1050
+ Check if the server is healthy and what model is loaded.
1051
+ """
1052
+ return {
1053
+ "status": "healthy",
1054
+ "loaded_model": model_cache.get("model_path", None),
1055
+ "loaded_adapter": model_cache.get("adapter_path", None),
1056
+ }
1057
+
1058
+
1059
+ @app.post("/unload")
1060
+ async def unload_model_endpoint():
1061
+ """
1062
+ Unload the currently loaded model from memory.
1063
+ """
1064
+ unloaded_info = {
1065
+ "model_name": model_cache.get("model_path", None),
1066
+ "adapter_name": model_cache.get("adapter_path", None),
1067
+ }
1068
+
1069
+ if not unload_model_sync(): # Use the synchronous unload function
1070
+ return {"status": "no_model_loaded", "message": "No model is currently loaded"}
1071
+
1072
+ return {
1073
+ "status": "success",
1074
+ "message": f"Model unloaded successfully",
1075
+ "unloaded": unloaded_info,
1076
+ }
1077
+
1078
+
1079
+ def main():
1080
+ parser = argparse.ArgumentParser(description="MLX VLM Http Server.")
1081
+ parser.add_argument(
1082
+ "--host",
1083
+ type=str,
1084
+ default="0.0.0.0",
1085
+ help="Host for the HTTP server (default:0.0.0.0)",
1086
+ )
1087
+ parser.add_argument(
1088
+ "--port",
1089
+ type=int,
1090
+ default=8080,
1091
+ help="Port for the HTTP server (default: 8080)",
1092
+ )
1093
+ parser.add_argument(
1094
+ "--trust-remote-code",
1095
+ action="store_true",
1096
+ help="Trust remote code when loading models from Hugging Face Hub.",
1097
+ )
1098
+ args = parser.parse_args()
1099
+ if args.trust_remote_code:
1100
+ os.environ["MLX_TRUST_REMOTE_CODE"] = "true"
1101
+ uvicorn.run(
1102
+ "mlx_vlm.server:app", host=args.host, port=args.port, workers=1, reload=True
1103
+ ) # reload=True for development to automatically restart on code changes.
1104
+
1105
+
1106
+ if __name__ == "__main__":
1107
+ main()