nexaai 1.0.4rc13__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +71 -0
- nexaai/_stub.cp310-win_amd64.pyd +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +60 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +91 -0
- nexaai/asr_impl/pybind_asr_impl.py +43 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +3 -0
- nexaai/binds/common_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/embedder_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/llm_bind.cp310-win_amd64.pyd +0 -0
- nexaai/binds/nexa_bridge.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-base.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-cpu.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-cuda.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml-vulkan.dll +0 -0
- nexaai/binds/nexa_llama_cpp/ggml.dll +0 -0
- nexaai/binds/nexa_llama_cpp/llama.dll +0 -0
- nexaai/binds/nexa_llama_cpp/mtmd.dll +0 -0
- nexaai/binds/nexa_llama_cpp/nexa_plugin.dll +0 -0
- nexaai/common.py +61 -0
- nexaai/cv.py +87 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +88 -0
- nexaai/cv_impl/pybind_cv_impl.py +31 -0
- nexaai/embedder.py +68 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
- nexaai/image_gen.py +136 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
- nexaai/llm.py +89 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +249 -0
- nexaai/llm_impl/pybind_llm_impl.py +207 -0
- nexaai/rerank.py +51 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
- nexaai/runtime.py +64 -0
- nexaai/tts.py +70 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +93 -0
- nexaai/tts_impl/pybind_tts_impl.py +42 -0
- nexaai/utils/avatar_fetcher.py +104 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/model_manager.py +1195 -0
- nexaai/utils/progress_tracker.py +372 -0
- nexaai/vlm.py +120 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
- nexaai-1.0.4rc13.dist-info/METADATA +26 -0
- nexaai-1.0.4rc13.dist-info/RECORD +59 -0
- nexaai-1.0.4rc13.dist-info/WHEEL +5 -0
- nexaai-1.0.4rc13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Progress tracking utilities for downloads with tqdm integration.
|
|
3
|
+
|
|
4
|
+
This module provides custom progress tracking classes that can monitor
|
|
5
|
+
download progress with callback support and customizable display options.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from typing import Optional, Callable, Dict, Any
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CustomProgressTqdm(tqdm):
|
|
16
|
+
"""Custom tqdm that tracks progress but completely hides terminal output."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, *args, **kwargs):
|
|
19
|
+
# Redirect output to devnull to completely suppress terminal output
|
|
20
|
+
kwargs['file'] = open(os.devnull, 'w')
|
|
21
|
+
kwargs['disable'] = False # Keep enabled for tracking
|
|
22
|
+
kwargs['leave'] = False # Don't leave progress bar
|
|
23
|
+
super().__init__(*args, **kwargs)
|
|
24
|
+
|
|
25
|
+
def display(self, msg=None, pos=None):
|
|
26
|
+
# Override display to show nothing
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
def write(self, s, file=None, end="\n", nolock=False):
|
|
30
|
+
# Override write to prevent any output
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
def close(self):
|
|
34
|
+
# Override close to avoid printing and properly close devnull
|
|
35
|
+
if hasattr(self, 'fp') and self.fp and self.fp != sys.stdout and self.fp != sys.stderr:
|
|
36
|
+
try:
|
|
37
|
+
self.fp.close()
|
|
38
|
+
except:
|
|
39
|
+
pass
|
|
40
|
+
self.disable = True
|
|
41
|
+
super(tqdm, self).close()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DownloadProgressTracker:
|
|
45
|
+
"""Progress tracker for HuggingFace downloads with callback support."""
|
|
46
|
+
|
|
47
|
+
def __init__(self, progress_callback: Optional[Callable[[Dict[str, Any]], None]] = None, show_progress: bool = True):
|
|
48
|
+
self.progress_data: Dict[str, Dict[str, Any]] = {}
|
|
49
|
+
self.total_repo_size = 0
|
|
50
|
+
self.repo_file_count = 0
|
|
51
|
+
self.original_tqdm_update = None
|
|
52
|
+
self.original_tqdm_init = None
|
|
53
|
+
self.original_tqdm_display = None
|
|
54
|
+
self.original_tqdm_write = None
|
|
55
|
+
self.is_tracking = False
|
|
56
|
+
|
|
57
|
+
# Callback function
|
|
58
|
+
self.progress_callback = progress_callback
|
|
59
|
+
|
|
60
|
+
# Progress display
|
|
61
|
+
self.show_progress = show_progress
|
|
62
|
+
self.last_display_length = 0
|
|
63
|
+
|
|
64
|
+
# Speed tracking
|
|
65
|
+
self.last_downloaded = 0
|
|
66
|
+
self.last_time = time.time()
|
|
67
|
+
self.speed_history = []
|
|
68
|
+
self.max_speed_history = 10
|
|
69
|
+
|
|
70
|
+
# Download status
|
|
71
|
+
self.download_status = "idle" # idle, downloading, completed, error
|
|
72
|
+
self.error_message = None
|
|
73
|
+
self.download_start_time = None
|
|
74
|
+
|
|
75
|
+
def set_repo_info(self, total_size: int, file_count: int):
|
|
76
|
+
"""Set the total repository size and file count before download."""
|
|
77
|
+
self.total_repo_size = total_size
|
|
78
|
+
self.repo_file_count = file_count
|
|
79
|
+
|
|
80
|
+
def register_tqdm(self, tqdm_instance):
|
|
81
|
+
"""Register a tqdm instance for monitoring."""
|
|
82
|
+
tqdm_id = str(id(tqdm_instance))
|
|
83
|
+
self.progress_data[tqdm_id] = {
|
|
84
|
+
'current': 0,
|
|
85
|
+
'total': getattr(tqdm_instance, 'total', 0) or 0,
|
|
86
|
+
'desc': getattr(tqdm_instance, 'desc', 'Unknown'),
|
|
87
|
+
'tqdm_obj': tqdm_instance
|
|
88
|
+
}
|
|
89
|
+
# Trigger callback when new file is registered
|
|
90
|
+
self._trigger_callback()
|
|
91
|
+
|
|
92
|
+
def update_progress(self, tqdm_instance, n=1):
|
|
93
|
+
"""Update progress for a tqdm instance."""
|
|
94
|
+
tqdm_id = str(id(tqdm_instance))
|
|
95
|
+
if tqdm_id in self.progress_data:
|
|
96
|
+
self.progress_data[tqdm_id]['current'] = getattr(tqdm_instance, 'n', 0)
|
|
97
|
+
self.progress_data[tqdm_id]['total'] = getattr(tqdm_instance, 'total', 0) or 0
|
|
98
|
+
# Trigger callback on every progress update
|
|
99
|
+
self._trigger_callback()
|
|
100
|
+
|
|
101
|
+
def calculate_speed(self, current_downloaded: int) -> float:
|
|
102
|
+
"""Calculate download speed in bytes per second."""
|
|
103
|
+
current_time = time.time()
|
|
104
|
+
time_diff = current_time - self.last_time
|
|
105
|
+
|
|
106
|
+
if time_diff > 0 and self.last_downloaded > 0:
|
|
107
|
+
bytes_diff = current_downloaded - self.last_downloaded
|
|
108
|
+
speed = bytes_diff / time_diff
|
|
109
|
+
|
|
110
|
+
# Add to speed history for smoothing
|
|
111
|
+
self.speed_history.append(speed)
|
|
112
|
+
if len(self.speed_history) > self.max_speed_history:
|
|
113
|
+
self.speed_history.pop(0)
|
|
114
|
+
|
|
115
|
+
# Return smoothed speed
|
|
116
|
+
return sum(self.speed_history) / len(self.speed_history)
|
|
117
|
+
|
|
118
|
+
return 0.0
|
|
119
|
+
|
|
120
|
+
def format_bytes(self, bytes_value: int) -> str:
|
|
121
|
+
"""Format bytes to human readable string."""
|
|
122
|
+
for unit in ['B', 'KB', 'MB', 'GB']:
|
|
123
|
+
if bytes_value < 1024.0:
|
|
124
|
+
return f"{bytes_value:.1f} {unit}"
|
|
125
|
+
bytes_value /= 1024.0
|
|
126
|
+
return f"{bytes_value:.1f} TB"
|
|
127
|
+
|
|
128
|
+
def format_speed(self, speed: float) -> str:
|
|
129
|
+
"""Format speed to human readable string."""
|
|
130
|
+
if speed == 0:
|
|
131
|
+
return "0 B/s"
|
|
132
|
+
|
|
133
|
+
for unit in ['B/s', 'KB/s', 'MB/s', 'GB/s']:
|
|
134
|
+
if speed < 1024.0:
|
|
135
|
+
return f"{speed:.1f} {unit}"
|
|
136
|
+
speed /= 1024.0
|
|
137
|
+
return f"{speed:.1f} TB/s"
|
|
138
|
+
|
|
139
|
+
def get_progress_data(self) -> Dict[str, Any]:
|
|
140
|
+
"""Get current progress data."""
|
|
141
|
+
total_downloaded = 0
|
|
142
|
+
active_file_count = 0
|
|
143
|
+
total_file_sizes = 0
|
|
144
|
+
|
|
145
|
+
for data in self.progress_data.values():
|
|
146
|
+
if data['total'] > 0:
|
|
147
|
+
total_downloaded += data['current']
|
|
148
|
+
total_file_sizes += data['total']
|
|
149
|
+
active_file_count += 1
|
|
150
|
+
|
|
151
|
+
# Calculate speed
|
|
152
|
+
speed = self.calculate_speed(total_downloaded)
|
|
153
|
+
|
|
154
|
+
# Update tracking variables
|
|
155
|
+
self.last_downloaded = total_downloaded
|
|
156
|
+
self.last_time = time.time()
|
|
157
|
+
|
|
158
|
+
# Determine total size - prioritize pre-fetched repo size, then aggregate file sizes
|
|
159
|
+
if self.total_repo_size > 0:
|
|
160
|
+
# Use pre-fetched repository info if available
|
|
161
|
+
total_size = self.total_repo_size
|
|
162
|
+
elif total_file_sizes > 0:
|
|
163
|
+
# Use sum of individual file sizes if available
|
|
164
|
+
total_size = total_file_sizes
|
|
165
|
+
else:
|
|
166
|
+
# Last resort - we don't know the total size yet
|
|
167
|
+
total_size = 0
|
|
168
|
+
|
|
169
|
+
file_count = self.repo_file_count if self.repo_file_count > 0 else active_file_count
|
|
170
|
+
|
|
171
|
+
# Calculate percentage - handle unknown total size gracefully
|
|
172
|
+
if total_size > 0:
|
|
173
|
+
percentage = min((total_downloaded / total_size * 100), 100.0)
|
|
174
|
+
else:
|
|
175
|
+
percentage = 0
|
|
176
|
+
|
|
177
|
+
# Calculate ETA
|
|
178
|
+
eta_seconds = None
|
|
179
|
+
if speed > 0 and total_size > total_downloaded:
|
|
180
|
+
eta_seconds = (total_size - total_downloaded) / speed
|
|
181
|
+
|
|
182
|
+
# Calculate elapsed time
|
|
183
|
+
elapsed_seconds = None
|
|
184
|
+
if self.download_start_time:
|
|
185
|
+
elapsed_seconds = time.time() - self.download_start_time
|
|
186
|
+
|
|
187
|
+
return {
|
|
188
|
+
'status': self.download_status,
|
|
189
|
+
'error_message': self.error_message,
|
|
190
|
+
'progress': {
|
|
191
|
+
'total_downloaded': total_downloaded,
|
|
192
|
+
'total_size': total_size,
|
|
193
|
+
'percentage': round(percentage, 2),
|
|
194
|
+
'files_active': active_file_count,
|
|
195
|
+
'files_total': file_count,
|
|
196
|
+
'known_total': total_size > 0
|
|
197
|
+
},
|
|
198
|
+
'speed': {
|
|
199
|
+
'bytes_per_second': speed,
|
|
200
|
+
'formatted': self.format_speed(speed)
|
|
201
|
+
},
|
|
202
|
+
'formatting': {
|
|
203
|
+
'downloaded': self.format_bytes(total_downloaded),
|
|
204
|
+
'total_size': self.format_bytes(total_size)
|
|
205
|
+
},
|
|
206
|
+
'timing': {
|
|
207
|
+
'elapsed_seconds': elapsed_seconds,
|
|
208
|
+
'eta_seconds': eta_seconds,
|
|
209
|
+
'start_time': self.download_start_time
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
def _display_progress_bar(self, progress_data: Dict[str, Any]):
|
|
214
|
+
"""Display a custom unified progress bar."""
|
|
215
|
+
if not self.show_progress:
|
|
216
|
+
return
|
|
217
|
+
|
|
218
|
+
# Clear previous line
|
|
219
|
+
if self.last_display_length > 0:
|
|
220
|
+
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
221
|
+
|
|
222
|
+
progress_info = progress_data.get('progress', {})
|
|
223
|
+
speed_info = progress_data.get('speed', {})
|
|
224
|
+
timing_info = progress_data.get('timing', {})
|
|
225
|
+
formatting_info = progress_data.get('formatting', {})
|
|
226
|
+
|
|
227
|
+
percentage = progress_info.get('percentage', 0)
|
|
228
|
+
downloaded = formatting_info.get('downloaded', '0 B')
|
|
229
|
+
total_size_raw = progress_info.get('total_size', 0)
|
|
230
|
+
total_size = formatting_info.get('total_size', 'Unknown')
|
|
231
|
+
speed = speed_info.get('formatted', '0 B/s')
|
|
232
|
+
known_total = progress_info.get('known_total', False)
|
|
233
|
+
|
|
234
|
+
# Create progress bar
|
|
235
|
+
bar_width = 30
|
|
236
|
+
if known_total and total_size_raw > 0:
|
|
237
|
+
# Known total size - show actual progress
|
|
238
|
+
filled_width = int(bar_width * min(percentage, 100) / 100)
|
|
239
|
+
bar = '█' * filled_width + '░' * (bar_width - filled_width)
|
|
240
|
+
else:
|
|
241
|
+
# Unknown total size - show animated progress
|
|
242
|
+
animation_pos = int(time.time() * 2) % bar_width
|
|
243
|
+
bar = '░' * animation_pos + '█' + '░' * (bar_width - animation_pos - 1)
|
|
244
|
+
|
|
245
|
+
# Format the progress line
|
|
246
|
+
status = progress_data.get('status', 'unknown')
|
|
247
|
+
if status == 'downloading':
|
|
248
|
+
if known_total:
|
|
249
|
+
progress_line = f"[{bar}] {percentage:.1f}% | {downloaded}/{total_size} | {speed}"
|
|
250
|
+
else:
|
|
251
|
+
progress_line = f"[{bar}] {downloaded} | {speed} | Calculating size..."
|
|
252
|
+
elif status == 'completed':
|
|
253
|
+
progress_line = f"[{bar}] 100.0% | {downloaded} | Complete!"
|
|
254
|
+
elif status == 'error':
|
|
255
|
+
progress_line = f"Error: {progress_data.get('error_message', 'Unknown error')}"
|
|
256
|
+
else:
|
|
257
|
+
progress_line = f"Starting download..."
|
|
258
|
+
|
|
259
|
+
# Display and track length for next clear
|
|
260
|
+
print(progress_line, end='', flush=True)
|
|
261
|
+
self.last_display_length = len(progress_line)
|
|
262
|
+
|
|
263
|
+
def _clear_progress_bar(self):
|
|
264
|
+
"""Clear the progress bar display."""
|
|
265
|
+
if self.show_progress and self.last_display_length > 0:
|
|
266
|
+
print('\r' + ' ' * self.last_display_length, end='\r')
|
|
267
|
+
print() # Move to next line
|
|
268
|
+
self.last_display_length = 0
|
|
269
|
+
|
|
270
|
+
def _trigger_callback(self):
|
|
271
|
+
"""Trigger the progress callback if one is set."""
|
|
272
|
+
progress_data = self.get_progress_data()
|
|
273
|
+
|
|
274
|
+
if self.progress_callback:
|
|
275
|
+
try:
|
|
276
|
+
self.progress_callback(progress_data)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
print(f"Error in progress callback: {e}")
|
|
279
|
+
|
|
280
|
+
# Show custom progress bar only if callback is enabled and show_progress is True
|
|
281
|
+
if self.progress_callback and self.show_progress:
|
|
282
|
+
self._display_progress_bar(progress_data)
|
|
283
|
+
|
|
284
|
+
def start_tracking(self):
|
|
285
|
+
"""Start progress tracking (monkey patch tqdm)."""
|
|
286
|
+
if self.is_tracking:
|
|
287
|
+
return
|
|
288
|
+
|
|
289
|
+
# Store original methods
|
|
290
|
+
self.original_tqdm_update = tqdm.update
|
|
291
|
+
self.original_tqdm_init = tqdm.__init__
|
|
292
|
+
self.original_tqdm_display = tqdm.display
|
|
293
|
+
self.original_tqdm_write = tqdm.write
|
|
294
|
+
|
|
295
|
+
# Create references to self for the nested functions
|
|
296
|
+
tracker = self
|
|
297
|
+
|
|
298
|
+
def patched_init(self_tqdm, *args, **kwargs):
|
|
299
|
+
# Suppress tqdm display by redirecting to devnull
|
|
300
|
+
kwargs['file'] = open(os.devnull, 'w')
|
|
301
|
+
kwargs['disable'] = False # Keep enabled for tracking
|
|
302
|
+
kwargs['leave'] = False # Don't leave progress bar
|
|
303
|
+
|
|
304
|
+
result = tracker.original_tqdm_init(self_tqdm, *args, **kwargs)
|
|
305
|
+
tracker.register_tqdm(self_tqdm)
|
|
306
|
+
return result
|
|
307
|
+
|
|
308
|
+
def patched_update(self_tqdm, n=1):
|
|
309
|
+
result = tracker.original_tqdm_update(self_tqdm, n)
|
|
310
|
+
tracker.update_progress(self_tqdm, n)
|
|
311
|
+
return result
|
|
312
|
+
|
|
313
|
+
def patched_display(self_tqdm, msg=None, pos=None):
|
|
314
|
+
# Override display to show nothing
|
|
315
|
+
pass
|
|
316
|
+
|
|
317
|
+
def patched_write(self_tqdm, s, file=None, end="\n", nolock=False):
|
|
318
|
+
# Override write to prevent any output
|
|
319
|
+
pass
|
|
320
|
+
|
|
321
|
+
# Apply patches
|
|
322
|
+
tqdm.__init__ = patched_init
|
|
323
|
+
tqdm.update = patched_update
|
|
324
|
+
tqdm.display = patched_display
|
|
325
|
+
tqdm.write = patched_write
|
|
326
|
+
|
|
327
|
+
self.is_tracking = True
|
|
328
|
+
self.download_status = "downloading"
|
|
329
|
+
self.download_start_time = time.time()
|
|
330
|
+
|
|
331
|
+
# Trigger initial callback
|
|
332
|
+
self._trigger_callback()
|
|
333
|
+
|
|
334
|
+
def stop_tracking(self):
|
|
335
|
+
"""Stop progress tracking and restore original tqdm."""
|
|
336
|
+
if not self.is_tracking:
|
|
337
|
+
return
|
|
338
|
+
|
|
339
|
+
# Restore original tqdm methods
|
|
340
|
+
if self.original_tqdm_update:
|
|
341
|
+
tqdm.update = self.original_tqdm_update
|
|
342
|
+
if self.original_tqdm_init:
|
|
343
|
+
tqdm.__init__ = self.original_tqdm_init
|
|
344
|
+
if hasattr(self, 'original_tqdm_display') and self.original_tqdm_display:
|
|
345
|
+
tqdm.display = self.original_tqdm_display
|
|
346
|
+
if hasattr(self, 'original_tqdm_write') and self.original_tqdm_write:
|
|
347
|
+
tqdm.write = self.original_tqdm_write
|
|
348
|
+
|
|
349
|
+
# Clean up any open devnull file handles from tqdm instances
|
|
350
|
+
for data in self.progress_data.values():
|
|
351
|
+
if 'tqdm_obj' in data and hasattr(data['tqdm_obj'], 'fp'):
|
|
352
|
+
try:
|
|
353
|
+
fp = data['tqdm_obj'].fp
|
|
354
|
+
if fp and fp != sys.stdout and fp != sys.stderr and not fp.closed:
|
|
355
|
+
fp.close()
|
|
356
|
+
except:
|
|
357
|
+
pass
|
|
358
|
+
|
|
359
|
+
self.is_tracking = False
|
|
360
|
+
if self.download_status == "downloading":
|
|
361
|
+
self.download_status = "completed"
|
|
362
|
+
|
|
363
|
+
# Trigger final callback and clear progress bar
|
|
364
|
+
self._trigger_callback()
|
|
365
|
+
self._clear_progress_bar()
|
|
366
|
+
|
|
367
|
+
def set_error(self, error_message: str):
|
|
368
|
+
"""Set error status and trigger callback."""
|
|
369
|
+
self.download_status = "error"
|
|
370
|
+
self.error_message = error_message
|
|
371
|
+
self._trigger_callback()
|
|
372
|
+
self._clear_progress_bar()
|
nexaai/vlm.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from typing import Generator, Optional, List, Dict, Any, Union
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
import queue
|
|
4
|
+
import threading
|
|
5
|
+
import base64
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from nexaai.common import ModelConfig, GenerationConfig, MultiModalMessage
|
|
9
|
+
from nexaai.base import BaseModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VLM(BaseModel):
|
|
13
|
+
def __init__(self, m_cfg: ModelConfig = ModelConfig()):
|
|
14
|
+
"""Initialize base VLM class."""
|
|
15
|
+
self._m_cfg = m_cfg
|
|
16
|
+
self._cancel_event = threading.Event() # New attribute to control cancellation
|
|
17
|
+
|
|
18
|
+
@classmethod
|
|
19
|
+
def _load_from(cls,
|
|
20
|
+
local_path: str,
|
|
21
|
+
mmproj_path: str,
|
|
22
|
+
m_cfg: ModelConfig = ModelConfig(),
|
|
23
|
+
plugin_id: str = "llama_cpp",
|
|
24
|
+
device_id: Optional[str] = None
|
|
25
|
+
) -> 'VLM':
|
|
26
|
+
"""Load VLM model from local path, routing to appropriate implementation.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
local_path: Path to the main model file
|
|
30
|
+
mmproj_path: Path to the multimodal projection file
|
|
31
|
+
m_cfg: Model configuration
|
|
32
|
+
plugin_id: Plugin identifier
|
|
33
|
+
device_id: Optional device ID (not used in current binding)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
VLM instance
|
|
37
|
+
"""
|
|
38
|
+
if plugin_id == "mlx":
|
|
39
|
+
from nexaai.vlm_impl.mlx_vlm_impl import MlxVlmImpl
|
|
40
|
+
return MlxVlmImpl._load_from(local_path, mmproj_path, m_cfg, plugin_id, device_id)
|
|
41
|
+
else:
|
|
42
|
+
from nexaai.vlm_impl.pybind_vlm_impl import PyBindVLMImpl
|
|
43
|
+
return PyBindVLMImpl._load_from(local_path, mmproj_path, m_cfg, plugin_id, device_id)
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
def eject(self):
|
|
47
|
+
"""Release the model from memory."""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
def cancel_generation(self):
|
|
51
|
+
"""Signal to cancel any ongoing stream generation."""
|
|
52
|
+
self._cancel_event.set()
|
|
53
|
+
|
|
54
|
+
def reset_cancel(self):
|
|
55
|
+
"""Reset the cancel event. Call before starting a new generation if needed."""
|
|
56
|
+
self._cancel_event.clear()
|
|
57
|
+
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def reset(self):
|
|
60
|
+
"""
|
|
61
|
+
Reset the VLM model context and KV cache. If not reset, the model will skip the number of evaluated tokens and treat tokens after those as the new incremental tokens.
|
|
62
|
+
If your past chat history changed, or you are starting a new chat, you should always reset the model before running generate.
|
|
63
|
+
"""
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
def _process_image(self, image: Union[bytes, str, Path]) -> bytes:
|
|
67
|
+
"""Process image input to bytes format.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
image: Image data as bytes, base64 string, or file path
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Image data as bytes
|
|
74
|
+
"""
|
|
75
|
+
if isinstance(image, bytes):
|
|
76
|
+
return image
|
|
77
|
+
elif isinstance(image, str):
|
|
78
|
+
# Check if it's a base64 string
|
|
79
|
+
if image.startswith('data:image'):
|
|
80
|
+
# Extract base64 data from data URL
|
|
81
|
+
base64_data = image.split(',')[1] if ',' in image else image
|
|
82
|
+
return base64.b64decode(base64_data)
|
|
83
|
+
else:
|
|
84
|
+
# Assume it's a file path
|
|
85
|
+
with open(image, 'rb') as f:
|
|
86
|
+
return f.read()
|
|
87
|
+
elif isinstance(image, Path):
|
|
88
|
+
with open(image, 'rb') as f:
|
|
89
|
+
return f.read()
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def apply_chat_template(
|
|
96
|
+
self,
|
|
97
|
+
messages: List[MultiModalMessage],
|
|
98
|
+
tools: Optional[List[Dict[str, Any]]] = None
|
|
99
|
+
) -> str:
|
|
100
|
+
"""Apply the chat template to multimodal messages."""
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def generate_stream(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> Generator[str, None, None]:
|
|
105
|
+
"""Generate text with streaming."""
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def generate(self, prompt: str, g_cfg: GenerationConfig = GenerationConfig()) -> str:
|
|
110
|
+
"""
|
|
111
|
+
Generate text without streaming.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
prompt (str): The prompt to generate text from. For chat models, this is the chat messages after chat template is applied.
|
|
115
|
+
g_cfg (GenerationConfig): Generation configuration.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
str: The generated text.
|
|
119
|
+
"""
|
|
120
|
+
pass
|
|
File without changes
|