griptape-nodes 0.53.0__py3-none-any.whl → 0.54.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/__init__.py +5 -2
- griptape_nodes/app/app.py +4 -26
- griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
- griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
- griptape_nodes/cli/commands/config.py +4 -1
- griptape_nodes/cli/commands/init.py +5 -3
- griptape_nodes/cli/commands/libraries.py +14 -8
- griptape_nodes/cli/commands/models.py +504 -0
- griptape_nodes/cli/commands/self.py +5 -2
- griptape_nodes/cli/main.py +11 -1
- griptape_nodes/cli/shared.py +0 -9
- griptape_nodes/common/directed_graph.py +17 -1
- griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
- griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
- griptape_nodes/drivers/storage/local_storage_driver.py +17 -13
- griptape_nodes/exe_types/node_types.py +219 -14
- griptape_nodes/exe_types/param_components/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
- griptape_nodes/machines/control_flow.py +129 -92
- griptape_nodes/machines/dag_builder.py +207 -0
- griptape_nodes/machines/parallel_resolution.py +264 -276
- griptape_nodes/machines/sequential_resolution.py +9 -7
- griptape_nodes/node_library/library_registry.py +34 -1
- griptape_nodes/retained_mode/events/app_events.py +5 -1
- griptape_nodes/retained_mode/events/base_events.py +7 -7
- griptape_nodes/retained_mode/events/config_events.py +30 -0
- griptape_nodes/retained_mode/events/execution_events.py +2 -2
- griptape_nodes/retained_mode/events/model_events.py +296 -0
- griptape_nodes/retained_mode/griptape_nodes.py +10 -1
- griptape_nodes/retained_mode/managers/agent_manager.py +14 -0
- griptape_nodes/retained_mode/managers/config_manager.py +44 -3
- griptape_nodes/retained_mode/managers/event_manager.py +8 -2
- griptape_nodes/retained_mode/managers/flow_manager.py +45 -14
- griptape_nodes/retained_mode/managers/library_manager.py +3 -3
- griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
- griptape_nodes/retained_mode/managers/node_manager.py +26 -26
- griptape_nodes/retained_mode/managers/object_manager.py +1 -1
- griptape_nodes/retained_mode/managers/os_manager.py +6 -6
- griptape_nodes/retained_mode/managers/settings.py +87 -9
- griptape_nodes/retained_mode/managers/static_files_manager.py +77 -9
- griptape_nodes/retained_mode/managers/sync_manager.py +10 -5
- griptape_nodes/retained_mode/managers/workflow_manager.py +98 -92
- griptape_nodes/retained_mode/retained_mode.py +19 -0
- griptape_nodes/servers/__init__.py +1 -0
- griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
- griptape_nodes/{app/api.py → servers/static.py} +43 -40
- griptape_nodes/traits/button.py +124 -6
- griptape_nodes/traits/multi_options.py +188 -0
- griptape_nodes/traits/numbers_selector.py +77 -0
- griptape_nodes/traits/options.py +93 -2
- griptape_nodes/utils/async_utils.py +31 -0
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +3 -1
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +56 -47
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
- /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,1107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
import sys
|
|
8
|
+
import threading
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from datetime import UTC, datetime
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import TYPE_CHECKING, Any
|
|
13
|
+
from urllib.parse import urlparse
|
|
14
|
+
|
|
15
|
+
from huggingface_hub import list_models, scan_cache_dir, snapshot_download
|
|
16
|
+
from huggingface_hub.utils.tqdm import tqdm
|
|
17
|
+
from xdg_base_dirs import xdg_data_home
|
|
18
|
+
|
|
19
|
+
from griptape_nodes.retained_mode.events.app_events import AppInitializationComplete
|
|
20
|
+
from griptape_nodes.retained_mode.events.model_events import (
|
|
21
|
+
DeleteModelDownloadRequest,
|
|
22
|
+
DeleteModelDownloadResultFailure,
|
|
23
|
+
DeleteModelDownloadResultSuccess,
|
|
24
|
+
DeleteModelRequest,
|
|
25
|
+
DeleteModelResultFailure,
|
|
26
|
+
DeleteModelResultSuccess,
|
|
27
|
+
DownloadModelRequest,
|
|
28
|
+
DownloadModelResultFailure,
|
|
29
|
+
DownloadModelResultSuccess,
|
|
30
|
+
ListModelDownloadsRequest,
|
|
31
|
+
ListModelDownloadsResultFailure,
|
|
32
|
+
ListModelDownloadsResultSuccess,
|
|
33
|
+
ListModelsRequest,
|
|
34
|
+
ListModelsResultFailure,
|
|
35
|
+
ListModelsResultSuccess,
|
|
36
|
+
ModelDownloadStatus,
|
|
37
|
+
ModelInfo,
|
|
38
|
+
QueryInfo,
|
|
39
|
+
SearchModelsRequest,
|
|
40
|
+
SearchModelsResultFailure,
|
|
41
|
+
SearchModelsResultSuccess,
|
|
42
|
+
)
|
|
43
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
44
|
+
from griptape_nodes.utils.async_utils import cancel_subprocess
|
|
45
|
+
|
|
46
|
+
if TYPE_CHECKING:
|
|
47
|
+
from griptape_nodes.retained_mode.events.base_events import ResultPayload
|
|
48
|
+
from griptape_nodes.retained_mode.managers.event_manager import EventManager
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger("griptape_nodes")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
HTTP_UNAUTHORIZED = 401
|
|
54
|
+
HTTP_FORBIDDEN = 403
|
|
55
|
+
|
|
56
|
+
MIN_CACHE_DIR_PARTS = 3
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class SearchResultsData:
|
|
61
|
+
"""Data class for model search results."""
|
|
62
|
+
|
|
63
|
+
models: list[ModelInfo]
|
|
64
|
+
total_results: int
|
|
65
|
+
query_info: QueryInfo
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class ModelDownloadTracker(tqdm):
|
|
69
|
+
"""Custom tqdm progress bar that tracks aggregate model download progress."""
|
|
70
|
+
|
|
71
|
+
_file_lock = threading.Lock()
|
|
72
|
+
_current_model_id = ""
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def set_current_model_id(cls, model_id: str) -> None:
|
|
76
|
+
"""Set the current model being downloaded."""
|
|
77
|
+
cls._current_model_id = model_id
|
|
78
|
+
|
|
79
|
+
def __init__(self, *args, model_id: str = "", **kwargs):
|
|
80
|
+
if not model_id and self._current_model_id:
|
|
81
|
+
model_id = self._current_model_id
|
|
82
|
+
|
|
83
|
+
super().__init__(*args, **kwargs)
|
|
84
|
+
self.model_id = model_id
|
|
85
|
+
self.start_time = datetime.now(UTC).isoformat()
|
|
86
|
+
|
|
87
|
+
logger.debug(
|
|
88
|
+
"ModelDownloadTracker created - model_id: %s, total: %s, desc: %s, args: %s",
|
|
89
|
+
self.model_id,
|
|
90
|
+
self.total,
|
|
91
|
+
getattr(self, "desc", None),
|
|
92
|
+
args,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if self.model_id:
|
|
96
|
+
self._init_status_file()
|
|
97
|
+
|
|
98
|
+
def update(self, n: int = 1) -> None:
|
|
99
|
+
"""Override update to track progress in status file."""
|
|
100
|
+
logger.debug(
|
|
101
|
+
"ModelDownloadTracker update - model_id: %s, n: %s, self.n: %s, total: %s",
|
|
102
|
+
self.model_id,
|
|
103
|
+
n,
|
|
104
|
+
self.n,
|
|
105
|
+
self.total,
|
|
106
|
+
)
|
|
107
|
+
super().update(n)
|
|
108
|
+
self._update_status_file()
|
|
109
|
+
|
|
110
|
+
def close(self) -> None:
|
|
111
|
+
"""Override close to log download completion."""
|
|
112
|
+
super().close()
|
|
113
|
+
logger.debug(
|
|
114
|
+
"ModelDownloadTracker close - model_id: %s, self.n: %s, total: %s", self.model_id, self.n, self.total
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def _get_status_file_path(self) -> Path:
|
|
118
|
+
"""Get the path to the status file for this model."""
|
|
119
|
+
status_dir = xdg_data_home() / "griptape_nodes" / "model_downloads"
|
|
120
|
+
status_dir.mkdir(parents=True, exist_ok=True)
|
|
121
|
+
|
|
122
|
+
sanitized_model_id = re.sub(r"[^\w\-_]", "--", self.model_id)
|
|
123
|
+
return status_dir / f"{sanitized_model_id}.json"
|
|
124
|
+
|
|
125
|
+
def _init_status_file(self) -> None:
|
|
126
|
+
"""Initialize the status file for this model."""
|
|
127
|
+
try:
|
|
128
|
+
with self._file_lock:
|
|
129
|
+
status_file = self._get_status_file_path()
|
|
130
|
+
current_time = datetime.now(UTC).isoformat()
|
|
131
|
+
|
|
132
|
+
logger.info(
|
|
133
|
+
"ModelDownloadTracker initializing status file: %s (total_files=%s)", status_file, self.total
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
data = {
|
|
137
|
+
"model_id": self.model_id,
|
|
138
|
+
"status": "downloading",
|
|
139
|
+
"started_at": current_time,
|
|
140
|
+
"updated_at": current_time,
|
|
141
|
+
"total_files": self.total or 0,
|
|
142
|
+
"downloaded_files": 0,
|
|
143
|
+
"progress_percent": 0.0,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
with status_file.open("w") as f:
|
|
147
|
+
json.dump(data, f, indent=2)
|
|
148
|
+
|
|
149
|
+
logger.info("ModelDownloadTracker status file initialized successfully")
|
|
150
|
+
|
|
151
|
+
except Exception:
|
|
152
|
+
logger.exception("ModelDownloadTracker._init_status_file failed")
|
|
153
|
+
|
|
154
|
+
def _update_status_file(self) -> None:
|
|
155
|
+
"""Update the status file with current progress."""
|
|
156
|
+
if not self.model_id:
|
|
157
|
+
logger.warning("ModelDownloadTracker._update_status_file called with empty model_id")
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
with self._file_lock:
|
|
162
|
+
status_file = self._get_status_file_path()
|
|
163
|
+
logger.info("ModelDownloadTracker updating status file: %s", status_file)
|
|
164
|
+
|
|
165
|
+
if not status_file.exists():
|
|
166
|
+
logger.warning("Status file does not exist: %s", status_file)
|
|
167
|
+
return
|
|
168
|
+
|
|
169
|
+
with status_file.open() as f:
|
|
170
|
+
data = json.load(f)
|
|
171
|
+
|
|
172
|
+
current_time = datetime.now(UTC).isoformat()
|
|
173
|
+
progress_percent = (self.n / self.total * 100) if self.total else 0
|
|
174
|
+
|
|
175
|
+
logger.info(
|
|
176
|
+
"ModelDownloadTracker updating progress: files=%d/%d, percent=%.1f%%",
|
|
177
|
+
self.n,
|
|
178
|
+
self.total,
|
|
179
|
+
progress_percent,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
data.update(
|
|
183
|
+
{
|
|
184
|
+
"downloaded_files": self.n,
|
|
185
|
+
"progress_percent": progress_percent,
|
|
186
|
+
"updated_at": current_time,
|
|
187
|
+
}
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
with status_file.open("w") as f:
|
|
191
|
+
json.dump(data, f, indent=2)
|
|
192
|
+
|
|
193
|
+
logger.debug("ModelDownloadTracker status file updated successfully")
|
|
194
|
+
|
|
195
|
+
except Exception:
|
|
196
|
+
logger.exception("ModelDownloadTracker._update_status_file failed")
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ModelManager:
|
|
200
|
+
"""A manager for downloading models from Hugging Face Hub.
|
|
201
|
+
|
|
202
|
+
This manager provides async handlers for downloading models using the Hugging Face Hub API.
|
|
203
|
+
It supports downloading entire model repositories or specific files, with caching and
|
|
204
|
+
local storage management.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
def __init__(self, event_manager: EventManager | None = None) -> None:
|
|
208
|
+
"""Initialize the ModelManager.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
event_manager: The EventManager instance to use for event handling.
|
|
212
|
+
"""
|
|
213
|
+
self._download_tasks = {}
|
|
214
|
+
self._download_processes = {}
|
|
215
|
+
|
|
216
|
+
if event_manager is not None:
|
|
217
|
+
event_manager.assign_manager_to_request_type(DownloadModelRequest, self.on_handle_download_model_request)
|
|
218
|
+
event_manager.assign_manager_to_request_type(ListModelsRequest, self.on_handle_list_models_request)
|
|
219
|
+
event_manager.assign_manager_to_request_type(DeleteModelRequest, self.on_handle_delete_model_request)
|
|
220
|
+
event_manager.assign_manager_to_request_type(SearchModelsRequest, self.on_handle_search_models_request)
|
|
221
|
+
event_manager.assign_manager_to_request_type(
|
|
222
|
+
ListModelDownloadsRequest, self.on_handle_list_model_downloads_request
|
|
223
|
+
)
|
|
224
|
+
event_manager.assign_manager_to_request_type(
|
|
225
|
+
DeleteModelDownloadRequest, self.on_handle_delete_model_download_request
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
event_manager.add_listener_to_app_event(AppInitializationComplete, self.on_app_initialization_complete)
|
|
229
|
+
|
|
230
|
+
def download_model(
|
|
231
|
+
self,
|
|
232
|
+
model_id: str,
|
|
233
|
+
local_dir: str | None = None,
|
|
234
|
+
revision: str = "main",
|
|
235
|
+
allow_patterns: list[str] | None = None,
|
|
236
|
+
ignore_patterns: list[str] | None = None,
|
|
237
|
+
) -> str:
|
|
238
|
+
"""Direct model download method that can be used without event system.
|
|
239
|
+
|
|
240
|
+
This method contains the core download logic without going through
|
|
241
|
+
the event system, avoiding recursion issues.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
model_id: Model ID to download
|
|
245
|
+
local_dir: Optional local directory to download to
|
|
246
|
+
revision: Git revision to download
|
|
247
|
+
allow_patterns: Optional glob patterns to include
|
|
248
|
+
ignore_patterns: Optional glob patterns to exclude
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
str: Local path where the model was downloaded
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
Exception: If download fails
|
|
255
|
+
"""
|
|
256
|
+
# Set up progress tracking
|
|
257
|
+
ModelDownloadTracker.set_current_model_id(model_id)
|
|
258
|
+
|
|
259
|
+
try:
|
|
260
|
+
# Build download kwargs
|
|
261
|
+
download_kwargs = {
|
|
262
|
+
"repo_id": model_id,
|
|
263
|
+
"repo_type": "model",
|
|
264
|
+
"revision": revision,
|
|
265
|
+
"tqdm_class": ModelDownloadTracker,
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
# Add optional parameters
|
|
269
|
+
if local_dir:
|
|
270
|
+
download_kwargs["local_dir"] = local_dir
|
|
271
|
+
if allow_patterns:
|
|
272
|
+
download_kwargs["allow_patterns"] = allow_patterns
|
|
273
|
+
if ignore_patterns:
|
|
274
|
+
download_kwargs["ignore_patterns"] = ignore_patterns
|
|
275
|
+
|
|
276
|
+
# Execute download with progress tracking
|
|
277
|
+
local_path = snapshot_download(**download_kwargs) # type: ignore[arg-type]
|
|
278
|
+
|
|
279
|
+
return str(local_path)
|
|
280
|
+
|
|
281
|
+
finally:
|
|
282
|
+
# Clear the current model ID when done
|
|
283
|
+
ModelDownloadTracker.set_current_model_id("")
|
|
284
|
+
|
|
285
|
+
def _get_status_directory(self) -> Path:
|
|
286
|
+
"""Get the status directory path for model downloads.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Path: Path to the status directory, creating it if needed
|
|
290
|
+
"""
|
|
291
|
+
status_dir = xdg_data_home() / "griptape_nodes" / "model_downloads"
|
|
292
|
+
status_dir.mkdir(parents=True, exist_ok=True)
|
|
293
|
+
return status_dir
|
|
294
|
+
|
|
295
|
+
async def on_handle_download_model_request(self, request: DownloadModelRequest) -> ResultPayload:
|
|
296
|
+
"""Handle model download requests asynchronously.
|
|
297
|
+
|
|
298
|
+
This method starts a background task to download models from Hugging Face Hub using the provided parameters.
|
|
299
|
+
It supports both model IDs and full URLs, and can download entire repositories
|
|
300
|
+
or specific files based on the patterns provided.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
request: The download request containing model ID and options
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
ResultPayload: Success result indicating download started or failure with error details
|
|
307
|
+
"""
|
|
308
|
+
parsed_model_id = self._parse_model_id(request.model_id)
|
|
309
|
+
if parsed_model_id != request.model_id:
|
|
310
|
+
logger.debug("Parsed model ID '%s' from URL '%s'", parsed_model_id, request.model_id)
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
download_params = {
|
|
314
|
+
"model_id": parsed_model_id,
|
|
315
|
+
"local_dir": request.local_dir,
|
|
316
|
+
"revision": request.revision,
|
|
317
|
+
"allow_patterns": request.allow_patterns,
|
|
318
|
+
"ignore_patterns": request.ignore_patterns,
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
task = asyncio.create_task(self._download_model_task(download_params))
|
|
322
|
+
self._download_tasks[parsed_model_id] = task
|
|
323
|
+
|
|
324
|
+
result_details = f"Started background download for model '{parsed_model_id}'"
|
|
325
|
+
|
|
326
|
+
return DownloadModelResultSuccess(
|
|
327
|
+
model_id=parsed_model_id,
|
|
328
|
+
result_details=result_details,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
error_msg = f"Failed to start download for model '{request.model_id}': {e}"
|
|
333
|
+
return DownloadModelResultFailure(
|
|
334
|
+
result_details=error_msg,
|
|
335
|
+
exception=e,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
async def _download_model_task(self, download_params: dict[str, str | list[str] | None]) -> None:
|
|
339
|
+
"""Background task for downloading a model using CLI command.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
download_params: Dictionary containing download parameters
|
|
343
|
+
"""
|
|
344
|
+
model_id = download_params["model_id"]
|
|
345
|
+
logger.info("Starting background download for model: %s", model_id)
|
|
346
|
+
|
|
347
|
+
# Build CLI command arguments
|
|
348
|
+
cmd = [sys.executable, "-m", "griptape_nodes", "models", "download", str(model_id)]
|
|
349
|
+
|
|
350
|
+
# Add optional parameters
|
|
351
|
+
if download_params.get("local_dir"):
|
|
352
|
+
cmd.extend(["--local-dir", str(download_params["local_dir"])])
|
|
353
|
+
if download_params.get("revision") and download_params["revision"] != "main":
|
|
354
|
+
cmd.extend(["--revision", str(download_params["revision"])])
|
|
355
|
+
|
|
356
|
+
try:
|
|
357
|
+
# Start subprocess
|
|
358
|
+
process = await asyncio.create_subprocess_exec(
|
|
359
|
+
*cmd,
|
|
360
|
+
stdout=asyncio.subprocess.PIPE,
|
|
361
|
+
stderr=asyncio.subprocess.PIPE,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
# Store process for cancellation
|
|
365
|
+
if isinstance(model_id, str):
|
|
366
|
+
self._download_processes[model_id] = process
|
|
367
|
+
|
|
368
|
+
# Wait for completion (this can be cancelled)
|
|
369
|
+
stdout, stderr = await process.communicate()
|
|
370
|
+
|
|
371
|
+
if process.returncode == 0:
|
|
372
|
+
logger.info("Successfully downloaded model '%s'", model_id)
|
|
373
|
+
else:
|
|
374
|
+
logger.error("Download failed for model '%s': %s", model_id, stderr.decode())
|
|
375
|
+
|
|
376
|
+
except asyncio.CancelledError:
|
|
377
|
+
logger.info("Download cancelled for model '%s'", model_id)
|
|
378
|
+
raise
|
|
379
|
+
|
|
380
|
+
except Exception:
|
|
381
|
+
logger.exception("Error downloading model '%s'", model_id)
|
|
382
|
+
|
|
383
|
+
finally:
|
|
384
|
+
if isinstance(model_id, str):
|
|
385
|
+
if model_id in self._download_tasks:
|
|
386
|
+
del self._download_tasks[model_id]
|
|
387
|
+
if model_id in self._download_processes:
|
|
388
|
+
del self._download_processes[model_id]
|
|
389
|
+
|
|
390
|
+
async def on_handle_list_models_request(self, request: ListModelsRequest) -> ResultPayload: # noqa: ARG002
|
|
391
|
+
"""Handle model listing requests asynchronously.
|
|
392
|
+
|
|
393
|
+
This method scans the local Hugging Face cache directory to find downloaded models
|
|
394
|
+
and returns information about each model including path, size, and metadata.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
request: The list request (no parameters needed)
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
ResultPayload: Success result with model list or failure with error details
|
|
401
|
+
"""
|
|
402
|
+
try:
|
|
403
|
+
# Get models in a thread to avoid blocking the event loop
|
|
404
|
+
models = await asyncio.to_thread(self._list_models)
|
|
405
|
+
|
|
406
|
+
result_details = f"Found {len(models)} cached models"
|
|
407
|
+
|
|
408
|
+
return ListModelsResultSuccess(
|
|
409
|
+
models=models,
|
|
410
|
+
result_details=result_details,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
except Exception as e:
|
|
414
|
+
error_msg = f"Failed to list models: {e}"
|
|
415
|
+
return ListModelsResultFailure(
|
|
416
|
+
result_details=error_msg,
|
|
417
|
+
exception=e,
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
async def on_handle_delete_model_request(self, request: DeleteModelRequest) -> ResultPayload:
|
|
421
|
+
"""Handle model deletion requests asynchronously.
|
|
422
|
+
|
|
423
|
+
This method removes a model from the local Hugging Face cache directory and
|
|
424
|
+
cleans up any associated download tracking records.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
request: The delete request containing model_id
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
ResultPayload: Success result with deletion confirmation or failure with error details
|
|
431
|
+
"""
|
|
432
|
+
# Parse the model ID from potential URL
|
|
433
|
+
model_id = request.model_id
|
|
434
|
+
|
|
435
|
+
deleted_items = []
|
|
436
|
+
|
|
437
|
+
try:
|
|
438
|
+
deleted_path = await asyncio.to_thread(self._delete_model, model_id)
|
|
439
|
+
deleted_items.append(f"model files from '{deleted_path}'")
|
|
440
|
+
|
|
441
|
+
except FileNotFoundError:
|
|
442
|
+
logger.debug("No model files found for '%s' in cache", model_id)
|
|
443
|
+
|
|
444
|
+
except Exception as e:
|
|
445
|
+
error_msg = f"Failed to delete model files for '{model_id}': {e}"
|
|
446
|
+
return DeleteModelResultFailure(
|
|
447
|
+
result_details=error_msg,
|
|
448
|
+
exception=e,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
if not deleted_items:
|
|
452
|
+
error_msg = f"Model '{model_id}' not found (no cached files or download records)"
|
|
453
|
+
return DeleteModelResultFailure(
|
|
454
|
+
result_details=error_msg,
|
|
455
|
+
exception=FileNotFoundError(error_msg),
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
deleted_description = " and ".join(deleted_items)
|
|
459
|
+
result_details = f"Successfully deleted {deleted_description} for model '{model_id}'"
|
|
460
|
+
|
|
461
|
+
return DeleteModelResultSuccess(
|
|
462
|
+
model_id=model_id,
|
|
463
|
+
deleted_path=deleted_items[0] if deleted_items else "",
|
|
464
|
+
result_details=result_details,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
async def on_handle_search_models_request(self, request: SearchModelsRequest) -> ResultPayload:
|
|
468
|
+
"""Handle model search requests asynchronously.
|
|
469
|
+
|
|
470
|
+
This method searches for models on Hugging Face Hub using the provided parameters.
|
|
471
|
+
It supports filtering by query, task, library, author, and tags.
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
request: The search request containing search parameters
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
ResultPayload: Success result with model list or failure with error details
|
|
478
|
+
"""
|
|
479
|
+
try:
|
|
480
|
+
# Search models in a thread to avoid blocking the event loop
|
|
481
|
+
search_results = await asyncio.to_thread(self._search_models, request)
|
|
482
|
+
except Exception as e:
|
|
483
|
+
error_msg = f"Failed to search models: {e}"
|
|
484
|
+
return SearchModelsResultFailure(
|
|
485
|
+
result_details=error_msg,
|
|
486
|
+
exception=e,
|
|
487
|
+
)
|
|
488
|
+
else:
|
|
489
|
+
result_details = f"Found {len(search_results.models)} models"
|
|
490
|
+
return SearchModelsResultSuccess(
|
|
491
|
+
models=search_results.models,
|
|
492
|
+
total_results=search_results.total_results,
|
|
493
|
+
query_info=search_results.query_info,
|
|
494
|
+
result_details=result_details,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
def _search_models(self, request: SearchModelsRequest) -> SearchResultsData:
|
|
498
|
+
"""Synchronous model search implementation.
|
|
499
|
+
|
|
500
|
+
Searches for models on Hugging Face Hub using the huggingface_hub API.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
request: The search request parameters
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
SearchResultsData: Dataclass containing models list, total results, and query info
|
|
507
|
+
"""
|
|
508
|
+
# Build search parameters
|
|
509
|
+
search_params = {}
|
|
510
|
+
|
|
511
|
+
if request.query:
|
|
512
|
+
search_params["search"] = request.query
|
|
513
|
+
if request.task:
|
|
514
|
+
search_params["task"] = request.task
|
|
515
|
+
if request.library:
|
|
516
|
+
search_params["library"] = request.library
|
|
517
|
+
if request.author:
|
|
518
|
+
search_params["author"] = request.author
|
|
519
|
+
if request.tags:
|
|
520
|
+
search_params["tags"] = request.tags
|
|
521
|
+
|
|
522
|
+
# Validate and set sort parameters
|
|
523
|
+
valid_sorts = ["downloads", "likes", "updated", "created"]
|
|
524
|
+
sort_param = request.sort if request.sort in valid_sorts else "downloads"
|
|
525
|
+
search_params["sort"] = sort_param
|
|
526
|
+
|
|
527
|
+
# Only add direction for sorts that support it (downloads only supports descending)
|
|
528
|
+
if sort_param != "downloads":
|
|
529
|
+
# Convert direction to the format expected by HF Hub API (-1 for asc, 1 for desc)
|
|
530
|
+
direction_param = -1 if request.direction == "asc" else 1
|
|
531
|
+
search_params["direction"] = direction_param
|
|
532
|
+
|
|
533
|
+
# Limit results (max 100 as per HF Hub API)
|
|
534
|
+
limit = min(max(1, request.limit), 100)
|
|
535
|
+
|
|
536
|
+
# Perform the search
|
|
537
|
+
models_iterator = list_models(limit=limit, **search_params)
|
|
538
|
+
|
|
539
|
+
# Convert models to list and extract information
|
|
540
|
+
models_list = []
|
|
541
|
+
for model in models_iterator:
|
|
542
|
+
created_at = getattr(model, "created_at", None)
|
|
543
|
+
updated_at = getattr(model, "last_modified", None)
|
|
544
|
+
|
|
545
|
+
model_info = ModelInfo(
|
|
546
|
+
model_id=model.id,
|
|
547
|
+
author=getattr(model, "author", None),
|
|
548
|
+
downloads=getattr(model, "downloads", None),
|
|
549
|
+
likes=getattr(model, "likes", None),
|
|
550
|
+
created_at=created_at.isoformat() if created_at else None,
|
|
551
|
+
updated_at=updated_at.isoformat() if updated_at else None,
|
|
552
|
+
task=getattr(model, "pipeline_tag", None),
|
|
553
|
+
library=getattr(model, "library_name", None),
|
|
554
|
+
tags=getattr(model, "tags", None),
|
|
555
|
+
)
|
|
556
|
+
models_list.append(model_info)
|
|
557
|
+
|
|
558
|
+
# Prepare query info for response
|
|
559
|
+
query_info = QueryInfo(
|
|
560
|
+
query=request.query,
|
|
561
|
+
task=request.task,
|
|
562
|
+
library=request.library,
|
|
563
|
+
author=request.author,
|
|
564
|
+
tags=request.tags,
|
|
565
|
+
limit=limit,
|
|
566
|
+
sort=sort_param,
|
|
567
|
+
direction=request.direction, # Keep the original user-friendly format
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
return SearchResultsData(
|
|
571
|
+
models=models_list,
|
|
572
|
+
total_results=len(models_list),
|
|
573
|
+
query_info=query_info,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
async def on_app_initialization_complete(self, _payload: AppInitializationComplete) -> None:
|
|
577
|
+
"""Handle app initialization complete event by downloading configured models and resuming unfinished downloads.
|
|
578
|
+
|
|
579
|
+
Args:
|
|
580
|
+
payload: The app initialization complete payload
|
|
581
|
+
"""
|
|
582
|
+
# Get models to download from configuration
|
|
583
|
+
config_manager = GriptapeNodes.ConfigManager()
|
|
584
|
+
models_to_download = config_manager.get_config_value(
|
|
585
|
+
"app_events.on_app_initialization_complete.models_to_download", default=[]
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
# Find unfinished downloads to resume
|
|
589
|
+
unfinished_models = await asyncio.to_thread(self._find_unfinished_downloads)
|
|
590
|
+
|
|
591
|
+
# Combine new downloads and unfinished ones, avoiding duplicates
|
|
592
|
+
all_models = list(
|
|
593
|
+
dict.fromkeys(
|
|
594
|
+
[
|
|
595
|
+
*[model_id for model_id in models_to_download if model_id], # Filter empty strings
|
|
596
|
+
*unfinished_models,
|
|
597
|
+
]
|
|
598
|
+
)
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
if not all_models:
|
|
602
|
+
logger.debug("No models to download or resume")
|
|
603
|
+
return
|
|
604
|
+
|
|
605
|
+
logger.info(
|
|
606
|
+
"Starting download/resume of %d models (%d new, %d resuming)",
|
|
607
|
+
len(all_models),
|
|
608
|
+
len(models_to_download),
|
|
609
|
+
len(unfinished_models),
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Create download tasks for concurrent execution
|
|
613
|
+
download_tasks = []
|
|
614
|
+
for model_id in all_models:
|
|
615
|
+
task = asyncio.create_task(self._download_model_with_logging(model_id))
|
|
616
|
+
download_tasks.append(task)
|
|
617
|
+
|
|
618
|
+
# Wait for all downloads to complete
|
|
619
|
+
results = await asyncio.gather(*download_tasks, return_exceptions=True)
|
|
620
|
+
|
|
621
|
+
# Log summary of results
|
|
622
|
+
successful = sum(1 for result in results if not isinstance(result, Exception))
|
|
623
|
+
failed = len(results) - successful
|
|
624
|
+
|
|
625
|
+
logger.info("Completed automatic model downloads: %d successful, %d failed", successful, failed)
|
|
626
|
+
|
|
627
|
+
async def _download_model_with_logging(self, model_id: str) -> None:
|
|
628
|
+
"""Download a single model with proper logging.
|
|
629
|
+
|
|
630
|
+
Args:
|
|
631
|
+
model_id: The model ID to download
|
|
632
|
+
"""
|
|
633
|
+
logger.info("Auto-downloading model: %s", model_id)
|
|
634
|
+
|
|
635
|
+
# Create download request with default parameters
|
|
636
|
+
request = DownloadModelRequest(
|
|
637
|
+
model_id=model_id,
|
|
638
|
+
local_dir=None, # Use default cache directory
|
|
639
|
+
revision="main",
|
|
640
|
+
allow_patterns=None,
|
|
641
|
+
ignore_patterns=None,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
try:
|
|
645
|
+
# Run the download asynchronously
|
|
646
|
+
result = await self.on_handle_download_model_request(request)
|
|
647
|
+
|
|
648
|
+
if isinstance(result, DownloadModelResultFailure):
|
|
649
|
+
logger.warning("Failed to auto-download model '%s': %s", model_id, result.result_details)
|
|
650
|
+
elif not isinstance(result, DownloadModelResultSuccess):
|
|
651
|
+
logger.warning("Unknown result type for model '%s' download: %s", model_id, type(result))
|
|
652
|
+
|
|
653
|
+
except Exception as e:
|
|
654
|
+
logger.error("Unexpected error auto-downloading model '%s': %s", model_id, e)
|
|
655
|
+
raise
|
|
656
|
+
|
|
657
|
+
def _get_status_file_path(self, model_id: str) -> Path:
|
|
658
|
+
"""Get the path to the status file for a model.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
model_id: The model ID to get status file path for
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
Path: Path to the status file for this model
|
|
665
|
+
"""
|
|
666
|
+
status_dir = self._get_status_directory()
|
|
667
|
+
|
|
668
|
+
sanitized_model_id = re.sub(r"[^\w\-_]", "--", model_id)
|
|
669
|
+
return status_dir / f"{sanitized_model_id}.json"
|
|
670
|
+
|
|
671
|
+
def _read_model_download_status(self, model_id: str) -> ModelDownloadStatus | None:
|
|
672
|
+
"""Read download status for a specific model.
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
model_id: The model ID to get status for
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
ModelDownloadStatus | None: The status if found, None otherwise
|
|
679
|
+
"""
|
|
680
|
+
status_file = self._get_status_file_path(model_id)
|
|
681
|
+
|
|
682
|
+
if not status_file.exists():
|
|
683
|
+
return None
|
|
684
|
+
|
|
685
|
+
try:
|
|
686
|
+
with status_file.open() as f:
|
|
687
|
+
data = json.load(f)
|
|
688
|
+
|
|
689
|
+
# Get file counts from simplified structure
|
|
690
|
+
total_files = data.get("total_files", 0)
|
|
691
|
+
downloaded_files = data.get("downloaded_files", 0)
|
|
692
|
+
|
|
693
|
+
# For simplified tracking, failed_files is calculated
|
|
694
|
+
failed_files = 0
|
|
695
|
+
if data.get("status") == "failed":
|
|
696
|
+
failed_files = total_files - downloaded_files
|
|
697
|
+
|
|
698
|
+
return ModelDownloadStatus(
|
|
699
|
+
model_id=data["model_id"],
|
|
700
|
+
status=data["status"],
|
|
701
|
+
started_at=data["started_at"],
|
|
702
|
+
updated_at=data["updated_at"],
|
|
703
|
+
total_files=total_files,
|
|
704
|
+
completed_files=downloaded_files,
|
|
705
|
+
failed_files=failed_files,
|
|
706
|
+
completed_at=data.get("completed_at"),
|
|
707
|
+
local_path=data.get("local_path"),
|
|
708
|
+
failed_at=data.get("failed_at"),
|
|
709
|
+
error_message=data.get("error_message"),
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
except (json.JSONDecodeError, KeyError) as e:
|
|
713
|
+
logger.warning("Failed to read status file for model '%s': %s", model_id, e)
|
|
714
|
+
return None
|
|
715
|
+
|
|
716
|
+
def _list_all_download_statuses(self) -> list[ModelDownloadStatus]:
|
|
717
|
+
"""List all model download statuses from status files.
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
list[ModelDownloadStatus]: List of all download statuses
|
|
721
|
+
"""
|
|
722
|
+
status_dir = self._get_status_directory()
|
|
723
|
+
|
|
724
|
+
if not status_dir.exists():
|
|
725
|
+
return []
|
|
726
|
+
|
|
727
|
+
statuses = []
|
|
728
|
+
for status_file in status_dir.glob("*.json"):
|
|
729
|
+
try:
|
|
730
|
+
with status_file.open() as f:
|
|
731
|
+
data = json.load(f)
|
|
732
|
+
|
|
733
|
+
model_id = data.get("model_id", "")
|
|
734
|
+
if model_id:
|
|
735
|
+
status = self._read_model_download_status(model_id)
|
|
736
|
+
if status:
|
|
737
|
+
statuses.append(status)
|
|
738
|
+
|
|
739
|
+
except (json.JSONDecodeError, KeyError) as e:
|
|
740
|
+
logger.warning("Failed to read status file '%s': %s", status_file, e)
|
|
741
|
+
continue
|
|
742
|
+
|
|
743
|
+
return statuses
|
|
744
|
+
|
|
745
|
+
def _find_unfinished_downloads(self) -> list[str]:
|
|
746
|
+
"""Find model IDs with unfinished downloads from status files.
|
|
747
|
+
|
|
748
|
+
Returns:
|
|
749
|
+
list[str]: List of model IDs with status 'downloading' or 'failed'
|
|
750
|
+
"""
|
|
751
|
+
status_dir = self._get_status_directory()
|
|
752
|
+
|
|
753
|
+
if not status_dir.exists():
|
|
754
|
+
return []
|
|
755
|
+
|
|
756
|
+
unfinished_models = []
|
|
757
|
+
for status_file in status_dir.glob("*.json"):
|
|
758
|
+
try:
|
|
759
|
+
with status_file.open() as f:
|
|
760
|
+
data = json.load(f)
|
|
761
|
+
|
|
762
|
+
status = data.get("status", "")
|
|
763
|
+
model_id = data.get("model_id", "")
|
|
764
|
+
|
|
765
|
+
if model_id and status in ("downloading", "failed"):
|
|
766
|
+
unfinished_models.append(model_id)
|
|
767
|
+
|
|
768
|
+
except (json.JSONDecodeError, KeyError) as e:
|
|
769
|
+
logger.warning("Failed to read status file '%s': %s", status_file, e)
|
|
770
|
+
continue
|
|
771
|
+
|
|
772
|
+
return unfinished_models
|
|
773
|
+
|
|
774
|
+
async def on_handle_list_model_downloads_request(self, request: ListModelDownloadsRequest) -> ResultPayload:
|
|
775
|
+
"""Handle model download status requests asynchronously.
|
|
776
|
+
|
|
777
|
+
This method retrieves download status for a specific model or all models
|
|
778
|
+
from the local status files stored in the XDG data directory.
|
|
779
|
+
|
|
780
|
+
Args:
|
|
781
|
+
request: The status request containing optional model_id
|
|
782
|
+
|
|
783
|
+
Returns:
|
|
784
|
+
ResultPayload: Success result with download status list or failure with error details
|
|
785
|
+
"""
|
|
786
|
+
try:
|
|
787
|
+
# Get status information in a thread to avoid blocking the event loop
|
|
788
|
+
downloads = await asyncio.to_thread(self._get_download_statuses, request.model_id)
|
|
789
|
+
|
|
790
|
+
if request.model_id and not downloads:
|
|
791
|
+
result_details = f"No download status found for model '{request.model_id}'"
|
|
792
|
+
elif request.model_id:
|
|
793
|
+
result_details = f"Found download status for model '{request.model_id}'"
|
|
794
|
+
else:
|
|
795
|
+
result_details = f"Found {len(downloads)} download status records"
|
|
796
|
+
|
|
797
|
+
return ListModelDownloadsResultSuccess(
|
|
798
|
+
downloads=downloads,
|
|
799
|
+
result_details=result_details,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
except Exception as e:
|
|
803
|
+
error_msg = f"Failed to get download status: {e}"
|
|
804
|
+
return ListModelDownloadsResultFailure(
|
|
805
|
+
result_details=error_msg,
|
|
806
|
+
exception=e,
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
def _list_models(self) -> list[ModelInfo]:
|
|
810
|
+
"""Synchronous model listing implementation using HuggingFace Hub SDK.
|
|
811
|
+
|
|
812
|
+
Uses scan_cache_dir to get information about cached models.
|
|
813
|
+
|
|
814
|
+
Returns:
|
|
815
|
+
list[ModelInfo]: List of model information
|
|
816
|
+
"""
|
|
817
|
+
try:
|
|
818
|
+
cache_info = scan_cache_dir()
|
|
819
|
+
models = []
|
|
820
|
+
|
|
821
|
+
for repo in cache_info.repos:
|
|
822
|
+
# Calculate total size across all revisions
|
|
823
|
+
total_size = sum(revision.size_on_disk for revision in repo.revisions)
|
|
824
|
+
|
|
825
|
+
model_info = ModelInfo(
|
|
826
|
+
model_id=repo.repo_id,
|
|
827
|
+
local_path=str(repo.repo_path),
|
|
828
|
+
size_bytes=total_size,
|
|
829
|
+
)
|
|
830
|
+
models.append(model_info)
|
|
831
|
+
|
|
832
|
+
except Exception as e:
|
|
833
|
+
logger.warning("Failed to scan cache directory: %s", e)
|
|
834
|
+
return []
|
|
835
|
+
else:
|
|
836
|
+
return models
|
|
837
|
+
|
|
838
|
+
def _delete_model(self, model_id: str) -> str:
|
|
839
|
+
"""Synchronous model deletion implementation using HuggingFace Hub SDK.
|
|
840
|
+
|
|
841
|
+
Uses scan_cache_dir to find and delete the model from cache.
|
|
842
|
+
|
|
843
|
+
Args:
|
|
844
|
+
model_id: The model ID to delete
|
|
845
|
+
|
|
846
|
+
Returns:
|
|
847
|
+
str: Information about what was deleted
|
|
848
|
+
|
|
849
|
+
Raises:
|
|
850
|
+
FileNotFoundError: If the model is not found in cache
|
|
851
|
+
"""
|
|
852
|
+
cache_info = scan_cache_dir()
|
|
853
|
+
|
|
854
|
+
# Find the repo to delete
|
|
855
|
+
repo_to_delete = None
|
|
856
|
+
for repo in cache_info.repos:
|
|
857
|
+
if repo.repo_id == model_id:
|
|
858
|
+
repo_to_delete = repo
|
|
859
|
+
break
|
|
860
|
+
|
|
861
|
+
if repo_to_delete is None:
|
|
862
|
+
error_msg = f"Model '{model_id}' not found in cache"
|
|
863
|
+
raise FileNotFoundError(error_msg)
|
|
864
|
+
|
|
865
|
+
# Get all revision hashes for this repo
|
|
866
|
+
revision_hashes = [revision.commit_hash for revision in repo_to_delete.revisions]
|
|
867
|
+
|
|
868
|
+
if not revision_hashes:
|
|
869
|
+
error_msg = f"No revisions found for model '{model_id}'"
|
|
870
|
+
raise FileNotFoundError(error_msg)
|
|
871
|
+
|
|
872
|
+
# Create delete strategy for all revisions of this repo
|
|
873
|
+
delete_strategy = cache_info.delete_revisions(*revision_hashes)
|
|
874
|
+
|
|
875
|
+
# Execute the deletion
|
|
876
|
+
delete_strategy.execute()
|
|
877
|
+
|
|
878
|
+
return f"Deleted model '{model_id}' (freed {delete_strategy.expected_freed_size_str})"
|
|
879
|
+
|
|
880
|
+
def _get_model_info(self, model_dir: Path) -> dict[str, str | int | float] | None:
|
|
881
|
+
"""Get information about a cached model.
|
|
882
|
+
|
|
883
|
+
Args:
|
|
884
|
+
model_dir: Path to the model directory in cache
|
|
885
|
+
|
|
886
|
+
Returns:
|
|
887
|
+
dict | None: Model information or None if not a valid model directory
|
|
888
|
+
"""
|
|
889
|
+
try:
|
|
890
|
+
# Extract model_id from directory name
|
|
891
|
+
# HuggingFace cache format: models--{org}--{model}--{hash}
|
|
892
|
+
dir_name = model_dir.name
|
|
893
|
+
if not dir_name.startswith("models--"):
|
|
894
|
+
return None
|
|
895
|
+
|
|
896
|
+
# Parse the model ID from the directory name
|
|
897
|
+
parts = dir_name.split("--")
|
|
898
|
+
if len(parts) >= MIN_CACHE_DIR_PARTS:
|
|
899
|
+
# Reconstruct model_id as org/model
|
|
900
|
+
model_id = f"{parts[1]}/{parts[2]}"
|
|
901
|
+
else:
|
|
902
|
+
model_id = dir_name[8:] # Remove "models--" prefix
|
|
903
|
+
|
|
904
|
+
# Calculate directory size
|
|
905
|
+
total_size = sum(f.stat().st_size for f in model_dir.rglob("*") if f.is_file())
|
|
906
|
+
|
|
907
|
+
return {
|
|
908
|
+
"model_id": model_id,
|
|
909
|
+
"local_path": str(model_dir),
|
|
910
|
+
"size_bytes": total_size,
|
|
911
|
+
"size_mb": round(total_size / (1024 * 1024), 2),
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
except Exception:
|
|
915
|
+
# If we can't parse the directory, skip it
|
|
916
|
+
return None
|
|
917
|
+
|
|
918
|
+
def _download_model(self, download_params: dict[str, str | list[str] | None]) -> Path:
|
|
919
|
+
"""Model download implementation.
|
|
920
|
+
|
|
921
|
+
Args:
|
|
922
|
+
download_params: Dictionary containing download parameters
|
|
923
|
+
|
|
924
|
+
Returns:
|
|
925
|
+
Path: Local path where the model was downloaded
|
|
926
|
+
"""
|
|
927
|
+
# Validate parameters and build download kwargs
|
|
928
|
+
download_kwargs = self._build_download_kwargs(download_params)
|
|
929
|
+
|
|
930
|
+
# Set the current model ID for progress tracking
|
|
931
|
+
model_id = download_params["model_id"]
|
|
932
|
+
if isinstance(model_id, str):
|
|
933
|
+
ModelDownloadTracker.set_current_model_id(model_id)
|
|
934
|
+
|
|
935
|
+
# Execute download with progress tracking
|
|
936
|
+
local_path = snapshot_download(**download_kwargs) # type: ignore[arg-type]
|
|
937
|
+
|
|
938
|
+
# Clear the current model ID when done
|
|
939
|
+
ModelDownloadTracker.set_current_model_id("")
|
|
940
|
+
|
|
941
|
+
return Path(local_path)
|
|
942
|
+
|
|
943
|
+
def _build_download_kwargs(self, download_params: dict[str, str | list[str] | None]) -> dict:
|
|
944
|
+
"""Build kwargs for snapshot_download with validation.
|
|
945
|
+
|
|
946
|
+
Args:
|
|
947
|
+
download_params: Dictionary containing download parameters
|
|
948
|
+
|
|
949
|
+
Returns:
|
|
950
|
+
dict: Validated download kwargs for snapshot_download
|
|
951
|
+
"""
|
|
952
|
+
param_model_id = download_params["model_id"]
|
|
953
|
+
local_dir = download_params["local_dir"]
|
|
954
|
+
revision = download_params["revision"]
|
|
955
|
+
allow_patterns = download_params["allow_patterns"]
|
|
956
|
+
ignore_patterns = download_params["ignore_patterns"]
|
|
957
|
+
|
|
958
|
+
# Build base kwargs with custom progress tracking
|
|
959
|
+
download_kwargs: dict[str, Any] = {
|
|
960
|
+
"repo_id": param_model_id,
|
|
961
|
+
"repo_type": "model",
|
|
962
|
+
"revision": revision,
|
|
963
|
+
"tqdm_class": ModelDownloadTracker,
|
|
964
|
+
}
|
|
965
|
+
|
|
966
|
+
# Add optional parameters
|
|
967
|
+
if local_dir is not None and isinstance(local_dir, str):
|
|
968
|
+
download_kwargs["local_dir"] = local_dir
|
|
969
|
+
if allow_patterns is not None and isinstance(allow_patterns, list):
|
|
970
|
+
download_kwargs["allow_patterns"] = allow_patterns
|
|
971
|
+
if ignore_patterns is not None and isinstance(ignore_patterns, list):
|
|
972
|
+
download_kwargs["ignore_patterns"] = ignore_patterns
|
|
973
|
+
|
|
974
|
+
return download_kwargs
|
|
975
|
+
|
|
976
|
+
def _parse_model_id(self, model_input: str) -> str:
|
|
977
|
+
"""Parse model ID from either a direct model ID or a Hugging Face URL.
|
|
978
|
+
|
|
979
|
+
Args:
|
|
980
|
+
model_input: Either a model ID (e.g., 'microsoft/DialoGPT-medium')
|
|
981
|
+
or a Hugging Face URL (e.g., 'https://huggingface.co/microsoft/DialoGPT-medium')
|
|
982
|
+
|
|
983
|
+
Returns:
|
|
984
|
+
str: The parsed model ID in the format 'namespace/repo_name' or 'repo_name'
|
|
985
|
+
"""
|
|
986
|
+
# If it's already a simple model ID (no URL scheme), return as-is
|
|
987
|
+
if not model_input.startswith(("http://", "https://")):
|
|
988
|
+
return model_input
|
|
989
|
+
|
|
990
|
+
# Parse the URL
|
|
991
|
+
parsed = urlparse(model_input)
|
|
992
|
+
|
|
993
|
+
# Check if it's a Hugging Face URL
|
|
994
|
+
if parsed.netloc in ("huggingface.co", "www.huggingface.co"):
|
|
995
|
+
# Extract the path and remove leading slash
|
|
996
|
+
path = parsed.path.lstrip("/")
|
|
997
|
+
|
|
998
|
+
# Remove any trailing parameters or fragments
|
|
999
|
+
# The model ID should be in the format: namespace/repo_name or just repo_name
|
|
1000
|
+
model_id_match = re.match(r"^([^/]+/[^/?#]+|[^/?#]+)", path)
|
|
1001
|
+
if model_id_match:
|
|
1002
|
+
return model_id_match.group(1)
|
|
1003
|
+
|
|
1004
|
+
# If we can't parse it, return the original input and let huggingface_hub handle the error
|
|
1005
|
+
return model_input
|
|
1006
|
+
|
|
1007
|
+
def _get_download_statuses(self, model_id: str | None = None) -> list[ModelDownloadStatus]:
|
|
1008
|
+
"""Get download statuses for a specific model or all models.
|
|
1009
|
+
|
|
1010
|
+
Args:
|
|
1011
|
+
model_id: Optional model ID to get status for. If None, returns all statuses.
|
|
1012
|
+
|
|
1013
|
+
Returns:
|
|
1014
|
+
list[ModelDownloadStatus]: List of download statuses
|
|
1015
|
+
"""
|
|
1016
|
+
if model_id:
|
|
1017
|
+
# Get status for specific model
|
|
1018
|
+
status = self._read_model_download_status(model_id)
|
|
1019
|
+
return [status] if status else []
|
|
1020
|
+
# Get all download statuses
|
|
1021
|
+
return self._list_all_download_statuses()
|
|
1022
|
+
|
|
1023
|
+
async def on_handle_delete_model_download_request(self, request: DeleteModelDownloadRequest) -> ResultPayload:
|
|
1024
|
+
"""Handle model download status deletion requests asynchronously.
|
|
1025
|
+
|
|
1026
|
+
This method removes download tracking records for a specific model
|
|
1027
|
+
from the local status files stored in the XDG data directory.
|
|
1028
|
+
If the model is currently downloading or failed, it also cancels
|
|
1029
|
+
the download task and deletes any cached model files.
|
|
1030
|
+
|
|
1031
|
+
Args:
|
|
1032
|
+
request: The delete request containing model_id
|
|
1033
|
+
|
|
1034
|
+
Returns:
|
|
1035
|
+
ResultPayload: Success result with deletion confirmation or failure with error details
|
|
1036
|
+
"""
|
|
1037
|
+
model_id = request.model_id
|
|
1038
|
+
|
|
1039
|
+
try:
|
|
1040
|
+
# Check current download status first
|
|
1041
|
+
download_status = await asyncio.to_thread(self._read_model_download_status, model_id)
|
|
1042
|
+
|
|
1043
|
+
# Cancel active download process if it exists
|
|
1044
|
+
if model_id in self._download_processes:
|
|
1045
|
+
process = self._download_processes[model_id]
|
|
1046
|
+
await cancel_subprocess(process, f"download process for model '{model_id}'")
|
|
1047
|
+
del self._download_processes[model_id]
|
|
1048
|
+
|
|
1049
|
+
if model_id in self._download_tasks:
|
|
1050
|
+
task = self._download_tasks[model_id]
|
|
1051
|
+
if not task.done():
|
|
1052
|
+
task.cancel()
|
|
1053
|
+
logger.info("Cancelled active download task for model '%s'", model_id)
|
|
1054
|
+
del self._download_tasks[model_id]
|
|
1055
|
+
|
|
1056
|
+
# Delete status file
|
|
1057
|
+
deleted_path = await asyncio.to_thread(self._delete_model_download_status, model_id)
|
|
1058
|
+
|
|
1059
|
+
# Only delete cached model if it's not completed
|
|
1060
|
+
if download_status and download_status.status != "completed":
|
|
1061
|
+
try:
|
|
1062
|
+
await asyncio.to_thread(self._delete_model, model_id)
|
|
1063
|
+
except FileNotFoundError:
|
|
1064
|
+
logger.debug("No cached model files found for '%s'", model_id)
|
|
1065
|
+
|
|
1066
|
+
result_details = f"Successfully deleted download status for model '{model_id}'"
|
|
1067
|
+
|
|
1068
|
+
return DeleteModelDownloadResultSuccess(
|
|
1069
|
+
model_id=model_id,
|
|
1070
|
+
deleted_path=deleted_path,
|
|
1071
|
+
result_details=result_details,
|
|
1072
|
+
)
|
|
1073
|
+
|
|
1074
|
+
except FileNotFoundError:
|
|
1075
|
+
error_msg = f"Download status for model '{model_id}' not found"
|
|
1076
|
+
return DeleteModelDownloadResultFailure(
|
|
1077
|
+
result_details=error_msg,
|
|
1078
|
+
exception=FileNotFoundError(error_msg),
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
except Exception as e:
|
|
1082
|
+
error_msg = f"Failed to delete download status for '{model_id}': {e}"
|
|
1083
|
+
return DeleteModelDownloadResultFailure(
|
|
1084
|
+
result_details=error_msg,
|
|
1085
|
+
exception=e,
|
|
1086
|
+
)
|
|
1087
|
+
|
|
1088
|
+
def _delete_model_download_status(self, model_id: str) -> str:
|
|
1089
|
+
"""Delete download status file for a specific model.
|
|
1090
|
+
|
|
1091
|
+
Args:
|
|
1092
|
+
model_id: The model ID to remove download status for
|
|
1093
|
+
|
|
1094
|
+
Returns:
|
|
1095
|
+
str: Path to the deleted status file
|
|
1096
|
+
|
|
1097
|
+
Raises:
|
|
1098
|
+
FileNotFoundError: If the status file is not found
|
|
1099
|
+
"""
|
|
1100
|
+
status_file = self._get_status_file_path(model_id)
|
|
1101
|
+
|
|
1102
|
+
if not status_file.exists():
|
|
1103
|
+
msg = f"Download status file not found for model '{model_id}'"
|
|
1104
|
+
raise FileNotFoundError(msg)
|
|
1105
|
+
|
|
1106
|
+
status_file.unlink()
|
|
1107
|
+
return str(status_file)
|