fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/chat_ui.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import gc
|
|
3
|
+
import json
|
|
4
|
+
import threading
|
|
5
|
+
|
|
6
|
+
import gradio as gr
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
|
|
9
|
+
from mlx_vlm import load
|
|
10
|
+
|
|
11
|
+
from .generate import stream_generate
|
|
12
|
+
from .prompt_utils import get_chat_template, get_message_json
|
|
13
|
+
from .utils import load_config, load_image_processor
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def parse_arguments():
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
description="Generate text from an image using a model."
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--model",
|
|
22
|
+
type=str,
|
|
23
|
+
default="qnguyen3/nanoLLaVA",
|
|
24
|
+
help="The path to the local model directory or Hugging Face repo.",
|
|
25
|
+
)
|
|
26
|
+
return parser.parse_args()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Global state for model
|
|
30
|
+
class ModelState:
|
|
31
|
+
def __init__(self):
|
|
32
|
+
self.model = None
|
|
33
|
+
self.processor = None
|
|
34
|
+
self.config = None
|
|
35
|
+
self.image_processor = None
|
|
36
|
+
self.current_model_name = None
|
|
37
|
+
|
|
38
|
+
def load(self, model_name):
|
|
39
|
+
"""Load a model, clearing previous one from memory."""
|
|
40
|
+
# Clear previous model from memory
|
|
41
|
+
if self.model is not None:
|
|
42
|
+
del self.model
|
|
43
|
+
del self.processor
|
|
44
|
+
del self.config
|
|
45
|
+
del self.image_processor
|
|
46
|
+
mx.metal.clear_cache()
|
|
47
|
+
gc.collect()
|
|
48
|
+
|
|
49
|
+
# Load new model
|
|
50
|
+
self.config = load_config(model_name)
|
|
51
|
+
self.model, self.processor = load(
|
|
52
|
+
model_name, processor_kwargs={"trust_remote_code": True}
|
|
53
|
+
)
|
|
54
|
+
self.image_processor = load_image_processor(model_name)
|
|
55
|
+
self.current_model_name = model_name
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
state = ModelState()
|
|
59
|
+
|
|
60
|
+
# Parse args and load initial model
|
|
61
|
+
args = parse_arguments()
|
|
62
|
+
state.load(args.model)
|
|
63
|
+
|
|
64
|
+
# Use most of the viewport for conversation
|
|
65
|
+
chatbot_height = "clamp(380px, calc(100vh - 450px), 820px)"
|
|
66
|
+
|
|
67
|
+
# Global flag for stopping generation
|
|
68
|
+
stop_generation = threading.Event()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_cached_vlm_models():
|
|
72
|
+
"""Scan HF cache for vision-capable models."""
|
|
73
|
+
try:
|
|
74
|
+
from huggingface_hub import scan_cache_dir
|
|
75
|
+
|
|
76
|
+
vlm_models = []
|
|
77
|
+
cache_info = scan_cache_dir()
|
|
78
|
+
|
|
79
|
+
for repo in cache_info.repos:
|
|
80
|
+
if repo.repo_type != "model":
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
# Check for refs
|
|
84
|
+
refs = getattr(repo, "refs", {})
|
|
85
|
+
if not refs or "main" not in refs:
|
|
86
|
+
# Try revisions instead
|
|
87
|
+
revisions = getattr(repo, "revisions", None)
|
|
88
|
+
if revisions:
|
|
89
|
+
for rev in revisions:
|
|
90
|
+
snapshot_path = getattr(rev, "snapshot_path", None)
|
|
91
|
+
if snapshot_path:
|
|
92
|
+
config_path = snapshot_path / "config.json"
|
|
93
|
+
if config_path.exists():
|
|
94
|
+
try:
|
|
95
|
+
with open(config_path, "r") as f:
|
|
96
|
+
config = json.load(f)
|
|
97
|
+
if "vision_config" in config:
|
|
98
|
+
vlm_models.append(repo.repo_id)
|
|
99
|
+
break
|
|
100
|
+
except Exception:
|
|
101
|
+
pass
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
# Check config.json for vision_config
|
|
105
|
+
main_ref = refs["main"]
|
|
106
|
+
snapshot_path = getattr(main_ref, "snapshot_path", None)
|
|
107
|
+
if snapshot_path:
|
|
108
|
+
config_path = snapshot_path / "config.json"
|
|
109
|
+
if config_path.exists():
|
|
110
|
+
try:
|
|
111
|
+
with open(config_path, "r") as f:
|
|
112
|
+
config = json.load(f)
|
|
113
|
+
if "vision_config" in config:
|
|
114
|
+
vlm_models.append(repo.repo_id)
|
|
115
|
+
except Exception:
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
# Ensure current model is in the list
|
|
119
|
+
if state.current_model_name and state.current_model_name not in vlm_models:
|
|
120
|
+
vlm_models.insert(0, state.current_model_name)
|
|
121
|
+
|
|
122
|
+
return sorted(set(vlm_models))
|
|
123
|
+
except Exception as e:
|
|
124
|
+
print(f"Error scanning cache: {e}")
|
|
125
|
+
# Return at least the current model
|
|
126
|
+
return [state.current_model_name] if state.current_model_name else []
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def load_model_by_name(model_name, progress=gr.Progress()):
|
|
130
|
+
"""Load a model and return status."""
|
|
131
|
+
if not model_name:
|
|
132
|
+
return "â Loaded", gr.update()
|
|
133
|
+
|
|
134
|
+
if model_name == state.current_model_name:
|
|
135
|
+
return "â Loaded", gr.update()
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
progress(0.1, desc="Clearing memory...")
|
|
139
|
+
progress(0.3, desc="Loading...")
|
|
140
|
+
state.load(model_name)
|
|
141
|
+
progress(1.0, desc="Done!")
|
|
142
|
+
|
|
143
|
+
return "â Loaded", gr.update(value=[])
|
|
144
|
+
except Exception as e:
|
|
145
|
+
error_msg = str(e)
|
|
146
|
+
# Truncate error for display
|
|
147
|
+
short_err = error_msg[:60] + "..." if len(error_msg) > 60 else error_msg
|
|
148
|
+
return f"â {short_err}", gr.update()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def refresh_model_list():
|
|
152
|
+
"""Refresh the list of cached models."""
|
|
153
|
+
models = get_cached_vlm_models()
|
|
154
|
+
return gr.update(choices=models, value=state.current_model_name)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def extract_image_from_message(message):
|
|
158
|
+
"""Extract image file path from various message formats."""
|
|
159
|
+
if isinstance(message, dict):
|
|
160
|
+
if "files" in message and message["files"]:
|
|
161
|
+
img = message["files"][-1]
|
|
162
|
+
if isinstance(img, dict) and "path" in img:
|
|
163
|
+
return img["path"]
|
|
164
|
+
elif isinstance(img, str):
|
|
165
|
+
return img
|
|
166
|
+
if "file" in message and message["file"]:
|
|
167
|
+
f = message["file"]
|
|
168
|
+
if isinstance(f, dict) and "path" in f:
|
|
169
|
+
return f["path"]
|
|
170
|
+
elif isinstance(f, str):
|
|
171
|
+
return f
|
|
172
|
+
elif isinstance(message, str):
|
|
173
|
+
return message if message else ""
|
|
174
|
+
return ""
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def extract_text_from_message(message):
|
|
178
|
+
"""Extract text content from various message formats."""
|
|
179
|
+
if isinstance(message, str):
|
|
180
|
+
return message
|
|
181
|
+
if isinstance(message, dict):
|
|
182
|
+
if "text" in message:
|
|
183
|
+
return message["text"] or ""
|
|
184
|
+
if "content" in message:
|
|
185
|
+
content = message["content"]
|
|
186
|
+
if isinstance(content, str):
|
|
187
|
+
return content
|
|
188
|
+
elif isinstance(content, list):
|
|
189
|
+
text_parts = []
|
|
190
|
+
for c in content:
|
|
191
|
+
if isinstance(c, str):
|
|
192
|
+
text_parts.append(c)
|
|
193
|
+
elif isinstance(c, dict) and c.get("type") == "text":
|
|
194
|
+
text_parts.append(c.get("text", ""))
|
|
195
|
+
return " ".join(text_parts)
|
|
196
|
+
return ""
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def chat(
|
|
200
|
+
message,
|
|
201
|
+
history,
|
|
202
|
+
temperature,
|
|
203
|
+
max_tokens,
|
|
204
|
+
top_p,
|
|
205
|
+
repetition_penalty,
|
|
206
|
+
system_prompt,
|
|
207
|
+
):
|
|
208
|
+
global stop_generation
|
|
209
|
+
stop_generation.clear()
|
|
210
|
+
|
|
211
|
+
image_file = extract_image_from_message(message)
|
|
212
|
+
num_images = 1 if image_file else 0
|
|
213
|
+
|
|
214
|
+
if state.config["model_type"] != "paligemma":
|
|
215
|
+
chat_history = []
|
|
216
|
+
|
|
217
|
+
if system_prompt and system_prompt.strip():
|
|
218
|
+
chat_history.append({"role": "system", "content": system_prompt.strip()})
|
|
219
|
+
|
|
220
|
+
for item in history:
|
|
221
|
+
if isinstance(item, dict):
|
|
222
|
+
role = item.get("role", "user")
|
|
223
|
+
content = item.get("content", "")
|
|
224
|
+
if isinstance(content, str):
|
|
225
|
+
pass
|
|
226
|
+
elif isinstance(content, dict) and "text" in content:
|
|
227
|
+
content = content["text"]
|
|
228
|
+
elif isinstance(content, list):
|
|
229
|
+
text_parts = []
|
|
230
|
+
for c in content:
|
|
231
|
+
if isinstance(c, str):
|
|
232
|
+
text_parts.append(c)
|
|
233
|
+
elif isinstance(c, dict) and c.get("type") == "text":
|
|
234
|
+
text_parts.append(c.get("text", ""))
|
|
235
|
+
content = " ".join(text_parts) if text_parts else ""
|
|
236
|
+
else:
|
|
237
|
+
content = ""
|
|
238
|
+
if role == "assistant" and isinstance(content, str) and content:
|
|
239
|
+
content = content.split("\n\n---\n")[0]
|
|
240
|
+
if content:
|
|
241
|
+
chat_history.append({"role": role, "content": content})
|
|
242
|
+
elif isinstance(item, (list, tuple)):
|
|
243
|
+
if isinstance(item[0], str):
|
|
244
|
+
chat_history.append({"role": "user", "content": item[0]})
|
|
245
|
+
elif isinstance(item[0], dict) and "text" in item[0]:
|
|
246
|
+
chat_history.append({"role": "user", "content": item[0]["text"]})
|
|
247
|
+
if item[1] is not None:
|
|
248
|
+
content = (
|
|
249
|
+
item[1].split("\n\n---\n")[0]
|
|
250
|
+
if isinstance(item[1], str)
|
|
251
|
+
else item[1]
|
|
252
|
+
)
|
|
253
|
+
chat_history.append({"role": "assistant", "content": content})
|
|
254
|
+
|
|
255
|
+
chat_history.append(
|
|
256
|
+
{"role": "user", "content": extract_text_from_message(message)}
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
messages = []
|
|
260
|
+
for i, m in enumerate(chat_history):
|
|
261
|
+
skip_token = True
|
|
262
|
+
if i == len(chat_history) - 1 and m["role"] == "user" and image_file:
|
|
263
|
+
skip_token = False
|
|
264
|
+
messages.append(
|
|
265
|
+
get_message_json(
|
|
266
|
+
state.config["model_type"],
|
|
267
|
+
m["content"],
|
|
268
|
+
role=m["role"],
|
|
269
|
+
skip_image_token=skip_token,
|
|
270
|
+
num_images=num_images if not skip_token else 0,
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
messages = get_chat_template(
|
|
275
|
+
state.processor, messages, add_generation_prompt=True
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
else:
|
|
279
|
+
messages = extract_text_from_message(message)
|
|
280
|
+
|
|
281
|
+
response = ""
|
|
282
|
+
last_chunk = None
|
|
283
|
+
|
|
284
|
+
gen_kwargs = {
|
|
285
|
+
"max_tokens": max_tokens,
|
|
286
|
+
"temperature": temperature,
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
if top_p < 1.0:
|
|
290
|
+
gen_kwargs["top_p"] = top_p
|
|
291
|
+
if repetition_penalty != 1.0:
|
|
292
|
+
gen_kwargs["repetition_penalty"] = repetition_penalty
|
|
293
|
+
|
|
294
|
+
for chunk in stream_generate(
|
|
295
|
+
state.model,
|
|
296
|
+
state.processor,
|
|
297
|
+
messages,
|
|
298
|
+
image=image_file,
|
|
299
|
+
**gen_kwargs,
|
|
300
|
+
):
|
|
301
|
+
if stop_generation.is_set():
|
|
302
|
+
response += "\n\n*[Generation stopped]*"
|
|
303
|
+
yield response
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
response += chunk.text
|
|
307
|
+
last_chunk = chunk
|
|
308
|
+
yield response
|
|
309
|
+
|
|
310
|
+
if last_chunk is not None:
|
|
311
|
+
stats = (
|
|
312
|
+
f"\n\n---\n"
|
|
313
|
+
f"<sub>đ Prompt: {last_chunk.prompt_tokens} tokens @ {last_chunk.prompt_tps:.1f} t/s | "
|
|
314
|
+
f"Generation: {last_chunk.generation_tokens} tokens @ {last_chunk.generation_tps:.1f} t/s | "
|
|
315
|
+
f"Peak memory: {last_chunk.peak_memory:.2f} GB</sub>"
|
|
316
|
+
)
|
|
317
|
+
yield response + stats
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def stop_generating():
|
|
321
|
+
"""Set the stop flag to interrupt generation."""
|
|
322
|
+
stop_generation.set()
|
|
323
|
+
return gr.update(interactive=False)
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# Create custom theme with dark mode support
|
|
327
|
+
theme = gr.themes.Soft(
|
|
328
|
+
primary_hue="blue",
|
|
329
|
+
secondary_hue="slate",
|
|
330
|
+
).set(
|
|
331
|
+
body_background_fill="*neutral_50",
|
|
332
|
+
body_background_fill_dark="*neutral_950",
|
|
333
|
+
block_background_fill="*neutral_100",
|
|
334
|
+
block_background_fill_dark="*neutral_900",
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Get initial model list
|
|
338
|
+
initial_models = get_cached_vlm_models()
|
|
339
|
+
|
|
340
|
+
# JavaScript to toggle dark mode and set dark as default
|
|
341
|
+
dark_mode_js = """
|
|
342
|
+
() => {
|
|
343
|
+
// Always set dark mode on load unless user explicitly chose light
|
|
344
|
+
const savedTheme = localStorage.getItem('theme');
|
|
345
|
+
const isDark = savedTheme !== 'light';
|
|
346
|
+
document.body.classList.toggle('dark', isDark);
|
|
347
|
+
return isDark ? 'âī¸' : 'đ';
|
|
348
|
+
}
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
toggle_dark_js = """
|
|
352
|
+
() => {
|
|
353
|
+
const isDark = document.body.classList.toggle('dark');
|
|
354
|
+
localStorage.setItem('theme', isDark ? 'dark' : 'light');
|
|
355
|
+
return isDark ? 'âī¸' : 'đ';
|
|
356
|
+
}
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
# JavaScript to persist and restore selected model
|
|
360
|
+
save_model_js = """
|
|
361
|
+
(model_name) => {
|
|
362
|
+
if (model_name) {
|
|
363
|
+
localStorage.setItem('mlx_vlm_model', model_name);
|
|
364
|
+
}
|
|
365
|
+
return model_name;
|
|
366
|
+
}
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
load_model_js = """
|
|
370
|
+
(server_model) => {
|
|
371
|
+
const savedModel = localStorage.getItem('mlx_vlm_model');
|
|
372
|
+
// Return saved model if available, otherwise use server's current model
|
|
373
|
+
return savedModel || server_model;
|
|
374
|
+
}
|
|
375
|
+
"""
|
|
376
|
+
|
|
377
|
+
with gr.Blocks(fill_height=True, title="MLX-VLM Chat") as demo:
|
|
378
|
+
gr.Markdown("## MLX-VLM Chat UI")
|
|
379
|
+
|
|
380
|
+
# Model selector row
|
|
381
|
+
with gr.Row():
|
|
382
|
+
with gr.Column(scale=5):
|
|
383
|
+
model_dropdown = gr.Dropdown(
|
|
384
|
+
label="Model",
|
|
385
|
+
choices=initial_models,
|
|
386
|
+
value=state.current_model_name,
|
|
387
|
+
show_label=True,
|
|
388
|
+
allow_custom_value=True,
|
|
389
|
+
)
|
|
390
|
+
with gr.Column(scale=0):
|
|
391
|
+
refresh_btn = gr.Button("đ", size="sm", min_width=20, scale=0)
|
|
392
|
+
theme_btn = gr.Button("âī¸", size="sm", min_width=20, scale=0)
|
|
393
|
+
with gr.Column(scale=5):
|
|
394
|
+
model_status = gr.Textbox(
|
|
395
|
+
value="â Loaded",
|
|
396
|
+
label="Status",
|
|
397
|
+
interactive=False,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
# Main controls row
|
|
401
|
+
with gr.Row():
|
|
402
|
+
with gr.Column(scale=6):
|
|
403
|
+
with gr.Accordion("âī¸ Generation Settings", open=False):
|
|
404
|
+
with gr.Row():
|
|
405
|
+
temperature = gr.Slider(
|
|
406
|
+
minimum=0,
|
|
407
|
+
maximum=2,
|
|
408
|
+
step=0.05,
|
|
409
|
+
value=0.1,
|
|
410
|
+
label="Temperature",
|
|
411
|
+
info="Higher = more creative, lower = more focused",
|
|
412
|
+
)
|
|
413
|
+
max_tokens = gr.Slider(
|
|
414
|
+
minimum=128,
|
|
415
|
+
maximum=4096,
|
|
416
|
+
step=64,
|
|
417
|
+
value=1024,
|
|
418
|
+
label="Max Tokens",
|
|
419
|
+
info="Maximum length of response",
|
|
420
|
+
)
|
|
421
|
+
with gr.Row():
|
|
422
|
+
top_p = gr.Slider(
|
|
423
|
+
minimum=0.1,
|
|
424
|
+
maximum=1.0,
|
|
425
|
+
step=0.05,
|
|
426
|
+
value=1.0,
|
|
427
|
+
label="Top-p (Nucleus Sampling)",
|
|
428
|
+
info="1.0 = disabled, lower = more focused",
|
|
429
|
+
)
|
|
430
|
+
repetition_penalty = gr.Slider(
|
|
431
|
+
minimum=1.0,
|
|
432
|
+
maximum=2.0,
|
|
433
|
+
step=0.05,
|
|
434
|
+
value=1.0,
|
|
435
|
+
label="Repetition Penalty",
|
|
436
|
+
info="1.0 = disabled, higher = less repetition",
|
|
437
|
+
)
|
|
438
|
+
with gr.Row():
|
|
439
|
+
system_prompt = gr.Textbox(
|
|
440
|
+
label="System Prompt (optional)",
|
|
441
|
+
placeholder="You are a helpful assistant...",
|
|
442
|
+
lines=2,
|
|
443
|
+
max_lines=4,
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
with gr.Column(scale=1, min_width=200):
|
|
447
|
+
stop_btn = gr.Button("âšī¸ Stop", variant="stop", size="sm")
|
|
448
|
+
|
|
449
|
+
# Chatbot component
|
|
450
|
+
chatbot = gr.Chatbot(
|
|
451
|
+
height=chatbot_height,
|
|
452
|
+
scale=1,
|
|
453
|
+
buttons=["copy", "copy_all"],
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Chat interface
|
|
457
|
+
chat_interface = gr.ChatInterface(
|
|
458
|
+
fn=chat,
|
|
459
|
+
additional_inputs=[
|
|
460
|
+
temperature,
|
|
461
|
+
max_tokens,
|
|
462
|
+
top_p,
|
|
463
|
+
repetition_penalty,
|
|
464
|
+
system_prompt,
|
|
465
|
+
],
|
|
466
|
+
multimodal=True,
|
|
467
|
+
fill_height=True,
|
|
468
|
+
chatbot=chatbot,
|
|
469
|
+
save_history=True,
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Connect model selector
|
|
473
|
+
model_dropdown.change(
|
|
474
|
+
fn=load_model_by_name,
|
|
475
|
+
inputs=[model_dropdown],
|
|
476
|
+
outputs=[model_status, chatbot],
|
|
477
|
+
).then(
|
|
478
|
+
fn=None,
|
|
479
|
+
inputs=[model_dropdown],
|
|
480
|
+
js=save_model_js,
|
|
481
|
+
)
|
|
482
|
+
refresh_btn.click(
|
|
483
|
+
fn=refresh_model_list,
|
|
484
|
+
outputs=[model_dropdown],
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# Connect theme toggle
|
|
488
|
+
theme_btn.click(fn=None, js=toggle_dark_js, outputs=[theme_btn])
|
|
489
|
+
|
|
490
|
+
# On page load: restore theme and model from localStorage
|
|
491
|
+
demo.load(fn=None, js=dark_mode_js, outputs=[theme_btn])
|
|
492
|
+
demo.load(
|
|
493
|
+
fn=lambda: state.current_model_name,
|
|
494
|
+
inputs=[],
|
|
495
|
+
outputs=[model_dropdown],
|
|
496
|
+
js=load_model_js,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Connect control buttons
|
|
500
|
+
stop_btn.click(fn=stop_generating, outputs=[stop_btn])
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
def main():
|
|
504
|
+
demo.launch(inbrowser=True, theme=theme)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
if __name__ == "__main__":
|
|
508
|
+
main()
|