griptape-nodes 0.52.1__py3-none-any.whl → 0.54.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. griptape_nodes/__init__.py +8 -942
  2. griptape_nodes/__main__.py +6 -0
  3. griptape_nodes/app/app.py +48 -86
  4. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
  5. griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
  6. griptape_nodes/cli/__init__.py +1 -0
  7. griptape_nodes/cli/commands/__init__.py +1 -0
  8. griptape_nodes/cli/commands/config.py +74 -0
  9. griptape_nodes/cli/commands/engine.py +80 -0
  10. griptape_nodes/cli/commands/init.py +550 -0
  11. griptape_nodes/cli/commands/libraries.py +96 -0
  12. griptape_nodes/cli/commands/models.py +504 -0
  13. griptape_nodes/cli/commands/self.py +120 -0
  14. griptape_nodes/cli/main.py +56 -0
  15. griptape_nodes/cli/shared.py +75 -0
  16. griptape_nodes/common/__init__.py +1 -0
  17. griptape_nodes/common/directed_graph.py +71 -0
  18. griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
  19. griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
  20. griptape_nodes/drivers/storage/local_storage_driver.py +23 -14
  21. griptape_nodes/exe_types/core_types.py +60 -2
  22. griptape_nodes/exe_types/node_types.py +257 -38
  23. griptape_nodes/exe_types/param_components/__init__.py +1 -0
  24. griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
  25. griptape_nodes/machines/control_flow.py +195 -94
  26. griptape_nodes/machines/dag_builder.py +207 -0
  27. griptape_nodes/machines/fsm.py +10 -1
  28. griptape_nodes/machines/parallel_resolution.py +558 -0
  29. griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +30 -57
  30. griptape_nodes/node_library/library_registry.py +34 -1
  31. griptape_nodes/retained_mode/events/app_events.py +5 -1
  32. griptape_nodes/retained_mode/events/base_events.py +9 -9
  33. griptape_nodes/retained_mode/events/config_events.py +30 -0
  34. griptape_nodes/retained_mode/events/execution_events.py +2 -2
  35. griptape_nodes/retained_mode/events/model_events.py +296 -0
  36. griptape_nodes/retained_mode/events/node_events.py +4 -3
  37. griptape_nodes/retained_mode/griptape_nodes.py +34 -12
  38. griptape_nodes/retained_mode/managers/agent_manager.py +23 -5
  39. griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
  40. griptape_nodes/retained_mode/managers/config_manager.py +44 -3
  41. griptape_nodes/retained_mode/managers/context_manager.py +6 -5
  42. griptape_nodes/retained_mode/managers/event_manager.py +8 -2
  43. griptape_nodes/retained_mode/managers/flow_manager.py +150 -206
  44. griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
  45. griptape_nodes/retained_mode/managers/library_manager.py +35 -25
  46. griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
  47. griptape_nodes/retained_mode/managers/node_manager.py +102 -220
  48. griptape_nodes/retained_mode/managers/object_manager.py +11 -5
  49. griptape_nodes/retained_mode/managers/os_manager.py +28 -13
  50. griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
  51. griptape_nodes/retained_mode/managers/settings.py +116 -7
  52. griptape_nodes/retained_mode/managers/static_files_manager.py +85 -12
  53. griptape_nodes/retained_mode/managers/sync_manager.py +17 -9
  54. griptape_nodes/retained_mode/managers/workflow_manager.py +186 -192
  55. griptape_nodes/retained_mode/retained_mode.py +19 -0
  56. griptape_nodes/servers/__init__.py +1 -0
  57. griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
  58. griptape_nodes/{app/api.py → servers/static.py} +43 -40
  59. griptape_nodes/traits/add_param_button.py +1 -1
  60. griptape_nodes/traits/button.py +334 -6
  61. griptape_nodes/traits/color_picker.py +66 -0
  62. griptape_nodes/traits/multi_options.py +188 -0
  63. griptape_nodes/traits/numbers_selector.py +77 -0
  64. griptape_nodes/traits/options.py +93 -2
  65. griptape_nodes/traits/traits.json +4 -0
  66. griptape_nodes/utils/async_utils.py +31 -0
  67. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +4 -1
  68. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +71 -48
  69. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
  70. /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
  71. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,1107 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ import re
7
+ import sys
8
+ import threading
9
+ from dataclasses import dataclass
10
+ from datetime import UTC, datetime
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any
13
+ from urllib.parse import urlparse
14
+
15
+ from huggingface_hub import list_models, scan_cache_dir, snapshot_download
16
+ from huggingface_hub.utils.tqdm import tqdm
17
+ from xdg_base_dirs import xdg_data_home
18
+
19
+ from griptape_nodes.retained_mode.events.app_events import AppInitializationComplete
20
+ from griptape_nodes.retained_mode.events.model_events import (
21
+ DeleteModelDownloadRequest,
22
+ DeleteModelDownloadResultFailure,
23
+ DeleteModelDownloadResultSuccess,
24
+ DeleteModelRequest,
25
+ DeleteModelResultFailure,
26
+ DeleteModelResultSuccess,
27
+ DownloadModelRequest,
28
+ DownloadModelResultFailure,
29
+ DownloadModelResultSuccess,
30
+ ListModelDownloadsRequest,
31
+ ListModelDownloadsResultFailure,
32
+ ListModelDownloadsResultSuccess,
33
+ ListModelsRequest,
34
+ ListModelsResultFailure,
35
+ ListModelsResultSuccess,
36
+ ModelDownloadStatus,
37
+ ModelInfo,
38
+ QueryInfo,
39
+ SearchModelsRequest,
40
+ SearchModelsResultFailure,
41
+ SearchModelsResultSuccess,
42
+ )
43
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
44
+ from griptape_nodes.utils.async_utils import cancel_subprocess
45
+
46
+ if TYPE_CHECKING:
47
+ from griptape_nodes.retained_mode.events.base_events import ResultPayload
48
+ from griptape_nodes.retained_mode.managers.event_manager import EventManager
49
+
50
+ logger = logging.getLogger("griptape_nodes")
51
+
52
+
53
+ HTTP_UNAUTHORIZED = 401
54
+ HTTP_FORBIDDEN = 403
55
+
56
+ MIN_CACHE_DIR_PARTS = 3
57
+
58
+
59
+ @dataclass
60
+ class SearchResultsData:
61
+ """Data class for model search results."""
62
+
63
+ models: list[ModelInfo]
64
+ total_results: int
65
+ query_info: QueryInfo
66
+
67
+
68
+ class ModelDownloadTracker(tqdm):
69
+ """Custom tqdm progress bar that tracks aggregate model download progress."""
70
+
71
+ _file_lock = threading.Lock()
72
+ _current_model_id = ""
73
+
74
+ @classmethod
75
+ def set_current_model_id(cls, model_id: str) -> None:
76
+ """Set the current model being downloaded."""
77
+ cls._current_model_id = model_id
78
+
79
+ def __init__(self, *args, model_id: str = "", **kwargs):
80
+ if not model_id and self._current_model_id:
81
+ model_id = self._current_model_id
82
+
83
+ super().__init__(*args, **kwargs)
84
+ self.model_id = model_id
85
+ self.start_time = datetime.now(UTC).isoformat()
86
+
87
+ logger.debug(
88
+ "ModelDownloadTracker created - model_id: %s, total: %s, desc: %s, args: %s",
89
+ self.model_id,
90
+ self.total,
91
+ getattr(self, "desc", None),
92
+ args,
93
+ )
94
+
95
+ if self.model_id:
96
+ self._init_status_file()
97
+
98
+ def update(self, n: int = 1) -> None:
99
+ """Override update to track progress in status file."""
100
+ logger.debug(
101
+ "ModelDownloadTracker update - model_id: %s, n: %s, self.n: %s, total: %s",
102
+ self.model_id,
103
+ n,
104
+ self.n,
105
+ self.total,
106
+ )
107
+ super().update(n)
108
+ self._update_status_file()
109
+
110
+ def close(self) -> None:
111
+ """Override close to log download completion."""
112
+ super().close()
113
+ logger.debug(
114
+ "ModelDownloadTracker close - model_id: %s, self.n: %s, total: %s", self.model_id, self.n, self.total
115
+ )
116
+
117
+ def _get_status_file_path(self) -> Path:
118
+ """Get the path to the status file for this model."""
119
+ status_dir = xdg_data_home() / "griptape_nodes" / "model_downloads"
120
+ status_dir.mkdir(parents=True, exist_ok=True)
121
+
122
+ sanitized_model_id = re.sub(r"[^\w\-_]", "--", self.model_id)
123
+ return status_dir / f"{sanitized_model_id}.json"
124
+
125
+ def _init_status_file(self) -> None:
126
+ """Initialize the status file for this model."""
127
+ try:
128
+ with self._file_lock:
129
+ status_file = self._get_status_file_path()
130
+ current_time = datetime.now(UTC).isoformat()
131
+
132
+ logger.info(
133
+ "ModelDownloadTracker initializing status file: %s (total_files=%s)", status_file, self.total
134
+ )
135
+
136
+ data = {
137
+ "model_id": self.model_id,
138
+ "status": "downloading",
139
+ "started_at": current_time,
140
+ "updated_at": current_time,
141
+ "total_files": self.total or 0,
142
+ "downloaded_files": 0,
143
+ "progress_percent": 0.0,
144
+ }
145
+
146
+ with status_file.open("w") as f:
147
+ json.dump(data, f, indent=2)
148
+
149
+ logger.info("ModelDownloadTracker status file initialized successfully")
150
+
151
+ except Exception:
152
+ logger.exception("ModelDownloadTracker._init_status_file failed")
153
+
154
+ def _update_status_file(self) -> None:
155
+ """Update the status file with current progress."""
156
+ if not self.model_id:
157
+ logger.warning("ModelDownloadTracker._update_status_file called with empty model_id")
158
+ return
159
+
160
+ try:
161
+ with self._file_lock:
162
+ status_file = self._get_status_file_path()
163
+ logger.info("ModelDownloadTracker updating status file: %s", status_file)
164
+
165
+ if not status_file.exists():
166
+ logger.warning("Status file does not exist: %s", status_file)
167
+ return
168
+
169
+ with status_file.open() as f:
170
+ data = json.load(f)
171
+
172
+ current_time = datetime.now(UTC).isoformat()
173
+ progress_percent = (self.n / self.total * 100) if self.total else 0
174
+
175
+ logger.info(
176
+ "ModelDownloadTracker updating progress: files=%d/%d, percent=%.1f%%",
177
+ self.n,
178
+ self.total,
179
+ progress_percent,
180
+ )
181
+
182
+ data.update(
183
+ {
184
+ "downloaded_files": self.n,
185
+ "progress_percent": progress_percent,
186
+ "updated_at": current_time,
187
+ }
188
+ )
189
+
190
+ with status_file.open("w") as f:
191
+ json.dump(data, f, indent=2)
192
+
193
+ logger.debug("ModelDownloadTracker status file updated successfully")
194
+
195
+ except Exception:
196
+ logger.exception("ModelDownloadTracker._update_status_file failed")
197
+
198
+
199
+ class ModelManager:
200
+ """A manager for downloading models from Hugging Face Hub.
201
+
202
+ This manager provides async handlers for downloading models using the Hugging Face Hub API.
203
+ It supports downloading entire model repositories or specific files, with caching and
204
+ local storage management.
205
+ """
206
+
207
+ def __init__(self, event_manager: EventManager | None = None) -> None:
208
+ """Initialize the ModelManager.
209
+
210
+ Args:
211
+ event_manager: The EventManager instance to use for event handling.
212
+ """
213
+ self._download_tasks = {}
214
+ self._download_processes = {}
215
+
216
+ if event_manager is not None:
217
+ event_manager.assign_manager_to_request_type(DownloadModelRequest, self.on_handle_download_model_request)
218
+ event_manager.assign_manager_to_request_type(ListModelsRequest, self.on_handle_list_models_request)
219
+ event_manager.assign_manager_to_request_type(DeleteModelRequest, self.on_handle_delete_model_request)
220
+ event_manager.assign_manager_to_request_type(SearchModelsRequest, self.on_handle_search_models_request)
221
+ event_manager.assign_manager_to_request_type(
222
+ ListModelDownloadsRequest, self.on_handle_list_model_downloads_request
223
+ )
224
+ event_manager.assign_manager_to_request_type(
225
+ DeleteModelDownloadRequest, self.on_handle_delete_model_download_request
226
+ )
227
+
228
+ event_manager.add_listener_to_app_event(AppInitializationComplete, self.on_app_initialization_complete)
229
+
230
+ def download_model(
231
+ self,
232
+ model_id: str,
233
+ local_dir: str | None = None,
234
+ revision: str = "main",
235
+ allow_patterns: list[str] | None = None,
236
+ ignore_patterns: list[str] | None = None,
237
+ ) -> str:
238
+ """Direct model download method that can be used without event system.
239
+
240
+ This method contains the core download logic without going through
241
+ the event system, avoiding recursion issues.
242
+
243
+ Args:
244
+ model_id: Model ID to download
245
+ local_dir: Optional local directory to download to
246
+ revision: Git revision to download
247
+ allow_patterns: Optional glob patterns to include
248
+ ignore_patterns: Optional glob patterns to exclude
249
+
250
+ Returns:
251
+ str: Local path where the model was downloaded
252
+
253
+ Raises:
254
+ Exception: If download fails
255
+ """
256
+ # Set up progress tracking
257
+ ModelDownloadTracker.set_current_model_id(model_id)
258
+
259
+ try:
260
+ # Build download kwargs
261
+ download_kwargs = {
262
+ "repo_id": model_id,
263
+ "repo_type": "model",
264
+ "revision": revision,
265
+ "tqdm_class": ModelDownloadTracker,
266
+ }
267
+
268
+ # Add optional parameters
269
+ if local_dir:
270
+ download_kwargs["local_dir"] = local_dir
271
+ if allow_patterns:
272
+ download_kwargs["allow_patterns"] = allow_patterns
273
+ if ignore_patterns:
274
+ download_kwargs["ignore_patterns"] = ignore_patterns
275
+
276
+ # Execute download with progress tracking
277
+ local_path = snapshot_download(**download_kwargs) # type: ignore[arg-type]
278
+
279
+ return str(local_path)
280
+
281
+ finally:
282
+ # Clear the current model ID when done
283
+ ModelDownloadTracker.set_current_model_id("")
284
+
285
+ def _get_status_directory(self) -> Path:
286
+ """Get the status directory path for model downloads.
287
+
288
+ Returns:
289
+ Path: Path to the status directory, creating it if needed
290
+ """
291
+ status_dir = xdg_data_home() / "griptape_nodes" / "model_downloads"
292
+ status_dir.mkdir(parents=True, exist_ok=True)
293
+ return status_dir
294
+
295
+ async def on_handle_download_model_request(self, request: DownloadModelRequest) -> ResultPayload:
296
+ """Handle model download requests asynchronously.
297
+
298
+ This method starts a background task to download models from Hugging Face Hub using the provided parameters.
299
+ It supports both model IDs and full URLs, and can download entire repositories
300
+ or specific files based on the patterns provided.
301
+
302
+ Args:
303
+ request: The download request containing model ID and options
304
+
305
+ Returns:
306
+ ResultPayload: Success result indicating download started or failure with error details
307
+ """
308
+ parsed_model_id = self._parse_model_id(request.model_id)
309
+ if parsed_model_id != request.model_id:
310
+ logger.debug("Parsed model ID '%s' from URL '%s'", parsed_model_id, request.model_id)
311
+
312
+ try:
313
+ download_params = {
314
+ "model_id": parsed_model_id,
315
+ "local_dir": request.local_dir,
316
+ "revision": request.revision,
317
+ "allow_patterns": request.allow_patterns,
318
+ "ignore_patterns": request.ignore_patterns,
319
+ }
320
+
321
+ task = asyncio.create_task(self._download_model_task(download_params))
322
+ self._download_tasks[parsed_model_id] = task
323
+
324
+ result_details = f"Started background download for model '{parsed_model_id}'"
325
+
326
+ return DownloadModelResultSuccess(
327
+ model_id=parsed_model_id,
328
+ result_details=result_details,
329
+ )
330
+
331
+ except Exception as e:
332
+ error_msg = f"Failed to start download for model '{request.model_id}': {e}"
333
+ return DownloadModelResultFailure(
334
+ result_details=error_msg,
335
+ exception=e,
336
+ )
337
+
338
+ async def _download_model_task(self, download_params: dict[str, str | list[str] | None]) -> None:
339
+ """Background task for downloading a model using CLI command.
340
+
341
+ Args:
342
+ download_params: Dictionary containing download parameters
343
+ """
344
+ model_id = download_params["model_id"]
345
+ logger.info("Starting background download for model: %s", model_id)
346
+
347
+ # Build CLI command arguments
348
+ cmd = [sys.executable, "-m", "griptape_nodes", "models", "download", str(model_id)]
349
+
350
+ # Add optional parameters
351
+ if download_params.get("local_dir"):
352
+ cmd.extend(["--local-dir", str(download_params["local_dir"])])
353
+ if download_params.get("revision") and download_params["revision"] != "main":
354
+ cmd.extend(["--revision", str(download_params["revision"])])
355
+
356
+ try:
357
+ # Start subprocess
358
+ process = await asyncio.create_subprocess_exec(
359
+ *cmd,
360
+ stdout=asyncio.subprocess.PIPE,
361
+ stderr=asyncio.subprocess.PIPE,
362
+ )
363
+
364
+ # Store process for cancellation
365
+ if isinstance(model_id, str):
366
+ self._download_processes[model_id] = process
367
+
368
+ # Wait for completion (this can be cancelled)
369
+ stdout, stderr = await process.communicate()
370
+
371
+ if process.returncode == 0:
372
+ logger.info("Successfully downloaded model '%s'", model_id)
373
+ else:
374
+ logger.error("Download failed for model '%s': %s", model_id, stderr.decode())
375
+
376
+ except asyncio.CancelledError:
377
+ logger.info("Download cancelled for model '%s'", model_id)
378
+ raise
379
+
380
+ except Exception:
381
+ logger.exception("Error downloading model '%s'", model_id)
382
+
383
+ finally:
384
+ if isinstance(model_id, str):
385
+ if model_id in self._download_tasks:
386
+ del self._download_tasks[model_id]
387
+ if model_id in self._download_processes:
388
+ del self._download_processes[model_id]
389
+
390
+ async def on_handle_list_models_request(self, request: ListModelsRequest) -> ResultPayload: # noqa: ARG002
391
+ """Handle model listing requests asynchronously.
392
+
393
+ This method scans the local Hugging Face cache directory to find downloaded models
394
+ and returns information about each model including path, size, and metadata.
395
+
396
+ Args:
397
+ request: The list request (no parameters needed)
398
+
399
+ Returns:
400
+ ResultPayload: Success result with model list or failure with error details
401
+ """
402
+ try:
403
+ # Get models in a thread to avoid blocking the event loop
404
+ models = await asyncio.to_thread(self._list_models)
405
+
406
+ result_details = f"Found {len(models)} cached models"
407
+
408
+ return ListModelsResultSuccess(
409
+ models=models,
410
+ result_details=result_details,
411
+ )
412
+
413
+ except Exception as e:
414
+ error_msg = f"Failed to list models: {e}"
415
+ return ListModelsResultFailure(
416
+ result_details=error_msg,
417
+ exception=e,
418
+ )
419
+
420
+ async def on_handle_delete_model_request(self, request: DeleteModelRequest) -> ResultPayload:
421
+ """Handle model deletion requests asynchronously.
422
+
423
+ This method removes a model from the local Hugging Face cache directory and
424
+ cleans up any associated download tracking records.
425
+
426
+ Args:
427
+ request: The delete request containing model_id
428
+
429
+ Returns:
430
+ ResultPayload: Success result with deletion confirmation or failure with error details
431
+ """
432
+ # Parse the model ID from potential URL
433
+ model_id = request.model_id
434
+
435
+ deleted_items = []
436
+
437
+ try:
438
+ deleted_path = await asyncio.to_thread(self._delete_model, model_id)
439
+ deleted_items.append(f"model files from '{deleted_path}'")
440
+
441
+ except FileNotFoundError:
442
+ logger.debug("No model files found for '%s' in cache", model_id)
443
+
444
+ except Exception as e:
445
+ error_msg = f"Failed to delete model files for '{model_id}': {e}"
446
+ return DeleteModelResultFailure(
447
+ result_details=error_msg,
448
+ exception=e,
449
+ )
450
+
451
+ if not deleted_items:
452
+ error_msg = f"Model '{model_id}' not found (no cached files or download records)"
453
+ return DeleteModelResultFailure(
454
+ result_details=error_msg,
455
+ exception=FileNotFoundError(error_msg),
456
+ )
457
+
458
+ deleted_description = " and ".join(deleted_items)
459
+ result_details = f"Successfully deleted {deleted_description} for model '{model_id}'"
460
+
461
+ return DeleteModelResultSuccess(
462
+ model_id=model_id,
463
+ deleted_path=deleted_items[0] if deleted_items else "",
464
+ result_details=result_details,
465
+ )
466
+
467
+ async def on_handle_search_models_request(self, request: SearchModelsRequest) -> ResultPayload:
468
+ """Handle model search requests asynchronously.
469
+
470
+ This method searches for models on Hugging Face Hub using the provided parameters.
471
+ It supports filtering by query, task, library, author, and tags.
472
+
473
+ Args:
474
+ request: The search request containing search parameters
475
+
476
+ Returns:
477
+ ResultPayload: Success result with model list or failure with error details
478
+ """
479
+ try:
480
+ # Search models in a thread to avoid blocking the event loop
481
+ search_results = await asyncio.to_thread(self._search_models, request)
482
+ except Exception as e:
483
+ error_msg = f"Failed to search models: {e}"
484
+ return SearchModelsResultFailure(
485
+ result_details=error_msg,
486
+ exception=e,
487
+ )
488
+ else:
489
+ result_details = f"Found {len(search_results.models)} models"
490
+ return SearchModelsResultSuccess(
491
+ models=search_results.models,
492
+ total_results=search_results.total_results,
493
+ query_info=search_results.query_info,
494
+ result_details=result_details,
495
+ )
496
+
497
+ def _search_models(self, request: SearchModelsRequest) -> SearchResultsData:
498
+ """Synchronous model search implementation.
499
+
500
+ Searches for models on Hugging Face Hub using the huggingface_hub API.
501
+
502
+ Args:
503
+ request: The search request parameters
504
+
505
+ Returns:
506
+ SearchResultsData: Dataclass containing models list, total results, and query info
507
+ """
508
+ # Build search parameters
509
+ search_params = {}
510
+
511
+ if request.query:
512
+ search_params["search"] = request.query
513
+ if request.task:
514
+ search_params["task"] = request.task
515
+ if request.library:
516
+ search_params["library"] = request.library
517
+ if request.author:
518
+ search_params["author"] = request.author
519
+ if request.tags:
520
+ search_params["tags"] = request.tags
521
+
522
+ # Validate and set sort parameters
523
+ valid_sorts = ["downloads", "likes", "updated", "created"]
524
+ sort_param = request.sort if request.sort in valid_sorts else "downloads"
525
+ search_params["sort"] = sort_param
526
+
527
+ # Only add direction for sorts that support it (downloads only supports descending)
528
+ if sort_param != "downloads":
529
+ # Convert direction to the format expected by HF Hub API (-1 for asc, 1 for desc)
530
+ direction_param = -1 if request.direction == "asc" else 1
531
+ search_params["direction"] = direction_param
532
+
533
+ # Limit results (max 100 as per HF Hub API)
534
+ limit = min(max(1, request.limit), 100)
535
+
536
+ # Perform the search
537
+ models_iterator = list_models(limit=limit, **search_params)
538
+
539
+ # Convert models to list and extract information
540
+ models_list = []
541
+ for model in models_iterator:
542
+ created_at = getattr(model, "created_at", None)
543
+ updated_at = getattr(model, "last_modified", None)
544
+
545
+ model_info = ModelInfo(
546
+ model_id=model.id,
547
+ author=getattr(model, "author", None),
548
+ downloads=getattr(model, "downloads", None),
549
+ likes=getattr(model, "likes", None),
550
+ created_at=created_at.isoformat() if created_at else None,
551
+ updated_at=updated_at.isoformat() if updated_at else None,
552
+ task=getattr(model, "pipeline_tag", None),
553
+ library=getattr(model, "library_name", None),
554
+ tags=getattr(model, "tags", None),
555
+ )
556
+ models_list.append(model_info)
557
+
558
+ # Prepare query info for response
559
+ query_info = QueryInfo(
560
+ query=request.query,
561
+ task=request.task,
562
+ library=request.library,
563
+ author=request.author,
564
+ tags=request.tags,
565
+ limit=limit,
566
+ sort=sort_param,
567
+ direction=request.direction, # Keep the original user-friendly format
568
+ )
569
+
570
+ return SearchResultsData(
571
+ models=models_list,
572
+ total_results=len(models_list),
573
+ query_info=query_info,
574
+ )
575
+
576
+ async def on_app_initialization_complete(self, _payload: AppInitializationComplete) -> None:
577
+ """Handle app initialization complete event by downloading configured models and resuming unfinished downloads.
578
+
579
+ Args:
580
+ payload: The app initialization complete payload
581
+ """
582
+ # Get models to download from configuration
583
+ config_manager = GriptapeNodes.ConfigManager()
584
+ models_to_download = config_manager.get_config_value(
585
+ "app_events.on_app_initialization_complete.models_to_download", default=[]
586
+ )
587
+
588
+ # Find unfinished downloads to resume
589
+ unfinished_models = await asyncio.to_thread(self._find_unfinished_downloads)
590
+
591
+ # Combine new downloads and unfinished ones, avoiding duplicates
592
+ all_models = list(
593
+ dict.fromkeys(
594
+ [
595
+ *[model_id for model_id in models_to_download if model_id], # Filter empty strings
596
+ *unfinished_models,
597
+ ]
598
+ )
599
+ )
600
+
601
+ if not all_models:
602
+ logger.debug("No models to download or resume")
603
+ return
604
+
605
+ logger.info(
606
+ "Starting download/resume of %d models (%d new, %d resuming)",
607
+ len(all_models),
608
+ len(models_to_download),
609
+ len(unfinished_models),
610
+ )
611
+
612
+ # Create download tasks for concurrent execution
613
+ download_tasks = []
614
+ for model_id in all_models:
615
+ task = asyncio.create_task(self._download_model_with_logging(model_id))
616
+ download_tasks.append(task)
617
+
618
+ # Wait for all downloads to complete
619
+ results = await asyncio.gather(*download_tasks, return_exceptions=True)
620
+
621
+ # Log summary of results
622
+ successful = sum(1 for result in results if not isinstance(result, Exception))
623
+ failed = len(results) - successful
624
+
625
+ logger.info("Completed automatic model downloads: %d successful, %d failed", successful, failed)
626
+
627
+ async def _download_model_with_logging(self, model_id: str) -> None:
628
+ """Download a single model with proper logging.
629
+
630
+ Args:
631
+ model_id: The model ID to download
632
+ """
633
+ logger.info("Auto-downloading model: %s", model_id)
634
+
635
+ # Create download request with default parameters
636
+ request = DownloadModelRequest(
637
+ model_id=model_id,
638
+ local_dir=None, # Use default cache directory
639
+ revision="main",
640
+ allow_patterns=None,
641
+ ignore_patterns=None,
642
+ )
643
+
644
+ try:
645
+ # Run the download asynchronously
646
+ result = await self.on_handle_download_model_request(request)
647
+
648
+ if isinstance(result, DownloadModelResultFailure):
649
+ logger.warning("Failed to auto-download model '%s': %s", model_id, result.result_details)
650
+ elif not isinstance(result, DownloadModelResultSuccess):
651
+ logger.warning("Unknown result type for model '%s' download: %s", model_id, type(result))
652
+
653
+ except Exception as e:
654
+ logger.error("Unexpected error auto-downloading model '%s': %s", model_id, e)
655
+ raise
656
+
657
+ def _get_status_file_path(self, model_id: str) -> Path:
658
+ """Get the path to the status file for a model.
659
+
660
+ Args:
661
+ model_id: The model ID to get status file path for
662
+
663
+ Returns:
664
+ Path: Path to the status file for this model
665
+ """
666
+ status_dir = self._get_status_directory()
667
+
668
+ sanitized_model_id = re.sub(r"[^\w\-_]", "--", model_id)
669
+ return status_dir / f"{sanitized_model_id}.json"
670
+
671
+ def _read_model_download_status(self, model_id: str) -> ModelDownloadStatus | None:
672
+ """Read download status for a specific model.
673
+
674
+ Args:
675
+ model_id: The model ID to get status for
676
+
677
+ Returns:
678
+ ModelDownloadStatus | None: The status if found, None otherwise
679
+ """
680
+ status_file = self._get_status_file_path(model_id)
681
+
682
+ if not status_file.exists():
683
+ return None
684
+
685
+ try:
686
+ with status_file.open() as f:
687
+ data = json.load(f)
688
+
689
+ # Get file counts from simplified structure
690
+ total_files = data.get("total_files", 0)
691
+ downloaded_files = data.get("downloaded_files", 0)
692
+
693
+ # For simplified tracking, failed_files is calculated
694
+ failed_files = 0
695
+ if data.get("status") == "failed":
696
+ failed_files = total_files - downloaded_files
697
+
698
+ return ModelDownloadStatus(
699
+ model_id=data["model_id"],
700
+ status=data["status"],
701
+ started_at=data["started_at"],
702
+ updated_at=data["updated_at"],
703
+ total_files=total_files,
704
+ completed_files=downloaded_files,
705
+ failed_files=failed_files,
706
+ completed_at=data.get("completed_at"),
707
+ local_path=data.get("local_path"),
708
+ failed_at=data.get("failed_at"),
709
+ error_message=data.get("error_message"),
710
+ )
711
+
712
+ except (json.JSONDecodeError, KeyError) as e:
713
+ logger.warning("Failed to read status file for model '%s': %s", model_id, e)
714
+ return None
715
+
716
+ def _list_all_download_statuses(self) -> list[ModelDownloadStatus]:
717
+ """List all model download statuses from status files.
718
+
719
+ Returns:
720
+ list[ModelDownloadStatus]: List of all download statuses
721
+ """
722
+ status_dir = self._get_status_directory()
723
+
724
+ if not status_dir.exists():
725
+ return []
726
+
727
+ statuses = []
728
+ for status_file in status_dir.glob("*.json"):
729
+ try:
730
+ with status_file.open() as f:
731
+ data = json.load(f)
732
+
733
+ model_id = data.get("model_id", "")
734
+ if model_id:
735
+ status = self._read_model_download_status(model_id)
736
+ if status:
737
+ statuses.append(status)
738
+
739
+ except (json.JSONDecodeError, KeyError) as e:
740
+ logger.warning("Failed to read status file '%s': %s", status_file, e)
741
+ continue
742
+
743
+ return statuses
744
+
745
+ def _find_unfinished_downloads(self) -> list[str]:
746
+ """Find model IDs with unfinished downloads from status files.
747
+
748
+ Returns:
749
+ list[str]: List of model IDs with status 'downloading' or 'failed'
750
+ """
751
+ status_dir = self._get_status_directory()
752
+
753
+ if not status_dir.exists():
754
+ return []
755
+
756
+ unfinished_models = []
757
+ for status_file in status_dir.glob("*.json"):
758
+ try:
759
+ with status_file.open() as f:
760
+ data = json.load(f)
761
+
762
+ status = data.get("status", "")
763
+ model_id = data.get("model_id", "")
764
+
765
+ if model_id and status in ("downloading", "failed"):
766
+ unfinished_models.append(model_id)
767
+
768
+ except (json.JSONDecodeError, KeyError) as e:
769
+ logger.warning("Failed to read status file '%s': %s", status_file, e)
770
+ continue
771
+
772
+ return unfinished_models
773
+
774
+ async def on_handle_list_model_downloads_request(self, request: ListModelDownloadsRequest) -> ResultPayload:
775
+ """Handle model download status requests asynchronously.
776
+
777
+ This method retrieves download status for a specific model or all models
778
+ from the local status files stored in the XDG data directory.
779
+
780
+ Args:
781
+ request: The status request containing optional model_id
782
+
783
+ Returns:
784
+ ResultPayload: Success result with download status list or failure with error details
785
+ """
786
+ try:
787
+ # Get status information in a thread to avoid blocking the event loop
788
+ downloads = await asyncio.to_thread(self._get_download_statuses, request.model_id)
789
+
790
+ if request.model_id and not downloads:
791
+ result_details = f"No download status found for model '{request.model_id}'"
792
+ elif request.model_id:
793
+ result_details = f"Found download status for model '{request.model_id}'"
794
+ else:
795
+ result_details = f"Found {len(downloads)} download status records"
796
+
797
+ return ListModelDownloadsResultSuccess(
798
+ downloads=downloads,
799
+ result_details=result_details,
800
+ )
801
+
802
+ except Exception as e:
803
+ error_msg = f"Failed to get download status: {e}"
804
+ return ListModelDownloadsResultFailure(
805
+ result_details=error_msg,
806
+ exception=e,
807
+ )
808
+
809
+ def _list_models(self) -> list[ModelInfo]:
810
+ """Synchronous model listing implementation using HuggingFace Hub SDK.
811
+
812
+ Uses scan_cache_dir to get information about cached models.
813
+
814
+ Returns:
815
+ list[ModelInfo]: List of model information
816
+ """
817
+ try:
818
+ cache_info = scan_cache_dir()
819
+ models = []
820
+
821
+ for repo in cache_info.repos:
822
+ # Calculate total size across all revisions
823
+ total_size = sum(revision.size_on_disk for revision in repo.revisions)
824
+
825
+ model_info = ModelInfo(
826
+ model_id=repo.repo_id,
827
+ local_path=str(repo.repo_path),
828
+ size_bytes=total_size,
829
+ )
830
+ models.append(model_info)
831
+
832
+ except Exception as e:
833
+ logger.warning("Failed to scan cache directory: %s", e)
834
+ return []
835
+ else:
836
+ return models
837
+
838
+ def _delete_model(self, model_id: str) -> str:
839
+ """Synchronous model deletion implementation using HuggingFace Hub SDK.
840
+
841
+ Uses scan_cache_dir to find and delete the model from cache.
842
+
843
+ Args:
844
+ model_id: The model ID to delete
845
+
846
+ Returns:
847
+ str: Information about what was deleted
848
+
849
+ Raises:
850
+ FileNotFoundError: If the model is not found in cache
851
+ """
852
+ cache_info = scan_cache_dir()
853
+
854
+ # Find the repo to delete
855
+ repo_to_delete = None
856
+ for repo in cache_info.repos:
857
+ if repo.repo_id == model_id:
858
+ repo_to_delete = repo
859
+ break
860
+
861
+ if repo_to_delete is None:
862
+ error_msg = f"Model '{model_id}' not found in cache"
863
+ raise FileNotFoundError(error_msg)
864
+
865
+ # Get all revision hashes for this repo
866
+ revision_hashes = [revision.commit_hash for revision in repo_to_delete.revisions]
867
+
868
+ if not revision_hashes:
869
+ error_msg = f"No revisions found for model '{model_id}'"
870
+ raise FileNotFoundError(error_msg)
871
+
872
+ # Create delete strategy for all revisions of this repo
873
+ delete_strategy = cache_info.delete_revisions(*revision_hashes)
874
+
875
+ # Execute the deletion
876
+ delete_strategy.execute()
877
+
878
+ return f"Deleted model '{model_id}' (freed {delete_strategy.expected_freed_size_str})"
879
+
880
+ def _get_model_info(self, model_dir: Path) -> dict[str, str | int | float] | None:
881
+ """Get information about a cached model.
882
+
883
+ Args:
884
+ model_dir: Path to the model directory in cache
885
+
886
+ Returns:
887
+ dict | None: Model information or None if not a valid model directory
888
+ """
889
+ try:
890
+ # Extract model_id from directory name
891
+ # HuggingFace cache format: models--{org}--{model}--{hash}
892
+ dir_name = model_dir.name
893
+ if not dir_name.startswith("models--"):
894
+ return None
895
+
896
+ # Parse the model ID from the directory name
897
+ parts = dir_name.split("--")
898
+ if len(parts) >= MIN_CACHE_DIR_PARTS:
899
+ # Reconstruct model_id as org/model
900
+ model_id = f"{parts[1]}/{parts[2]}"
901
+ else:
902
+ model_id = dir_name[8:] # Remove "models--" prefix
903
+
904
+ # Calculate directory size
905
+ total_size = sum(f.stat().st_size for f in model_dir.rglob("*") if f.is_file())
906
+
907
+ return {
908
+ "model_id": model_id,
909
+ "local_path": str(model_dir),
910
+ "size_bytes": total_size,
911
+ "size_mb": round(total_size / (1024 * 1024), 2),
912
+ }
913
+
914
+ except Exception:
915
+ # If we can't parse the directory, skip it
916
+ return None
917
+
918
+ def _download_model(self, download_params: dict[str, str | list[str] | None]) -> Path:
919
+ """Model download implementation.
920
+
921
+ Args:
922
+ download_params: Dictionary containing download parameters
923
+
924
+ Returns:
925
+ Path: Local path where the model was downloaded
926
+ """
927
+ # Validate parameters and build download kwargs
928
+ download_kwargs = self._build_download_kwargs(download_params)
929
+
930
+ # Set the current model ID for progress tracking
931
+ model_id = download_params["model_id"]
932
+ if isinstance(model_id, str):
933
+ ModelDownloadTracker.set_current_model_id(model_id)
934
+
935
+ # Execute download with progress tracking
936
+ local_path = snapshot_download(**download_kwargs) # type: ignore[arg-type]
937
+
938
+ # Clear the current model ID when done
939
+ ModelDownloadTracker.set_current_model_id("")
940
+
941
+ return Path(local_path)
942
+
943
+ def _build_download_kwargs(self, download_params: dict[str, str | list[str] | None]) -> dict:
944
+ """Build kwargs for snapshot_download with validation.
945
+
946
+ Args:
947
+ download_params: Dictionary containing download parameters
948
+
949
+ Returns:
950
+ dict: Validated download kwargs for snapshot_download
951
+ """
952
+ param_model_id = download_params["model_id"]
953
+ local_dir = download_params["local_dir"]
954
+ revision = download_params["revision"]
955
+ allow_patterns = download_params["allow_patterns"]
956
+ ignore_patterns = download_params["ignore_patterns"]
957
+
958
+ # Build base kwargs with custom progress tracking
959
+ download_kwargs: dict[str, Any] = {
960
+ "repo_id": param_model_id,
961
+ "repo_type": "model",
962
+ "revision": revision,
963
+ "tqdm_class": ModelDownloadTracker,
964
+ }
965
+
966
+ # Add optional parameters
967
+ if local_dir is not None and isinstance(local_dir, str):
968
+ download_kwargs["local_dir"] = local_dir
969
+ if allow_patterns is not None and isinstance(allow_patterns, list):
970
+ download_kwargs["allow_patterns"] = allow_patterns
971
+ if ignore_patterns is not None and isinstance(ignore_patterns, list):
972
+ download_kwargs["ignore_patterns"] = ignore_patterns
973
+
974
+ return download_kwargs
975
+
976
+ def _parse_model_id(self, model_input: str) -> str:
977
+ """Parse model ID from either a direct model ID or a Hugging Face URL.
978
+
979
+ Args:
980
+ model_input: Either a model ID (e.g., 'microsoft/DialoGPT-medium')
981
+ or a Hugging Face URL (e.g., 'https://huggingface.co/microsoft/DialoGPT-medium')
982
+
983
+ Returns:
984
+ str: The parsed model ID in the format 'namespace/repo_name' or 'repo_name'
985
+ """
986
+ # If it's already a simple model ID (no URL scheme), return as-is
987
+ if not model_input.startswith(("http://", "https://")):
988
+ return model_input
989
+
990
+ # Parse the URL
991
+ parsed = urlparse(model_input)
992
+
993
+ # Check if it's a Hugging Face URL
994
+ if parsed.netloc in ("huggingface.co", "www.huggingface.co"):
995
+ # Extract the path and remove leading slash
996
+ path = parsed.path.lstrip("/")
997
+
998
+ # Remove any trailing parameters or fragments
999
+ # The model ID should be in the format: namespace/repo_name or just repo_name
1000
+ model_id_match = re.match(r"^([^/]+/[^/?#]+|[^/?#]+)", path)
1001
+ if model_id_match:
1002
+ return model_id_match.group(1)
1003
+
1004
+ # If we can't parse it, return the original input and let huggingface_hub handle the error
1005
+ return model_input
1006
+
1007
+ def _get_download_statuses(self, model_id: str | None = None) -> list[ModelDownloadStatus]:
1008
+ """Get download statuses for a specific model or all models.
1009
+
1010
+ Args:
1011
+ model_id: Optional model ID to get status for. If None, returns all statuses.
1012
+
1013
+ Returns:
1014
+ list[ModelDownloadStatus]: List of download statuses
1015
+ """
1016
+ if model_id:
1017
+ # Get status for specific model
1018
+ status = self._read_model_download_status(model_id)
1019
+ return [status] if status else []
1020
+ # Get all download statuses
1021
+ return self._list_all_download_statuses()
1022
+
1023
+ async def on_handle_delete_model_download_request(self, request: DeleteModelDownloadRequest) -> ResultPayload:
1024
+ """Handle model download status deletion requests asynchronously.
1025
+
1026
+ This method removes download tracking records for a specific model
1027
+ from the local status files stored in the XDG data directory.
1028
+ If the model is currently downloading or failed, it also cancels
1029
+ the download task and deletes any cached model files.
1030
+
1031
+ Args:
1032
+ request: The delete request containing model_id
1033
+
1034
+ Returns:
1035
+ ResultPayload: Success result with deletion confirmation or failure with error details
1036
+ """
1037
+ model_id = request.model_id
1038
+
1039
+ try:
1040
+ # Check current download status first
1041
+ download_status = await asyncio.to_thread(self._read_model_download_status, model_id)
1042
+
1043
+ # Cancel active download process if it exists
1044
+ if model_id in self._download_processes:
1045
+ process = self._download_processes[model_id]
1046
+ await cancel_subprocess(process, f"download process for model '{model_id}'")
1047
+ del self._download_processes[model_id]
1048
+
1049
+ if model_id in self._download_tasks:
1050
+ task = self._download_tasks[model_id]
1051
+ if not task.done():
1052
+ task.cancel()
1053
+ logger.info("Cancelled active download task for model '%s'", model_id)
1054
+ del self._download_tasks[model_id]
1055
+
1056
+ # Delete status file
1057
+ deleted_path = await asyncio.to_thread(self._delete_model_download_status, model_id)
1058
+
1059
+ # Only delete cached model if it's not completed
1060
+ if download_status and download_status.status != "completed":
1061
+ try:
1062
+ await asyncio.to_thread(self._delete_model, model_id)
1063
+ except FileNotFoundError:
1064
+ logger.debug("No cached model files found for '%s'", model_id)
1065
+
1066
+ result_details = f"Successfully deleted download status for model '{model_id}'"
1067
+
1068
+ return DeleteModelDownloadResultSuccess(
1069
+ model_id=model_id,
1070
+ deleted_path=deleted_path,
1071
+ result_details=result_details,
1072
+ )
1073
+
1074
+ except FileNotFoundError:
1075
+ error_msg = f"Download status for model '{model_id}' not found"
1076
+ return DeleteModelDownloadResultFailure(
1077
+ result_details=error_msg,
1078
+ exception=FileNotFoundError(error_msg),
1079
+ )
1080
+
1081
+ except Exception as e:
1082
+ error_msg = f"Failed to delete download status for '{model_id}': {e}"
1083
+ return DeleteModelDownloadResultFailure(
1084
+ result_details=error_msg,
1085
+ exception=e,
1086
+ )
1087
+
1088
+ def _delete_model_download_status(self, model_id: str) -> str:
1089
+ """Delete download status file for a specific model.
1090
+
1091
+ Args:
1092
+ model_id: The model ID to remove download status for
1093
+
1094
+ Returns:
1095
+ str: Path to the deleted status file
1096
+
1097
+ Raises:
1098
+ FileNotFoundError: If the status file is not found
1099
+ """
1100
+ status_file = self._get_status_file_path(model_id)
1101
+
1102
+ if not status_file.exists():
1103
+ msg = f"Download status file not found for model '{model_id}'"
1104
+ raise FileNotFoundError(msg)
1105
+
1106
+ status_file.unlink()
1107
+ return str(status_file)