nexaai 1.0.14__cp310-cp310-macosx_14_0_universal2.whl → 1.0.15__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +261 -0
- nexaai/mlx_backend/vlm/interface.py +11 -2
- nexaai/mlx_backend/vlm/main.py +168 -9
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/vlm.py +4 -3
- nexaai/vlm_impl/mlx_vlm_impl.py +3 -1
- nexaai/vlm_impl/pybind_vlm_impl.py +3 -1
- {nexaai-1.0.14.dist-info → nexaai-1.0.15.dist-info}/METADATA +1 -1
- {nexaai-1.0.14.dist-info → nexaai-1.0.15.dist-info}/RECORD +23 -13
- {nexaai-1.0.14.dist-info → nexaai-1.0.15.dist-info}/WHEEL +0 -0
- {nexaai-1.0.14.dist-info → nexaai-1.0.15.dist-info}/top_level.txt +0 -0
|
Binary file
|
nexaai/_version.py
CHANGED
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import sys
|
|
4
|
+
import os
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
import time
|
|
8
|
+
from PIL import Image
|
|
9
|
+
import requests
|
|
10
|
+
import numpy as np
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from huggingface_hub import snapshot_download
|
|
13
|
+
|
|
14
|
+
# Add current directory to path for imports
|
|
15
|
+
curr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
16
|
+
sys.path.append(curr_dir)
|
|
17
|
+
sys.path.append(os.path.dirname(curr_dir))
|
|
18
|
+
|
|
19
|
+
# Add the qwen3vl model directory to path
|
|
20
|
+
qwen3vl_dir = os.path.join(curr_dir, "modeling", "models", "qwen3_vl")
|
|
21
|
+
sys.path.append(qwen3vl_dir)
|
|
22
|
+
|
|
23
|
+
# Import required modules for quantized loading
|
|
24
|
+
from transformers import AutoTokenizer
|
|
25
|
+
|
|
26
|
+
# Try relative imports first, fallback to sys.path approach for Nuitka compatibility
|
|
27
|
+
try:
|
|
28
|
+
from .modeling.models.qwen3_vl.llm_common.generate import nexa_generate_step
|
|
29
|
+
from .modeling.models.qwen3_vl.llm_common.cache import make_prompt_cache
|
|
30
|
+
from .modeling.models.qwen3_vl.qwen3vl import (
|
|
31
|
+
VEGModel, LLMModel, ModelArgs, VisionConfig, TextConfig, handle_multimodal_embeds
|
|
32
|
+
)
|
|
33
|
+
from .modeling.models.qwen3_vl.processor import Qwen3VLProcessor
|
|
34
|
+
except ImportError:
|
|
35
|
+
# Fallback for Nuitka compiled environment - use sys.path approach
|
|
36
|
+
from llm_common.generate import nexa_generate_step
|
|
37
|
+
from llm_common.cache import make_prompt_cache
|
|
38
|
+
from qwen3vl import VEGModel, LLMModel, ModelArgs, VisionConfig, TextConfig, handle_multimodal_embeds
|
|
39
|
+
from processor import Qwen3VLProcessor
|
|
40
|
+
|
|
41
|
+
from ml import ChatMessage
|
|
42
|
+
from dataclasses import dataclass
|
|
43
|
+
from typing import Any, Generator, List, Optional, Sequence, Tuple, Union
|
|
44
|
+
from .generate import GenerationResult
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class Qwen3VLBundledModel:
|
|
48
|
+
"""Container for Qwen3-VL vision and language models."""
|
|
49
|
+
vision_model: VEGModel
|
|
50
|
+
llm_model: LLMModel
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _ensure_list(x: Union[str, List[str], None]) -> Optional[List[str]]:
|
|
54
|
+
if x is None:
|
|
55
|
+
return None
|
|
56
|
+
return x if isinstance(x, list) else [x]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def load_qwen3_vl(
|
|
60
|
+
path_or_repo: str,
|
|
61
|
+
adapter_path: Optional[str] = None,
|
|
62
|
+
lazy: bool = False,
|
|
63
|
+
revision: Optional[str] = None,
|
|
64
|
+
**kwargs,
|
|
65
|
+
) -> Tuple[Qwen3VLBundledModel, Qwen3VLProcessor]:
|
|
66
|
+
"""Load Qwen3-VL quantized models and processor.
|
|
67
|
+
|
|
68
|
+
Parameters are aligned with .generate.load for compatibility.
|
|
69
|
+
"""
|
|
70
|
+
model_path = Path(path_or_repo)
|
|
71
|
+
if not model_path.exists():
|
|
72
|
+
if "/" in path_or_repo:
|
|
73
|
+
model_path = Path(snapshot_download(
|
|
74
|
+
repo_id=path_or_repo, repo_type="model", revision=revision))
|
|
75
|
+
else:
|
|
76
|
+
# Fallback to local modelfiles directory
|
|
77
|
+
model_path = Path(qwen3vl_dir) / "modelfiles"
|
|
78
|
+
if not model_path.exists():
|
|
79
|
+
model_path = Path(curr_dir) / "modelfiles"
|
|
80
|
+
|
|
81
|
+
# Model configs (kept identical to main)
|
|
82
|
+
vision_config = VisionConfig(
|
|
83
|
+
hidden_size=1024,
|
|
84
|
+
intermediate_size=4096,
|
|
85
|
+
num_heads=16,
|
|
86
|
+
num_hidden_layers=24,
|
|
87
|
+
patch_size=16,
|
|
88
|
+
temporal_patch_size=2,
|
|
89
|
+
in_channels=3,
|
|
90
|
+
hidden_act="gelu",
|
|
91
|
+
spatial_merge_size=2,
|
|
92
|
+
out_hidden_size=2560,
|
|
93
|
+
num_position_embeddings=2304,
|
|
94
|
+
deepstack_visual_indexes=[5, 11, 17],
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
text_config = TextConfig(
|
|
98
|
+
model_type="qwen3vl",
|
|
99
|
+
hidden_size=2560,
|
|
100
|
+
num_hidden_layers=36,
|
|
101
|
+
intermediate_size=9728,
|
|
102
|
+
num_attention_heads=32,
|
|
103
|
+
num_key_value_heads=8,
|
|
104
|
+
rms_norm_eps=1e-6,
|
|
105
|
+
vocab_size=151936,
|
|
106
|
+
max_position_embeddings=32768,
|
|
107
|
+
rope_theta=5000000.0,
|
|
108
|
+
head_dim=128,
|
|
109
|
+
tie_word_embeddings=True,
|
|
110
|
+
attention_bias=False,
|
|
111
|
+
attention_dropout=0.0,
|
|
112
|
+
rope_scaling={"mrope_section": [24, 20, 20],
|
|
113
|
+
"rope_type": "default", "type": "default"},
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
vision_model = VEGModel(vision_config)
|
|
117
|
+
llm_model = LLMModel(text_config)
|
|
118
|
+
|
|
119
|
+
# Try to load LLM model from available files in order of preference
|
|
120
|
+
preferred_order = [
|
|
121
|
+
("qwen3vl-llm-4B-q4_0.safetensors", 4),
|
|
122
|
+
("qwen3vl-llm-4B-q8_0.safetensors", 8),
|
|
123
|
+
("qwen3vl-llm-4B-f32.safetensors", 32)
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
llm_weights_path = None
|
|
127
|
+
quantization_bits = None
|
|
128
|
+
|
|
129
|
+
# Try loading in order of preference
|
|
130
|
+
for filename, bits in preferred_order:
|
|
131
|
+
candidate_path = model_path / filename
|
|
132
|
+
if candidate_path.exists():
|
|
133
|
+
llm_weights_path = candidate_path
|
|
134
|
+
quantization_bits = bits
|
|
135
|
+
break
|
|
136
|
+
|
|
137
|
+
if llm_weights_path is None:
|
|
138
|
+
# Fallback to original hardcoded path for backward compatibility
|
|
139
|
+
llm_weights_path = model_path / "qwen3vl-llm-4B-q4_0.safetensors"
|
|
140
|
+
quantization_bits = 4
|
|
141
|
+
|
|
142
|
+
vision_weights_path = model_path / "qwen3vl-vision-4B-f16.safetensors"
|
|
143
|
+
|
|
144
|
+
if not vision_weights_path.exists() or not llm_weights_path.exists():
|
|
145
|
+
raise FileNotFoundError(
|
|
146
|
+
f"Missing safetensors. Vision: {vision_weights_path}, LLM: {llm_weights_path}"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Load weights (vision fp16, llm with detected quantization)
|
|
150
|
+
vision_model.set_dtype(mx.float16)
|
|
151
|
+
vision_model.load_weights(str(vision_weights_path), strict=True)
|
|
152
|
+
|
|
153
|
+
# Apply quantization if needed and load LLM weights
|
|
154
|
+
if quantization_bits in [4, 8]:
|
|
155
|
+
nn.quantize(llm_model, bits=quantization_bits, group_size=64,
|
|
156
|
+
class_predicate=quant_predicate)
|
|
157
|
+
# For f32 (32-bit), no quantization needed
|
|
158
|
+
|
|
159
|
+
llm_model.load_weights(str(llm_weights_path), strict=True)
|
|
160
|
+
|
|
161
|
+
# Tokenizer and processor
|
|
162
|
+
tokenizer = AutoTokenizer.from_pretrained(path_or_repo)
|
|
163
|
+
processor = Qwen3VLProcessor(tokenizer=tokenizer)
|
|
164
|
+
|
|
165
|
+
return Qwen3VLBundledModel(vision_model=vision_model, llm_model=llm_model), processor
|
|
166
|
+
|
|
167
|
+
def apply_chat_template_qwen3_vl(messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = False) -> str:
|
|
168
|
+
"""Apply chat template: serialize messages with content as a list of typed items."""
|
|
169
|
+
messages_dict = []
|
|
170
|
+
for msg in messages:
|
|
171
|
+
content_items = [{"type": "text", "text": msg.content}]
|
|
172
|
+
messages_dict.append({"role": msg.role, "content": content_items})
|
|
173
|
+
return json.dumps(messages_dict)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def stream_generate_qwen3_vl(
|
|
177
|
+
model: Qwen3VLBundledModel,
|
|
178
|
+
processor: Qwen3VLProcessor,
|
|
179
|
+
prompt: str,
|
|
180
|
+
image: Union[str, List[str]] = None,
|
|
181
|
+
audio: Union[str, List[str]] = None,
|
|
182
|
+
max_tokens: int = 512,
|
|
183
|
+
**kwargs,
|
|
184
|
+
|
|
185
|
+
) -> Generator[Any, None, None]:
|
|
186
|
+
"""Stream generation yielding .generate.GenerationResult-compatible chunks."""
|
|
187
|
+
messages = json.loads(prompt)
|
|
188
|
+
if image is not None:
|
|
189
|
+
image_list = image if isinstance(image, list) else [image]
|
|
190
|
+
pil_images = []
|
|
191
|
+
for p in image_list:
|
|
192
|
+
try:
|
|
193
|
+
pil_images.append(Image.open(p))
|
|
194
|
+
except Exception:
|
|
195
|
+
continue
|
|
196
|
+
contents = [{"type": "image", "image": img} for img in pil_images]
|
|
197
|
+
if messages:
|
|
198
|
+
if "content" not in messages[-1] or not isinstance(messages[-1]["content"], list):
|
|
199
|
+
messages[-1]["content"] = []
|
|
200
|
+
messages[-1]["content"].extend(contents)
|
|
201
|
+
|
|
202
|
+
raw_text, processed_images = processor.messages_to_text(
|
|
203
|
+
messages, add_generation_prompt=True)
|
|
204
|
+
|
|
205
|
+
inputs = processor.text_to_input_ids(
|
|
206
|
+
raw_text, images=processed_images, return_tensors="mlx")
|
|
207
|
+
|
|
208
|
+
input_ids = inputs["input_ids"]
|
|
209
|
+
pixel_values = inputs.get("pixel_values")
|
|
210
|
+
image_grid_thw = inputs.get("image_grid_thw")
|
|
211
|
+
|
|
212
|
+
inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas = handle_multimodal_embeds(
|
|
213
|
+
model.vision_model, model.llm_model, input_ids, pixel_values, image_grid_thw
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
prompt_cache = make_prompt_cache(model.llm_model, max_kv_size=4096)
|
|
217
|
+
tokenizer = processor.tokenizer
|
|
218
|
+
|
|
219
|
+
# Rough prompt TPS estimation based on input size
|
|
220
|
+
prompt_start = time.perf_counter()
|
|
221
|
+
prompt_tps = input_ids.size / max(1e-6, (time.perf_counter() - prompt_start))
|
|
222
|
+
|
|
223
|
+
gen_count = 0
|
|
224
|
+
tic = time.perf_counter()
|
|
225
|
+
|
|
226
|
+
for token, logprobs in nexa_generate_step(
|
|
227
|
+
model=model.llm_model,
|
|
228
|
+
prompt=None,
|
|
229
|
+
input_embeddings=inputs_embeds,
|
|
230
|
+
max_tokens=max_tokens,
|
|
231
|
+
max_kv_size=4096,
|
|
232
|
+
prompt_cache=prompt_cache,
|
|
233
|
+
visual_pos_masks=visual_pos_masks,
|
|
234
|
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
|
235
|
+
cos=cos,
|
|
236
|
+
sin=sin,
|
|
237
|
+
rope_deltas=rope_deltas,
|
|
238
|
+
):
|
|
239
|
+
if token == tokenizer.eos_token_id:
|
|
240
|
+
break
|
|
241
|
+
|
|
242
|
+
text_piece = tokenizer.decode([token])
|
|
243
|
+
gen_count += 1
|
|
244
|
+
|
|
245
|
+
yield GenerationResult(
|
|
246
|
+
text=text_piece,
|
|
247
|
+
token=token,
|
|
248
|
+
logprobs=logprobs,
|
|
249
|
+
prompt_tokens=int(input_ids.size),
|
|
250
|
+
generation_tokens=gen_count,
|
|
251
|
+
prompt_tps=float(prompt_tps),
|
|
252
|
+
generation_tps=float(
|
|
253
|
+
gen_count / max(1e-6, (time.perf_counter() - tic))),
|
|
254
|
+
peak_memory=float(mx.get_peak_memory() / 1e9),
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def quant_predicate(path: str, mod: nn.Module) -> bool:
|
|
258
|
+
"""Quantization predicate to exclude certain layers from quantization."""
|
|
259
|
+
if path.endswith("lm_head") or "norm" in path.lower() or "embed" in path.lower():
|
|
260
|
+
return False
|
|
261
|
+
return isinstance(mod, (nn.Linear, nn.Embedding))
|
|
@@ -25,6 +25,8 @@ from profiling import ProfilingMixin, ProfilingData, StopReason
|
|
|
25
25
|
|
|
26
26
|
# Import from the actual mlx_vlm structure
|
|
27
27
|
from .generate import generate, stream_generate, load
|
|
28
|
+
from .generate_qwen3_vl import apply_chat_template_qwen3_vl, stream_generate_qwen3_vl, load_qwen3_vl
|
|
29
|
+
|
|
28
30
|
from .modeling.prompt_utils import apply_chat_template
|
|
29
31
|
|
|
30
32
|
# --------------------------------------------------------------------------------------
|
|
@@ -54,6 +56,7 @@ class VLM(ProfilingMixin):
|
|
|
54
56
|
|
|
55
57
|
def __init__(
|
|
56
58
|
self,
|
|
59
|
+
model_name: Optional[str],
|
|
57
60
|
model_path: Path,
|
|
58
61
|
mmproj_path: Path,
|
|
59
62
|
context_length: int,
|
|
@@ -67,11 +70,13 @@ class VLM(ProfilingMixin):
|
|
|
67
70
|
model_path = os.path.dirname(model_path)
|
|
68
71
|
|
|
69
72
|
self.model_path = model_path
|
|
73
|
+
self.model_name = model_name
|
|
70
74
|
self.mmproj_path = mmproj_path
|
|
71
75
|
self.context_length = context_length
|
|
72
76
|
self.device = device
|
|
73
77
|
|
|
74
|
-
|
|
78
|
+
load_impl = load_qwen3_vl if model_name == "qwen3vl" else load
|
|
79
|
+
self.model, self.processor = load_impl(str(model_path))
|
|
75
80
|
|
|
76
81
|
# Init deafutl sampler config with defualt.
|
|
77
82
|
self.sampler_config = SamplerConfig()
|
|
@@ -228,9 +233,10 @@ class VLM(ProfilingMixin):
|
|
|
228
233
|
text = ""
|
|
229
234
|
last_result = None
|
|
230
235
|
first_token = True
|
|
236
|
+
stream_generate_impl = stream_generate_qwen3_vl if self.model_name == "qwen3vl" else stream_generate
|
|
231
237
|
|
|
232
238
|
try:
|
|
233
|
-
for result in
|
|
239
|
+
for result in stream_generate_impl(
|
|
234
240
|
self.model,
|
|
235
241
|
self.processor,
|
|
236
242
|
prompt,
|
|
@@ -350,6 +356,9 @@ class VLM(ProfilingMixin):
|
|
|
350
356
|
|
|
351
357
|
def apply_chat_template_with_media(self, messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = True) -> str:
|
|
352
358
|
"""Apply chat template to messages with proper image/audio token insertion and optional tools support."""
|
|
359
|
+
if self.model_name == "qwen3vl":
|
|
360
|
+
return apply_chat_template_qwen3_vl(messages, num_images=num_images, num_audios=num_audios, tools=tools, enable_thinking=enable_thinking)
|
|
361
|
+
|
|
353
362
|
# Convert ChatMessage objects to dictionaries for the processor
|
|
354
363
|
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
355
364
|
|
nexaai/mlx_backend/vlm/main.py
CHANGED
|
@@ -3,6 +3,7 @@ from ml import GenerationConfig, SamplerConfig, ChatMessage
|
|
|
3
3
|
import re
|
|
4
4
|
import os
|
|
5
5
|
import codecs
|
|
6
|
+
import argparse
|
|
6
7
|
|
|
7
8
|
def parse_media_from_input(user_input):
|
|
8
9
|
"""Parse quoted media files from user input and return prompt and media paths"""
|
|
@@ -39,12 +40,177 @@ def parse_media_from_input(user_input):
|
|
|
39
40
|
|
|
40
41
|
return prompt, image_paths if image_paths else None, audio_paths if audio_paths else None
|
|
41
42
|
|
|
42
|
-
def
|
|
43
|
+
def parse_arguments():
|
|
44
|
+
"""Parse command line arguments for the VLM main function."""
|
|
45
|
+
parser = argparse.ArgumentParser(
|
|
46
|
+
description="Interactive VLM (Vision-Language Model) conversation interface."
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
"--model_path",
|
|
50
|
+
type=str,
|
|
51
|
+
default="mlx-community/gemma-3-4b-it-8bit",
|
|
52
|
+
help="The path to the local model directory or Hugging Face repo."
|
|
53
|
+
)
|
|
54
|
+
parser.add_argument(
|
|
55
|
+
"--model_name",
|
|
56
|
+
type=str,
|
|
57
|
+
default="",
|
|
58
|
+
help="Specific model name/type (e.g., 'qwen3vl', 'gemma3'). If empty, auto-detect from model_path."
|
|
59
|
+
)
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--context_length",
|
|
62
|
+
type=int,
|
|
63
|
+
default=2048,
|
|
64
|
+
help="Context length for the model (default: 2048)."
|
|
65
|
+
)
|
|
66
|
+
parser.add_argument(
|
|
67
|
+
"--temperature",
|
|
68
|
+
type=float,
|
|
69
|
+
default=0.7,
|
|
70
|
+
help="Sampling temperature (default: 0.7)."
|
|
71
|
+
)
|
|
72
|
+
parser.add_argument(
|
|
73
|
+
"--top_p",
|
|
74
|
+
type=float,
|
|
75
|
+
default=0.9,
|
|
76
|
+
help="Top-p sampling parameter (default: 0.9)."
|
|
77
|
+
)
|
|
78
|
+
parser.add_argument(
|
|
79
|
+
"--max_tokens",
|
|
80
|
+
type=int,
|
|
81
|
+
default=512,
|
|
82
|
+
help="Maximum tokens to generate (default: 512)."
|
|
83
|
+
)
|
|
84
|
+
return parser.parse_args()
|
|
85
|
+
|
|
86
|
+
def main():
|
|
87
|
+
"""Main function for interactive VLM conversation."""
|
|
88
|
+
args = parse_arguments()
|
|
89
|
+
|
|
90
|
+
# Auto-detect model name if not provided
|
|
91
|
+
model_name = args.model_name
|
|
92
|
+
if not model_name:
|
|
93
|
+
if "qwen" in args.model_path.lower():
|
|
94
|
+
model_name = "qwen3vl"
|
|
95
|
+
elif "gemma" in args.model_path.lower():
|
|
96
|
+
model_name = "gemma3"
|
|
97
|
+
else:
|
|
98
|
+
model_name = ""
|
|
99
|
+
|
|
100
|
+
# Load the VLM instance
|
|
101
|
+
vlm = VLM(
|
|
102
|
+
model_name=model_name,
|
|
103
|
+
model_path=args.model_path,
|
|
104
|
+
mmproj_path=None, # Not needed for this model
|
|
105
|
+
context_length=args.context_length,
|
|
106
|
+
device=None
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Configure sampler
|
|
110
|
+
sampler_config = SamplerConfig(
|
|
111
|
+
temperature=args.temperature,
|
|
112
|
+
top_p=args.top_p
|
|
113
|
+
)
|
|
114
|
+
vlm.set_sampler(sampler_config)
|
|
115
|
+
|
|
116
|
+
# Chat history using ChatMessage objects
|
|
117
|
+
chat = []
|
|
118
|
+
|
|
119
|
+
print("VLM Multi-round conversation started. Type 'quit' or 'exit' to end.")
|
|
120
|
+
print("Include images/audios in quotes, e.g.: 'describe \"image1.jpg\" \"image2.png\"'")
|
|
121
|
+
print("You can also use single quotes: 'describe '/path/to/image.jpg''")
|
|
122
|
+
print("=" * 50)
|
|
123
|
+
|
|
124
|
+
def on_token(text_chunk):
|
|
125
|
+
"""Token callback for streaming"""
|
|
126
|
+
print(text_chunk, end="", flush=True)
|
|
127
|
+
return True
|
|
128
|
+
|
|
129
|
+
while True:
|
|
130
|
+
# Get user input
|
|
131
|
+
user_input = input("\nUser: ").strip()
|
|
132
|
+
|
|
133
|
+
# Check for exit commands
|
|
134
|
+
if user_input.lower() in ["quit", "exit", "q"]:
|
|
135
|
+
print("Goodbye!")
|
|
136
|
+
break
|
|
137
|
+
|
|
138
|
+
if not user_input:
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
# Parse media files and prompt from user input
|
|
142
|
+
prompt_text, image_paths, audio_paths = parse_media_from_input(user_input)
|
|
143
|
+
|
|
144
|
+
# If no text prompt after parsing, use the original input
|
|
145
|
+
if not prompt_text.strip():
|
|
146
|
+
prompt_text = user_input
|
|
147
|
+
image_paths = None
|
|
148
|
+
audio_paths = None
|
|
149
|
+
|
|
150
|
+
# Add user message to chat history using ChatMessage
|
|
151
|
+
chat.append(ChatMessage(role="user", content=prompt_text))
|
|
152
|
+
|
|
153
|
+
# Calculate number of images and audios for chat template
|
|
154
|
+
num_images = len(image_paths) if image_paths else 0
|
|
155
|
+
num_audios = len(audio_paths) if audio_paths else 0
|
|
156
|
+
|
|
157
|
+
# Apply chat template with image/audio token insertion
|
|
158
|
+
try:
|
|
159
|
+
formatted_prompt = vlm.apply_chat_template_with_media(chat, num_images=num_images, num_audios=num_audios)
|
|
160
|
+
except (NotImplementedError, AttributeError):
|
|
161
|
+
# Fallback to manual formatting if chat template is not implemented
|
|
162
|
+
formatted_prompt = ""
|
|
163
|
+
for msg in chat:
|
|
164
|
+
formatted_prompt += f"{msg.role}: {msg.content}\n"
|
|
165
|
+
formatted_prompt += "Assistant: "
|
|
166
|
+
|
|
167
|
+
# Generation config with media paths
|
|
168
|
+
generation_config = GenerationConfig(
|
|
169
|
+
max_tokens=args.max_tokens,
|
|
170
|
+
sampler_config=sampler_config,
|
|
171
|
+
image_paths=image_paths,
|
|
172
|
+
audio_paths=audio_paths
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Generate response
|
|
176
|
+
print("Assistant: ", end="", flush=True)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
# Use streaming generation with callback
|
|
180
|
+
response_text = ""
|
|
181
|
+
|
|
182
|
+
def token_callback(text_chunk):
|
|
183
|
+
nonlocal response_text
|
|
184
|
+
print(text_chunk, end="", flush=True)
|
|
185
|
+
response_text += text_chunk
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
# Use generate_stream method for streaming generation
|
|
189
|
+
response = vlm.generate_stream(
|
|
190
|
+
prompt=formatted_prompt,
|
|
191
|
+
config=generation_config,
|
|
192
|
+
on_token=token_callback
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
print() # New line after streaming
|
|
196
|
+
|
|
197
|
+
# Add assistant response to chat history using ChatMessage
|
|
198
|
+
chat.append(ChatMessage(role="assistant", content=response_text))
|
|
199
|
+
|
|
200
|
+
except Exception as e:
|
|
201
|
+
print(f"Error generating response: {e}")
|
|
202
|
+
print()
|
|
203
|
+
|
|
204
|
+
# Clean up
|
|
205
|
+
vlm.destroy()
|
|
206
|
+
|
|
207
|
+
def test_vlm_generate_stream(model_path, model_name):
|
|
43
208
|
# Specify the checkpoint
|
|
44
209
|
context_length = 2048
|
|
45
210
|
|
|
46
211
|
# Load the corresponding model and VLM instance
|
|
47
212
|
vlm = VLM(
|
|
213
|
+
model_name=model_name,
|
|
48
214
|
model_path=model_path,
|
|
49
215
|
mmproj_path=None, # Not needed for this model
|
|
50
216
|
context_length=context_length,
|
|
@@ -88,9 +254,6 @@ def test_vlm_generate_stream(model_path):
|
|
|
88
254
|
# Parse media files and prompt from user input
|
|
89
255
|
prompt_text, image_paths, audio_paths = parse_media_from_input(user_input)
|
|
90
256
|
|
|
91
|
-
print(f"image_paths: {image_paths}")
|
|
92
|
-
print(f"audio_paths: {audio_paths}")
|
|
93
|
-
|
|
94
257
|
# If no text prompt after parsing, use the original input
|
|
95
258
|
if not prompt_text.strip():
|
|
96
259
|
prompt_text = user_input
|
|
@@ -150,8 +313,4 @@ def test_vlm_generate_stream(model_path):
|
|
|
150
313
|
vlm.destroy()
|
|
151
314
|
|
|
152
315
|
if __name__ == "__main__":
|
|
153
|
-
|
|
154
|
-
parser = argparse.ArgumentParser()
|
|
155
|
-
parser.add_argument("--model_path", type=str, default="mlx-community/gemma-3-4b-it-8bit")
|
|
156
|
-
args = parser.parse_args()
|
|
157
|
-
test_vlm_generate_stream(args.model_path)
|
|
316
|
+
main()
|
|
File without changes
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
from mlx.utils import tree_map
|
|
7
|
+
|
|
8
|
+
from .cache import QuantizedKVCache
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class BaseModelArgs:
|
|
13
|
+
@classmethod
|
|
14
|
+
def from_dict(cls, params):
|
|
15
|
+
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_causal_mask(
|
|
19
|
+
N: int,
|
|
20
|
+
offset: int = 0,
|
|
21
|
+
window_size: Optional[int] = None,
|
|
22
|
+
lengths: Optional[mx.array] = None,
|
|
23
|
+
):
|
|
24
|
+
rinds = mx.arange(offset + N)
|
|
25
|
+
linds = mx.arange(offset, offset + N) if offset else rinds
|
|
26
|
+
linds = linds[:, None]
|
|
27
|
+
rinds = rinds[None]
|
|
28
|
+
mask = linds >= rinds
|
|
29
|
+
if window_size is not None:
|
|
30
|
+
mask = mask & (linds <= rinds + window_size)
|
|
31
|
+
if lengths is not None:
|
|
32
|
+
lengths = lengths[:, None, None, None]
|
|
33
|
+
mask = mask & (rinds < lengths)
|
|
34
|
+
return mask
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, return_array: bool = False):
|
|
38
|
+
T = h.shape[1]
|
|
39
|
+
if T > 1:
|
|
40
|
+
offset = 0
|
|
41
|
+
window_size = None
|
|
42
|
+
if cache is not None and cache[0] is not None:
|
|
43
|
+
c = cache[0]
|
|
44
|
+
offset = c.offset
|
|
45
|
+
if hasattr(c, "max_size"):
|
|
46
|
+
window_size = c.max_size
|
|
47
|
+
offset = min(window_size, offset)
|
|
48
|
+
return_array = return_array or offset + T > window_size
|
|
49
|
+
if return_array:
|
|
50
|
+
return create_causal_mask(T, offset, window_size=window_size)
|
|
51
|
+
else:
|
|
52
|
+
return "causal"
|
|
53
|
+
else:
|
|
54
|
+
mask = None
|
|
55
|
+
return mask
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def quantized_scaled_dot_product_attention(
|
|
59
|
+
queries: mx.array,
|
|
60
|
+
q_keys: tuple[mx.array, mx.array, mx.array],
|
|
61
|
+
q_values: tuple[mx.array, mx.array, mx.array],
|
|
62
|
+
scale: float,
|
|
63
|
+
mask: Optional[mx.array],
|
|
64
|
+
group_size: int = 64,
|
|
65
|
+
bits: int = 8,
|
|
66
|
+
) -> mx.array:
|
|
67
|
+
B, n_q_heads, L, D = queries.shape
|
|
68
|
+
n_kv_heads = q_keys[0].shape[-3]
|
|
69
|
+
n_repeats = n_q_heads // n_kv_heads
|
|
70
|
+
|
|
71
|
+
queries *= scale
|
|
72
|
+
|
|
73
|
+
if n_repeats > 1:
|
|
74
|
+
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
|
75
|
+
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
|
76
|
+
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
|
77
|
+
|
|
78
|
+
scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
|
|
79
|
+
if mask is not None:
|
|
80
|
+
if isinstance(mask, str):
|
|
81
|
+
qL, kL = scores.shape[-2:]
|
|
82
|
+
q_indices = mx.arange(kL - qL, kL)
|
|
83
|
+
k_indices = mx.arange(kL)
|
|
84
|
+
mask = q_indices[:, None] >= k_indices[None]
|
|
85
|
+
if mask.dtype == mx.bool_:
|
|
86
|
+
scores = mx.where(mask, scores, mx.finfo(scores.dtype).min)
|
|
87
|
+
else:
|
|
88
|
+
scores += mask
|
|
89
|
+
scores = mx.softmax(scores, axis=-1, precise=True)
|
|
90
|
+
out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
|
|
91
|
+
|
|
92
|
+
if n_repeats > 1:
|
|
93
|
+
out = mx.reshape(out, (B, n_q_heads, L, D))
|
|
94
|
+
|
|
95
|
+
return out
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def scaled_dot_product_attention(
|
|
99
|
+
queries,
|
|
100
|
+
keys,
|
|
101
|
+
values,
|
|
102
|
+
cache,
|
|
103
|
+
scale: float,
|
|
104
|
+
mask: Optional[mx.array],
|
|
105
|
+
) -> mx.array:
|
|
106
|
+
if isinstance(cache, QuantizedKVCache):
|
|
107
|
+
return quantized_scaled_dot_product_attention(
|
|
108
|
+
queries,
|
|
109
|
+
keys,
|
|
110
|
+
values,
|
|
111
|
+
scale=scale,
|
|
112
|
+
mask=mask,
|
|
113
|
+
group_size=cache.group_size,
|
|
114
|
+
bits=cache.bits,
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
|