fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
mlx_vlm/chat_ui.py ADDED
@@ -0,0 +1,508 @@
1
+ import argparse
2
+ import gc
3
+ import json
4
+ import threading
5
+
6
+ import gradio as gr
7
+ import mlx.core as mx
8
+
9
+ from mlx_vlm import load
10
+
11
+ from .generate import stream_generate
12
+ from .prompt_utils import get_chat_template, get_message_json
13
+ from .utils import load_config, load_image_processor
14
+
15
+
16
+ def parse_arguments():
17
+ parser = argparse.ArgumentParser(
18
+ description="Generate text from an image using a model."
19
+ )
20
+ parser.add_argument(
21
+ "--model",
22
+ type=str,
23
+ default="qnguyen3/nanoLLaVA",
24
+ help="The path to the local model directory or Hugging Face repo.",
25
+ )
26
+ return parser.parse_args()
27
+
28
+
29
+ # Global state for model
30
+ class ModelState:
31
+ def __init__(self):
32
+ self.model = None
33
+ self.processor = None
34
+ self.config = None
35
+ self.image_processor = None
36
+ self.current_model_name = None
37
+
38
+ def load(self, model_name):
39
+ """Load a model, clearing previous one from memory."""
40
+ # Clear previous model from memory
41
+ if self.model is not None:
42
+ del self.model
43
+ del self.processor
44
+ del self.config
45
+ del self.image_processor
46
+ mx.metal.clear_cache()
47
+ gc.collect()
48
+
49
+ # Load new model
50
+ self.config = load_config(model_name)
51
+ self.model, self.processor = load(
52
+ model_name, processor_kwargs={"trust_remote_code": True}
53
+ )
54
+ self.image_processor = load_image_processor(model_name)
55
+ self.current_model_name = model_name
56
+
57
+
58
+ state = ModelState()
59
+
60
+ # Parse args and load initial model
61
+ args = parse_arguments()
62
+ state.load(args.model)
63
+
64
+ # Use most of the viewport for conversation
65
+ chatbot_height = "clamp(380px, calc(100vh - 450px), 820px)"
66
+
67
+ # Global flag for stopping generation
68
+ stop_generation = threading.Event()
69
+
70
+
71
+ def get_cached_vlm_models():
72
+ """Scan HF cache for vision-capable models."""
73
+ try:
74
+ from huggingface_hub import scan_cache_dir
75
+
76
+ vlm_models = []
77
+ cache_info = scan_cache_dir()
78
+
79
+ for repo in cache_info.repos:
80
+ if repo.repo_type != "model":
81
+ continue
82
+
83
+ # Check for refs
84
+ refs = getattr(repo, "refs", {})
85
+ if not refs or "main" not in refs:
86
+ # Try revisions instead
87
+ revisions = getattr(repo, "revisions", None)
88
+ if revisions:
89
+ for rev in revisions:
90
+ snapshot_path = getattr(rev, "snapshot_path", None)
91
+ if snapshot_path:
92
+ config_path = snapshot_path / "config.json"
93
+ if config_path.exists():
94
+ try:
95
+ with open(config_path, "r") as f:
96
+ config = json.load(f)
97
+ if "vision_config" in config:
98
+ vlm_models.append(repo.repo_id)
99
+ break
100
+ except Exception:
101
+ pass
102
+ continue
103
+
104
+ # Check config.json for vision_config
105
+ main_ref = refs["main"]
106
+ snapshot_path = getattr(main_ref, "snapshot_path", None)
107
+ if snapshot_path:
108
+ config_path = snapshot_path / "config.json"
109
+ if config_path.exists():
110
+ try:
111
+ with open(config_path, "r") as f:
112
+ config = json.load(f)
113
+ if "vision_config" in config:
114
+ vlm_models.append(repo.repo_id)
115
+ except Exception:
116
+ pass
117
+
118
+ # Ensure current model is in the list
119
+ if state.current_model_name and state.current_model_name not in vlm_models:
120
+ vlm_models.insert(0, state.current_model_name)
121
+
122
+ return sorted(set(vlm_models))
123
+ except Exception as e:
124
+ print(f"Error scanning cache: {e}")
125
+ # Return at least the current model
126
+ return [state.current_model_name] if state.current_model_name else []
127
+
128
+
129
+ def load_model_by_name(model_name, progress=gr.Progress()):
130
+ """Load a model and return status."""
131
+ if not model_name:
132
+ return "✓ Loaded", gr.update()
133
+
134
+ if model_name == state.current_model_name:
135
+ return "✓ Loaded", gr.update()
136
+
137
+ try:
138
+ progress(0.1, desc="Clearing memory...")
139
+ progress(0.3, desc="Loading...")
140
+ state.load(model_name)
141
+ progress(1.0, desc="Done!")
142
+
143
+ return "✓ Loaded", gr.update(value=[])
144
+ except Exception as e:
145
+ error_msg = str(e)
146
+ # Truncate error for display
147
+ short_err = error_msg[:60] + "..." if len(error_msg) > 60 else error_msg
148
+ return f"⚠ {short_err}", gr.update()
149
+
150
+
151
+ def refresh_model_list():
152
+ """Refresh the list of cached models."""
153
+ models = get_cached_vlm_models()
154
+ return gr.update(choices=models, value=state.current_model_name)
155
+
156
+
157
+ def extract_image_from_message(message):
158
+ """Extract image file path from various message formats."""
159
+ if isinstance(message, dict):
160
+ if "files" in message and message["files"]:
161
+ img = message["files"][-1]
162
+ if isinstance(img, dict) and "path" in img:
163
+ return img["path"]
164
+ elif isinstance(img, str):
165
+ return img
166
+ if "file" in message and message["file"]:
167
+ f = message["file"]
168
+ if isinstance(f, dict) and "path" in f:
169
+ return f["path"]
170
+ elif isinstance(f, str):
171
+ return f
172
+ elif isinstance(message, str):
173
+ return message if message else ""
174
+ return ""
175
+
176
+
177
+ def extract_text_from_message(message):
178
+ """Extract text content from various message formats."""
179
+ if isinstance(message, str):
180
+ return message
181
+ if isinstance(message, dict):
182
+ if "text" in message:
183
+ return message["text"] or ""
184
+ if "content" in message:
185
+ content = message["content"]
186
+ if isinstance(content, str):
187
+ return content
188
+ elif isinstance(content, list):
189
+ text_parts = []
190
+ for c in content:
191
+ if isinstance(c, str):
192
+ text_parts.append(c)
193
+ elif isinstance(c, dict) and c.get("type") == "text":
194
+ text_parts.append(c.get("text", ""))
195
+ return " ".join(text_parts)
196
+ return ""
197
+
198
+
199
+ def chat(
200
+ message,
201
+ history,
202
+ temperature,
203
+ max_tokens,
204
+ top_p,
205
+ repetition_penalty,
206
+ system_prompt,
207
+ ):
208
+ global stop_generation
209
+ stop_generation.clear()
210
+
211
+ image_file = extract_image_from_message(message)
212
+ num_images = 1 if image_file else 0
213
+
214
+ if state.config["model_type"] != "paligemma":
215
+ chat_history = []
216
+
217
+ if system_prompt and system_prompt.strip():
218
+ chat_history.append({"role": "system", "content": system_prompt.strip()})
219
+
220
+ for item in history:
221
+ if isinstance(item, dict):
222
+ role = item.get("role", "user")
223
+ content = item.get("content", "")
224
+ if isinstance(content, str):
225
+ pass
226
+ elif isinstance(content, dict) and "text" in content:
227
+ content = content["text"]
228
+ elif isinstance(content, list):
229
+ text_parts = []
230
+ for c in content:
231
+ if isinstance(c, str):
232
+ text_parts.append(c)
233
+ elif isinstance(c, dict) and c.get("type") == "text":
234
+ text_parts.append(c.get("text", ""))
235
+ content = " ".join(text_parts) if text_parts else ""
236
+ else:
237
+ content = ""
238
+ if role == "assistant" and isinstance(content, str) and content:
239
+ content = content.split("\n\n---\n")[0]
240
+ if content:
241
+ chat_history.append({"role": role, "content": content})
242
+ elif isinstance(item, (list, tuple)):
243
+ if isinstance(item[0], str):
244
+ chat_history.append({"role": "user", "content": item[0]})
245
+ elif isinstance(item[0], dict) and "text" in item[0]:
246
+ chat_history.append({"role": "user", "content": item[0]["text"]})
247
+ if item[1] is not None:
248
+ content = (
249
+ item[1].split("\n\n---\n")[0]
250
+ if isinstance(item[1], str)
251
+ else item[1]
252
+ )
253
+ chat_history.append({"role": "assistant", "content": content})
254
+
255
+ chat_history.append(
256
+ {"role": "user", "content": extract_text_from_message(message)}
257
+ )
258
+
259
+ messages = []
260
+ for i, m in enumerate(chat_history):
261
+ skip_token = True
262
+ if i == len(chat_history) - 1 and m["role"] == "user" and image_file:
263
+ skip_token = False
264
+ messages.append(
265
+ get_message_json(
266
+ state.config["model_type"],
267
+ m["content"],
268
+ role=m["role"],
269
+ skip_image_token=skip_token,
270
+ num_images=num_images if not skip_token else 0,
271
+ )
272
+ )
273
+
274
+ messages = get_chat_template(
275
+ state.processor, messages, add_generation_prompt=True
276
+ )
277
+
278
+ else:
279
+ messages = extract_text_from_message(message)
280
+
281
+ response = ""
282
+ last_chunk = None
283
+
284
+ gen_kwargs = {
285
+ "max_tokens": max_tokens,
286
+ "temperature": temperature,
287
+ }
288
+
289
+ if top_p < 1.0:
290
+ gen_kwargs["top_p"] = top_p
291
+ if repetition_penalty != 1.0:
292
+ gen_kwargs["repetition_penalty"] = repetition_penalty
293
+
294
+ for chunk in stream_generate(
295
+ state.model,
296
+ state.processor,
297
+ messages,
298
+ image=image_file,
299
+ **gen_kwargs,
300
+ ):
301
+ if stop_generation.is_set():
302
+ response += "\n\n*[Generation stopped]*"
303
+ yield response
304
+ return
305
+
306
+ response += chunk.text
307
+ last_chunk = chunk
308
+ yield response
309
+
310
+ if last_chunk is not None:
311
+ stats = (
312
+ f"\n\n---\n"
313
+ f"<sub>📊 Prompt: {last_chunk.prompt_tokens} tokens @ {last_chunk.prompt_tps:.1f} t/s | "
314
+ f"Generation: {last_chunk.generation_tokens} tokens @ {last_chunk.generation_tps:.1f} t/s | "
315
+ f"Peak memory: {last_chunk.peak_memory:.2f} GB</sub>"
316
+ )
317
+ yield response + stats
318
+
319
+
320
+ def stop_generating():
321
+ """Set the stop flag to interrupt generation."""
322
+ stop_generation.set()
323
+ return gr.update(interactive=False)
324
+
325
+
326
+ # Create custom theme with dark mode support
327
+ theme = gr.themes.Soft(
328
+ primary_hue="blue",
329
+ secondary_hue="slate",
330
+ ).set(
331
+ body_background_fill="*neutral_50",
332
+ body_background_fill_dark="*neutral_950",
333
+ block_background_fill="*neutral_100",
334
+ block_background_fill_dark="*neutral_900",
335
+ )
336
+
337
+ # Get initial model list
338
+ initial_models = get_cached_vlm_models()
339
+
340
+ # JavaScript to toggle dark mode and set dark as default
341
+ dark_mode_js = """
342
+ () => {
343
+ // Always set dark mode on load unless user explicitly chose light
344
+ const savedTheme = localStorage.getItem('theme');
345
+ const isDark = savedTheme !== 'light';
346
+ document.body.classList.toggle('dark', isDark);
347
+ return isDark ? 'â˜€ī¸' : '🌙';
348
+ }
349
+ """
350
+
351
+ toggle_dark_js = """
352
+ () => {
353
+ const isDark = document.body.classList.toggle('dark');
354
+ localStorage.setItem('theme', isDark ? 'dark' : 'light');
355
+ return isDark ? 'â˜€ī¸' : '🌙';
356
+ }
357
+ """
358
+
359
+ # JavaScript to persist and restore selected model
360
+ save_model_js = """
361
+ (model_name) => {
362
+ if (model_name) {
363
+ localStorage.setItem('mlx_vlm_model', model_name);
364
+ }
365
+ return model_name;
366
+ }
367
+ """
368
+
369
+ load_model_js = """
370
+ (server_model) => {
371
+ const savedModel = localStorage.getItem('mlx_vlm_model');
372
+ // Return saved model if available, otherwise use server's current model
373
+ return savedModel || server_model;
374
+ }
375
+ """
376
+
377
+ with gr.Blocks(fill_height=True, title="MLX-VLM Chat") as demo:
378
+ gr.Markdown("## MLX-VLM Chat UI")
379
+
380
+ # Model selector row
381
+ with gr.Row():
382
+ with gr.Column(scale=5):
383
+ model_dropdown = gr.Dropdown(
384
+ label="Model",
385
+ choices=initial_models,
386
+ value=state.current_model_name,
387
+ show_label=True,
388
+ allow_custom_value=True,
389
+ )
390
+ with gr.Column(scale=0):
391
+ refresh_btn = gr.Button("🔄", size="sm", min_width=20, scale=0)
392
+ theme_btn = gr.Button("â˜€ī¸", size="sm", min_width=20, scale=0)
393
+ with gr.Column(scale=5):
394
+ model_status = gr.Textbox(
395
+ value="✓ Loaded",
396
+ label="Status",
397
+ interactive=False,
398
+ )
399
+
400
+ # Main controls row
401
+ with gr.Row():
402
+ with gr.Column(scale=6):
403
+ with gr.Accordion("âš™ī¸ Generation Settings", open=False):
404
+ with gr.Row():
405
+ temperature = gr.Slider(
406
+ minimum=0,
407
+ maximum=2,
408
+ step=0.05,
409
+ value=0.1,
410
+ label="Temperature",
411
+ info="Higher = more creative, lower = more focused",
412
+ )
413
+ max_tokens = gr.Slider(
414
+ minimum=128,
415
+ maximum=4096,
416
+ step=64,
417
+ value=1024,
418
+ label="Max Tokens",
419
+ info="Maximum length of response",
420
+ )
421
+ with gr.Row():
422
+ top_p = gr.Slider(
423
+ minimum=0.1,
424
+ maximum=1.0,
425
+ step=0.05,
426
+ value=1.0,
427
+ label="Top-p (Nucleus Sampling)",
428
+ info="1.0 = disabled, lower = more focused",
429
+ )
430
+ repetition_penalty = gr.Slider(
431
+ minimum=1.0,
432
+ maximum=2.0,
433
+ step=0.05,
434
+ value=1.0,
435
+ label="Repetition Penalty",
436
+ info="1.0 = disabled, higher = less repetition",
437
+ )
438
+ with gr.Row():
439
+ system_prompt = gr.Textbox(
440
+ label="System Prompt (optional)",
441
+ placeholder="You are a helpful assistant...",
442
+ lines=2,
443
+ max_lines=4,
444
+ )
445
+
446
+ with gr.Column(scale=1, min_width=200):
447
+ stop_btn = gr.Button("âšī¸ Stop", variant="stop", size="sm")
448
+
449
+ # Chatbot component
450
+ chatbot = gr.Chatbot(
451
+ height=chatbot_height,
452
+ scale=1,
453
+ buttons=["copy", "copy_all"],
454
+ )
455
+
456
+ # Chat interface
457
+ chat_interface = gr.ChatInterface(
458
+ fn=chat,
459
+ additional_inputs=[
460
+ temperature,
461
+ max_tokens,
462
+ top_p,
463
+ repetition_penalty,
464
+ system_prompt,
465
+ ],
466
+ multimodal=True,
467
+ fill_height=True,
468
+ chatbot=chatbot,
469
+ save_history=True,
470
+ )
471
+
472
+ # Connect model selector
473
+ model_dropdown.change(
474
+ fn=load_model_by_name,
475
+ inputs=[model_dropdown],
476
+ outputs=[model_status, chatbot],
477
+ ).then(
478
+ fn=None,
479
+ inputs=[model_dropdown],
480
+ js=save_model_js,
481
+ )
482
+ refresh_btn.click(
483
+ fn=refresh_model_list,
484
+ outputs=[model_dropdown],
485
+ )
486
+
487
+ # Connect theme toggle
488
+ theme_btn.click(fn=None, js=toggle_dark_js, outputs=[theme_btn])
489
+
490
+ # On page load: restore theme and model from localStorage
491
+ demo.load(fn=None, js=dark_mode_js, outputs=[theme_btn])
492
+ demo.load(
493
+ fn=lambda: state.current_model_name,
494
+ inputs=[],
495
+ outputs=[model_dropdown],
496
+ js=load_model_js,
497
+ )
498
+
499
+ # Connect control buttons
500
+ stop_btn.click(fn=stop_generating, outputs=[stop_btn])
501
+
502
+
503
+ def main():
504
+ demo.launch(inbrowser=True, theme=theme)
505
+
506
+
507
+ if __name__ == "__main__":
508
+ main()