griptape-nodes 0.53.0__py3-none-any.whl → 0.54.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/__init__.py +5 -2
- griptape_nodes/app/app.py +4 -26
- griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
- griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
- griptape_nodes/cli/commands/config.py +4 -1
- griptape_nodes/cli/commands/init.py +5 -3
- griptape_nodes/cli/commands/libraries.py +14 -8
- griptape_nodes/cli/commands/models.py +504 -0
- griptape_nodes/cli/commands/self.py +5 -2
- griptape_nodes/cli/main.py +11 -1
- griptape_nodes/cli/shared.py +0 -9
- griptape_nodes/common/directed_graph.py +17 -1
- griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
- griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
- griptape_nodes/drivers/storage/local_storage_driver.py +17 -13
- griptape_nodes/exe_types/node_types.py +219 -14
- griptape_nodes/exe_types/param_components/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
- griptape_nodes/machines/control_flow.py +129 -92
- griptape_nodes/machines/dag_builder.py +207 -0
- griptape_nodes/machines/parallel_resolution.py +264 -276
- griptape_nodes/machines/sequential_resolution.py +9 -7
- griptape_nodes/node_library/library_registry.py +34 -1
- griptape_nodes/retained_mode/events/app_events.py +5 -1
- griptape_nodes/retained_mode/events/base_events.py +7 -7
- griptape_nodes/retained_mode/events/config_events.py +30 -0
- griptape_nodes/retained_mode/events/execution_events.py +2 -2
- griptape_nodes/retained_mode/events/model_events.py +296 -0
- griptape_nodes/retained_mode/griptape_nodes.py +10 -1
- griptape_nodes/retained_mode/managers/agent_manager.py +14 -0
- griptape_nodes/retained_mode/managers/config_manager.py +44 -3
- griptape_nodes/retained_mode/managers/event_manager.py +8 -2
- griptape_nodes/retained_mode/managers/flow_manager.py +45 -14
- griptape_nodes/retained_mode/managers/library_manager.py +3 -3
- griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
- griptape_nodes/retained_mode/managers/node_manager.py +26 -26
- griptape_nodes/retained_mode/managers/object_manager.py +1 -1
- griptape_nodes/retained_mode/managers/os_manager.py +6 -6
- griptape_nodes/retained_mode/managers/settings.py +87 -9
- griptape_nodes/retained_mode/managers/static_files_manager.py +77 -9
- griptape_nodes/retained_mode/managers/sync_manager.py +10 -5
- griptape_nodes/retained_mode/managers/workflow_manager.py +101 -92
- griptape_nodes/retained_mode/retained_mode.py +19 -0
- griptape_nodes/servers/__init__.py +1 -0
- griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
- griptape_nodes/{app/api.py → servers/static.py} +43 -40
- griptape_nodes/traits/button.py +124 -6
- griptape_nodes/traits/multi_options.py +188 -0
- griptape_nodes/traits/numbers_selector.py +77 -0
- griptape_nodes/traits/options.py +93 -2
- griptape_nodes/utils/async_utils.py +31 -0
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/METADATA +3 -1
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/RECORD +56 -47
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/WHEEL +1 -1
- /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
- {griptape_nodes-0.53.0.dist-info → griptape_nodes-0.54.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,504 @@
|
|
|
1
|
+
"""Models command for managing AI models."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
import typer
|
|
7
|
+
from rich.table import Table
|
|
8
|
+
|
|
9
|
+
from griptape_nodes.cli.shared import console
|
|
10
|
+
from griptape_nodes.retained_mode.events.model_events import (
|
|
11
|
+
DeleteModelDownloadRequest,
|
|
12
|
+
DeleteModelDownloadResultFailure,
|
|
13
|
+
DeleteModelDownloadResultSuccess,
|
|
14
|
+
DeleteModelRequest,
|
|
15
|
+
DeleteModelResultFailure,
|
|
16
|
+
DeleteModelResultSuccess,
|
|
17
|
+
ListModelDownloadsRequest,
|
|
18
|
+
ListModelDownloadsResultFailure,
|
|
19
|
+
ListModelDownloadsResultSuccess,
|
|
20
|
+
ListModelsRequest,
|
|
21
|
+
ListModelsResultFailure,
|
|
22
|
+
ListModelsResultSuccess,
|
|
23
|
+
ModelInfo,
|
|
24
|
+
QueryInfo,
|
|
25
|
+
SearchModelsRequest,
|
|
26
|
+
SearchModelsResultFailure,
|
|
27
|
+
SearchModelsResultSuccess,
|
|
28
|
+
)
|
|
29
|
+
from griptape_nodes.retained_mode.retained_mode import GriptapeNodes
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from griptape_nodes.retained_mode.events.model_events import ModelDownloadStatus
|
|
33
|
+
|
|
34
|
+
app = typer.Typer(help="Manage AI models.")
|
|
35
|
+
downloads_app = typer.Typer(help="Manage model download tracking records.")
|
|
36
|
+
|
|
37
|
+
# Add downloads subcommand
|
|
38
|
+
app.add_typer(downloads_app, name="downloads")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@app.command("download")
|
|
42
|
+
def download_command(
|
|
43
|
+
model_id: str = typer.Argument(..., help="Model ID or URL (e.g., 'microsoft/DialoGPT-medium')"),
|
|
44
|
+
local_dir: str | None = typer.Option(None, "--local-dir", help="Local directory to download the model to"),
|
|
45
|
+
revision: str = typer.Option("main", "--revision", help="Git revision to download"),
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Download a model from Hugging Face Hub."""
|
|
48
|
+
asyncio.run(_download_model(model_id, local_dir, revision))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@app.command("list")
|
|
52
|
+
def list_command() -> None:
|
|
53
|
+
"""List all downloaded model files in local cache."""
|
|
54
|
+
asyncio.run(_list_models())
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@app.command("delete")
|
|
58
|
+
def delete_command(
|
|
59
|
+
model_id: str = typer.Argument(..., help="Model ID to delete (e.g., 'microsoft/DialoGPT-medium')"),
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Delete model files from local cache."""
|
|
62
|
+
asyncio.run(_delete_model(model_id))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@downloads_app.command("status")
|
|
66
|
+
def downloads_status_command(
|
|
67
|
+
model_id: str = typer.Argument(None, help="Optional model ID to check download status for"),
|
|
68
|
+
) -> None:
|
|
69
|
+
"""Show download status for a specific model or all models."""
|
|
70
|
+
asyncio.run(_get_model_status(model_id))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@downloads_app.command("list")
|
|
74
|
+
def downloads_list_command() -> None:
|
|
75
|
+
"""List all model download status records."""
|
|
76
|
+
asyncio.run(_get_model_status(None))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@downloads_app.command("delete")
|
|
80
|
+
def downloads_delete_command(
|
|
81
|
+
model_id: str = typer.Argument(
|
|
82
|
+
..., help="Model ID to delete download status for (e.g., 'microsoft/DialoGPT-medium')"
|
|
83
|
+
),
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Delete download status tracking records for a model."""
|
|
86
|
+
asyncio.run(_delete_model_status(model_id))
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@app.command("search")
|
|
90
|
+
def search_command(
|
|
91
|
+
query: str | None = typer.Argument(None, help="Search query to match against model names"),
|
|
92
|
+
task: str | None = typer.Option(None, "--task", help="Filter by task type"),
|
|
93
|
+
limit: int = typer.Option(20, "--limit", help="Maximum number of results (max: 100)"),
|
|
94
|
+
sort: str = typer.Option("downloads", "--sort", help="Sort results by"),
|
|
95
|
+
direction: str = typer.Option("desc", "--direction", help="Sort direction"),
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Search for models on Hugging Face Hub."""
|
|
98
|
+
asyncio.run(_search_models(query, task, limit, sort, direction))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
async def _download_model(
|
|
102
|
+
model_id: str,
|
|
103
|
+
local_dir: str | None,
|
|
104
|
+
revision: str,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""Download a model from Hugging Face Hub.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
model_id: Model ID or URL to download
|
|
110
|
+
local_dir: Local directory to download the model to
|
|
111
|
+
revision: Git revision to download
|
|
112
|
+
"""
|
|
113
|
+
console.print(f"[bold green]Downloading model: {model_id}[/bold green]")
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
# ModelManager DownloadModelRequest will use this command so it's important that we don't use the request ourselves
|
|
117
|
+
model_manager = GriptapeNodes.ModelManager()
|
|
118
|
+
local_path = model_manager.download_model(
|
|
119
|
+
model_id=model_id,
|
|
120
|
+
local_dir=local_dir,
|
|
121
|
+
revision=revision,
|
|
122
|
+
allow_patterns=None,
|
|
123
|
+
ignore_patterns=None,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# Success case
|
|
127
|
+
console.print("[bold green]Model downloaded successfully![/bold green]")
|
|
128
|
+
console.print(f"[green]Downloaded to: {local_path}[/green]")
|
|
129
|
+
|
|
130
|
+
except Exception as e:
|
|
131
|
+
console.print("[bold red]Model download failed:[/bold red]")
|
|
132
|
+
console.print(f"[red]{e}[/red]")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
async def _list_models() -> None:
|
|
136
|
+
"""List all downloaded models in the local cache."""
|
|
137
|
+
console.print("[bold green]Listing cached models...[/bold green]")
|
|
138
|
+
|
|
139
|
+
# Create the list request
|
|
140
|
+
request = ListModelsRequest()
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
# Use the ModelManager to handle the listing
|
|
144
|
+
result = await GriptapeNodes.ahandle_request(request)
|
|
145
|
+
if isinstance(result, ListModelsResultSuccess):
|
|
146
|
+
# Success case
|
|
147
|
+
models = result.models
|
|
148
|
+
if models:
|
|
149
|
+
console.print(f"[bold green]Found {len(models)} cached models:[/bold green]")
|
|
150
|
+
|
|
151
|
+
table = Table()
|
|
152
|
+
table.add_column("Model ID", style="green")
|
|
153
|
+
table.add_column("Size (GB)", style="cyan", justify="right")
|
|
154
|
+
|
|
155
|
+
for model in models:
|
|
156
|
+
size_gb = round((model.size_bytes or 0) / (1024**3), 2) if model.size_bytes else 0.0
|
|
157
|
+
table.add_row(model.model_id, str(size_gb))
|
|
158
|
+
console.print(table)
|
|
159
|
+
else:
|
|
160
|
+
console.print("[yellow]No models found in local cache[/yellow]")
|
|
161
|
+
|
|
162
|
+
# Failure case
|
|
163
|
+
|
|
164
|
+
elif isinstance(result, ListModelsResultFailure):
|
|
165
|
+
console.print("[bold red]Model listing failed:[/bold red]")
|
|
166
|
+
if result.result_details:
|
|
167
|
+
console.print(f"[red]{result.result_details}[/red]")
|
|
168
|
+
if result.exception:
|
|
169
|
+
console.print(f"[dim]Error: {result.exception}[/dim]")
|
|
170
|
+
else:
|
|
171
|
+
console.print("[bold red]Model listing failed: Unknown error occurred[/bold red]")
|
|
172
|
+
|
|
173
|
+
except Exception as e:
|
|
174
|
+
console.print("[bold red]Unexpected error during model listing:[/bold red]")
|
|
175
|
+
console.print(f"[red]{e}[/red]")
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
async def _delete_model(model_id: str) -> None:
|
|
179
|
+
"""Delete a model from the local cache.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
model_id: Model ID to delete
|
|
183
|
+
"""
|
|
184
|
+
console.print(f"[bold yellow]Deleting model: {model_id}[/bold yellow]")
|
|
185
|
+
|
|
186
|
+
# Create the delete request
|
|
187
|
+
request = DeleteModelRequest(model_id=model_id)
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
# Use the ModelManager to handle the deletion
|
|
191
|
+
result = await GriptapeNodes.ahandle_request(request)
|
|
192
|
+
|
|
193
|
+
if isinstance(result, DeleteModelResultSuccess):
|
|
194
|
+
# Success case
|
|
195
|
+
console.print("[bold green]Model deleted successfully![/bold green]")
|
|
196
|
+
console.print(f"[green]Deleted: {result.deleted_path}[/green]")
|
|
197
|
+
# Failure case
|
|
198
|
+
|
|
199
|
+
elif isinstance(result, DeleteModelResultFailure):
|
|
200
|
+
console.print("[bold red]Model deletion failed:[/bold red]")
|
|
201
|
+
if result.result_details:
|
|
202
|
+
console.print(f"[red]{result.result_details}[/red]")
|
|
203
|
+
if result.exception:
|
|
204
|
+
console.print(f"[dim]Error: {result.exception}[/dim]")
|
|
205
|
+
else:
|
|
206
|
+
console.print("[bold red]Model deletion failed: Unknown error occurred[/bold red]")
|
|
207
|
+
|
|
208
|
+
except Exception as e:
|
|
209
|
+
console.print("[bold red]Unexpected error during model deletion:[/bold red]")
|
|
210
|
+
console.print(f"[red]{e}[/red]")
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _format_download_row(download: "ModelDownloadStatus") -> tuple[str, str, str, str, str, str]:
|
|
214
|
+
"""Format a download status object into table row data.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
download: ModelDownloadStatus object
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
tuple: (model_id, status_colored, progress_str, size_str, eta_str, started_str)
|
|
221
|
+
"""
|
|
222
|
+
progress_str = _format_progress(download)
|
|
223
|
+
size_str = _format_size(download)
|
|
224
|
+
eta_str = _format_eta(download)
|
|
225
|
+
started_str = _format_timestamp(download)
|
|
226
|
+
status_colored = _format_status(download)
|
|
227
|
+
|
|
228
|
+
return (
|
|
229
|
+
download.model_id,
|
|
230
|
+
status_colored,
|
|
231
|
+
progress_str,
|
|
232
|
+
size_str,
|
|
233
|
+
eta_str,
|
|
234
|
+
started_str,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _format_progress(download: "ModelDownloadStatus") -> str:
|
|
239
|
+
"""Format download progress information."""
|
|
240
|
+
if download.total_files is not None and download.completed_files is not None and download.total_files > 0:
|
|
241
|
+
progress_percent = (download.completed_files / download.total_files) * 100
|
|
242
|
+
return f"{progress_percent:.1f}%"
|
|
243
|
+
return "Unknown"
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _format_size(download: "ModelDownloadStatus") -> str:
|
|
247
|
+
"""Format download size information."""
|
|
248
|
+
if download.total_files is not None and download.completed_files is not None:
|
|
249
|
+
return f"{download.completed_files}/{download.total_files} files"
|
|
250
|
+
return "Unknown"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _format_eta(download: "ModelDownloadStatus") -> str:
|
|
254
|
+
"""Format estimated time of arrival."""
|
|
255
|
+
# ETA is not available in the current ModelDownloadStatus structure
|
|
256
|
+
# For active downloads, we could potentially calculate based on progress
|
|
257
|
+
# but without timing data, we return a status-appropriate message
|
|
258
|
+
if download.status == "downloading":
|
|
259
|
+
return "In progress"
|
|
260
|
+
if download.status == "completed":
|
|
261
|
+
return "Completed"
|
|
262
|
+
if download.status == "failed":
|
|
263
|
+
return "Failed"
|
|
264
|
+
return "Unknown"
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _format_timestamp(download: "ModelDownloadStatus") -> str:
|
|
268
|
+
"""Format started timestamp."""
|
|
269
|
+
started_at = download.started_at
|
|
270
|
+
if not started_at:
|
|
271
|
+
return "Unknown"
|
|
272
|
+
|
|
273
|
+
try:
|
|
274
|
+
from datetime import datetime
|
|
275
|
+
|
|
276
|
+
dt = datetime.fromisoformat(started_at)
|
|
277
|
+
return dt.strftime("%H:%M:%S")
|
|
278
|
+
except Exception:
|
|
279
|
+
return started_at[:10] # Fallback
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _format_status(download: "ModelDownloadStatus") -> str:
|
|
283
|
+
"""Format status with color coding."""
|
|
284
|
+
status = download.status
|
|
285
|
+
status_colors = {
|
|
286
|
+
"completed": "green",
|
|
287
|
+
"failed": "red",
|
|
288
|
+
"downloading": "yellow",
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
if status in status_colors:
|
|
292
|
+
return f"[{status_colors[status]}]{status}[/{status_colors[status]}]"
|
|
293
|
+
return status
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _display_downloads_table(downloads: list["ModelDownloadStatus"]) -> None:
|
|
297
|
+
"""Display downloads in a formatted table.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
downloads: List of ModelDownloadStatus objects
|
|
301
|
+
"""
|
|
302
|
+
console.print(f"[bold green]Found {len(downloads)} download(s):[/bold green]")
|
|
303
|
+
|
|
304
|
+
table = Table()
|
|
305
|
+
table.add_column("Model ID", style="green")
|
|
306
|
+
table.add_column("Status", style="cyan")
|
|
307
|
+
table.add_column("Progress", style="yellow", justify="right")
|
|
308
|
+
table.add_column("Size", style="blue", justify="right")
|
|
309
|
+
table.add_column("ETA", style="magenta", justify="right")
|
|
310
|
+
table.add_column("Started", style="dim")
|
|
311
|
+
|
|
312
|
+
for download in downloads:
|
|
313
|
+
row_data = _format_download_row(download)
|
|
314
|
+
table.add_row(*row_data)
|
|
315
|
+
|
|
316
|
+
console.print(table)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
async def _get_model_status(model_id: str | None) -> None:
|
|
320
|
+
"""Get download status for models.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
model_id: Optional model ID to get status for
|
|
324
|
+
"""
|
|
325
|
+
if model_id:
|
|
326
|
+
console.print(f"[bold green]Getting download status for: {model_id}[/bold green]")
|
|
327
|
+
else:
|
|
328
|
+
console.print("[bold green]Getting download status for all models...[/bold green]")
|
|
329
|
+
|
|
330
|
+
# Create the status request
|
|
331
|
+
request = ListModelDownloadsRequest(model_id=model_id)
|
|
332
|
+
|
|
333
|
+
try:
|
|
334
|
+
# Use the ModelManager to handle the status query
|
|
335
|
+
result = await GriptapeNodes.ahandle_request(request)
|
|
336
|
+
|
|
337
|
+
if isinstance(result, ListModelDownloadsResultSuccess):
|
|
338
|
+
# Success case
|
|
339
|
+
downloads = result.downloads
|
|
340
|
+
if downloads:
|
|
341
|
+
_display_downloads_table(downloads)
|
|
342
|
+
elif model_id:
|
|
343
|
+
console.print(f"[yellow]No download found for model: {model_id}[/yellow]")
|
|
344
|
+
else:
|
|
345
|
+
console.print("[yellow]No downloads found[/yellow]")
|
|
346
|
+
|
|
347
|
+
elif isinstance(result, ListModelDownloadsResultFailure):
|
|
348
|
+
console.print("[bold red]Failed to get download status:[/bold red]")
|
|
349
|
+
if result.result_details:
|
|
350
|
+
console.print(f"[red]{result.result_details}[/red]")
|
|
351
|
+
if result.exception:
|
|
352
|
+
console.print(f"[dim]Error: {result.exception}[/dim]")
|
|
353
|
+
else:
|
|
354
|
+
console.print("[bold red]Failed to get download status: Unknown error occurred[/bold red]")
|
|
355
|
+
|
|
356
|
+
except Exception as e:
|
|
357
|
+
console.print("[bold red]Unexpected error getting download status:[/bold red]")
|
|
358
|
+
console.print(f"[red]{e}[/red]")
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
async def _search_models(
|
|
362
|
+
query: str | None,
|
|
363
|
+
task: str | None,
|
|
364
|
+
limit: int,
|
|
365
|
+
sort: str,
|
|
366
|
+
direction: str,
|
|
367
|
+
) -> None:
|
|
368
|
+
"""Search for models on Hugging Face Hub.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
query: Search query to match against model names
|
|
372
|
+
task: Filter by task type
|
|
373
|
+
limit: Maximum number of results
|
|
374
|
+
sort: Sort results by
|
|
375
|
+
direction: Sort direction
|
|
376
|
+
"""
|
|
377
|
+
if query:
|
|
378
|
+
console.print(f"[bold green]Searching for models: {query}[/bold green]")
|
|
379
|
+
else:
|
|
380
|
+
console.print("[bold green]Searching for models...[/bold green]")
|
|
381
|
+
|
|
382
|
+
# Create the search request
|
|
383
|
+
request = SearchModelsRequest(
|
|
384
|
+
query=query,
|
|
385
|
+
task=task,
|
|
386
|
+
library=None,
|
|
387
|
+
author=None,
|
|
388
|
+
tags=None,
|
|
389
|
+
limit=limit,
|
|
390
|
+
sort=sort,
|
|
391
|
+
direction=direction,
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
try:
|
|
395
|
+
# Use the ModelManager to handle the search
|
|
396
|
+
result = await GriptapeNodes.ahandle_request(request)
|
|
397
|
+
|
|
398
|
+
if isinstance(result, SearchModelsResultSuccess):
|
|
399
|
+
# Success case
|
|
400
|
+
models = result.models
|
|
401
|
+
if models:
|
|
402
|
+
_display_search_results(models, result.query_info)
|
|
403
|
+
else:
|
|
404
|
+
console.print("[yellow]No models found matching the search criteria[/yellow]")
|
|
405
|
+
|
|
406
|
+
elif isinstance(result, SearchModelsResultFailure):
|
|
407
|
+
console.print("[bold red]Model search failed:[/bold red]")
|
|
408
|
+
if result.result_details:
|
|
409
|
+
console.print(f"[red]{result.result_details}[/red]")
|
|
410
|
+
if result.exception:
|
|
411
|
+
console.print(f"[dim]Error: {result.exception}[/dim]")
|
|
412
|
+
else:
|
|
413
|
+
console.print("[bold red]Model search failed: Unknown error occurred[/bold red]")
|
|
414
|
+
|
|
415
|
+
except Exception as e:
|
|
416
|
+
console.print("[bold red]Unexpected error during model search:[/bold red]")
|
|
417
|
+
console.print(f"[red]{e}[/red]")
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def _display_search_results(models: list[ModelInfo], query_info: QueryInfo) -> None:
|
|
421
|
+
"""Display model search results in a formatted table.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
models: List of model information
|
|
425
|
+
query_info: Information about the search query
|
|
426
|
+
"""
|
|
427
|
+
console.print(f"[bold green]Found {len(models)} models[/bold green]")
|
|
428
|
+
|
|
429
|
+
# Show search parameters if any were used
|
|
430
|
+
params = []
|
|
431
|
+
if query_info.query:
|
|
432
|
+
params.append(f"query: {query_info.query}")
|
|
433
|
+
if query_info.task:
|
|
434
|
+
params.append(f"task: {query_info.task}")
|
|
435
|
+
if query_info.library:
|
|
436
|
+
params.append(f"library: {query_info.library}")
|
|
437
|
+
if query_info.author:
|
|
438
|
+
params.append(f"author: {query_info.author}")
|
|
439
|
+
if query_info.tags:
|
|
440
|
+
params.append(f"tags: {', '.join(query_info.tags)}")
|
|
441
|
+
|
|
442
|
+
if params:
|
|
443
|
+
console.print(f"[dim]Search parameters: {', '.join(params)}[/dim]")
|
|
444
|
+
|
|
445
|
+
table = Table()
|
|
446
|
+
table.add_column("Model ID", style="green")
|
|
447
|
+
table.add_column("Author", style="blue")
|
|
448
|
+
table.add_column("Downloads", style="cyan", justify="right")
|
|
449
|
+
table.add_column("Likes", style="yellow", justify="right")
|
|
450
|
+
table.add_column("Task", style="magenta")
|
|
451
|
+
table.add_column("Library", style="dim")
|
|
452
|
+
|
|
453
|
+
for model in models:
|
|
454
|
+
downloads_str = f"{model.downloads:,}" if model.downloads else "0"
|
|
455
|
+
likes_str = str(model.likes or 0)
|
|
456
|
+
task_str = model.task or ""
|
|
457
|
+
library_str = model.library or ""
|
|
458
|
+
author_str = model.author or ""
|
|
459
|
+
|
|
460
|
+
table.add_row(
|
|
461
|
+
model.model_id,
|
|
462
|
+
author_str,
|
|
463
|
+
downloads_str,
|
|
464
|
+
likes_str,
|
|
465
|
+
task_str,
|
|
466
|
+
library_str,
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
console.print(table)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
async def _delete_model_status(model_id: str) -> None:
|
|
473
|
+
"""Delete download status records for a model.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
model_id: Model ID to delete download status for
|
|
477
|
+
"""
|
|
478
|
+
console.print(f"[bold yellow]Deleting download status for: {model_id}[/bold yellow]")
|
|
479
|
+
|
|
480
|
+
# Create the delete request
|
|
481
|
+
request = DeleteModelDownloadRequest(model_id=model_id)
|
|
482
|
+
|
|
483
|
+
try:
|
|
484
|
+
# Use the ModelManager to handle the deletion
|
|
485
|
+
result = await GriptapeNodes.ahandle_request(request)
|
|
486
|
+
|
|
487
|
+
if isinstance(result, DeleteModelDownloadResultSuccess):
|
|
488
|
+
# Success case
|
|
489
|
+
console.print("[bold green]Download status deleted successfully![/bold green]")
|
|
490
|
+
console.print(f"[green]Deleted status file: {result.deleted_path}[/green]")
|
|
491
|
+
# Failure case
|
|
492
|
+
|
|
493
|
+
elif isinstance(result, DeleteModelDownloadResultFailure):
|
|
494
|
+
console.print("[bold red]Download status deletion failed:[/bold red]")
|
|
495
|
+
if result.result_details:
|
|
496
|
+
console.print(f"[red]{result.result_details}[/red]")
|
|
497
|
+
if result.exception:
|
|
498
|
+
console.print(f"[dim]Error: {result.exception}[/dim]")
|
|
499
|
+
else:
|
|
500
|
+
console.print("[bold red]Download status deletion failed: Unknown error occurred[/bold red]")
|
|
501
|
+
|
|
502
|
+
except Exception as e:
|
|
503
|
+
console.print("[bold red]Unexpected error during download status deletion:[/bold red]")
|
|
504
|
+
console.print(f"[red]{e}[/red]")
|
|
@@ -11,10 +11,9 @@ from griptape_nodes.cli.shared import (
|
|
|
11
11
|
GITHUB_UPDATE_URL,
|
|
12
12
|
LATEST_TAG,
|
|
13
13
|
PYPI_UPDATE_URL,
|
|
14
|
-
config_manager,
|
|
15
14
|
console,
|
|
16
|
-
os_manager,
|
|
17
15
|
)
|
|
16
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
18
17
|
from griptape_nodes.utils.uv_utils import find_uv_bin
|
|
19
18
|
from griptape_nodes.utils.version_utils import (
|
|
20
19
|
get_complete_version_string,
|
|
@@ -23,6 +22,10 @@ from griptape_nodes.utils.version_utils import (
|
|
|
23
22
|
get_latest_version_pypi,
|
|
24
23
|
)
|
|
25
24
|
|
|
25
|
+
config_manager = GriptapeNodes.ConfigManager()
|
|
26
|
+
secrets_manager = GriptapeNodes.SecretsManager()
|
|
27
|
+
os_manager = GriptapeNodes.OSManager()
|
|
28
|
+
|
|
26
29
|
app = typer.Typer(help="Manage this CLI installation.")
|
|
27
30
|
|
|
28
31
|
|
griptape_nodes/cli/main.py
CHANGED
|
@@ -9,7 +9,8 @@ from rich.console import Console
|
|
|
9
9
|
# Add current directory to path for imports to work
|
|
10
10
|
sys.path.append(str(Path.cwd()))
|
|
11
11
|
|
|
12
|
-
from griptape_nodes.cli.commands import config, engine, init, libraries, self
|
|
12
|
+
from griptape_nodes.cli.commands import config, engine, init, libraries, models, self
|
|
13
|
+
from griptape_nodes.utils.version_utils import get_complete_version_string
|
|
13
14
|
|
|
14
15
|
console = Console()
|
|
15
16
|
|
|
@@ -26,17 +27,26 @@ app.command("init", help="Initialize engine configuration.")(init.init_command)
|
|
|
26
27
|
app.add_typer(config.app, name="config")
|
|
27
28
|
app.add_typer(self.app, name="self")
|
|
28
29
|
app.add_typer(libraries.app, name="libraries")
|
|
30
|
+
app.add_typer(models.app, name="models")
|
|
29
31
|
app.command("engine")(engine.engine_command)
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
@app.callback()
|
|
33
35
|
def main(
|
|
34
36
|
ctx: typer.Context,
|
|
37
|
+
version: bool = typer.Option( # noqa: FBT001
|
|
38
|
+
False, "--version", help="Show version and exit."
|
|
39
|
+
),
|
|
35
40
|
no_update: bool = typer.Option( # noqa: FBT001
|
|
36
41
|
False, "--no-update", help="Skip the auto-update check."
|
|
37
42
|
),
|
|
38
43
|
) -> None:
|
|
39
44
|
"""Griptape Nodes Engine CLI."""
|
|
45
|
+
if version:
|
|
46
|
+
version_string = get_complete_version_string()
|
|
47
|
+
console.print(f"[bold green]{version_string}[/bold green]")
|
|
48
|
+
raise typer.Exit
|
|
49
|
+
|
|
40
50
|
if ctx.invoked_subcommand is None:
|
|
41
51
|
# Default to engine command when no subcommand is specified
|
|
42
52
|
engine.engine_command(no_update=no_update)
|
griptape_nodes/cli/shared.py
CHANGED
|
@@ -8,10 +8,6 @@ from typing import Any
|
|
|
8
8
|
from rich.console import Console
|
|
9
9
|
from xdg_base_dirs import xdg_config_home, xdg_data_home
|
|
10
10
|
|
|
11
|
-
from griptape_nodes.retained_mode.managers.config_manager import ConfigManager
|
|
12
|
-
from griptape_nodes.retained_mode.managers.os_manager import OSManager
|
|
13
|
-
from griptape_nodes.retained_mode.managers.secrets_manager import SecretsManager
|
|
14
|
-
|
|
15
11
|
|
|
16
12
|
@dataclass
|
|
17
13
|
class InitConfig:
|
|
@@ -61,11 +57,6 @@ ENV_LIBRARIES_SYNC = (
|
|
|
61
57
|
ENV_GTN_BUCKET_NAME = os.getenv("GTN_BUCKET_NAME")
|
|
62
58
|
ENV_LIBRARIES_BASE_DIR = os.getenv("GTN_LIBRARIES_BASE_DIR", str(DATA_DIR / "libraries"))
|
|
63
59
|
|
|
64
|
-
# Initialize managers
|
|
65
|
-
config_manager = ConfigManager()
|
|
66
|
-
secrets_manager = SecretsManager(config_manager)
|
|
67
|
-
os_manager = OSManager()
|
|
68
|
-
|
|
69
60
|
|
|
70
61
|
def init_system_config() -> None:
|
|
71
62
|
"""Initializes the system config directory if it doesn't exist."""
|
|
@@ -12,6 +12,10 @@ class DirectedGraph:
|
|
|
12
12
|
self._nodes: set[str] = set()
|
|
13
13
|
self._predecessors: dict[str, set[str]] = {}
|
|
14
14
|
|
|
15
|
+
def __len__(self) -> int:
|
|
16
|
+
"""Return the number of nodes in the graph."""
|
|
17
|
+
return len(self._nodes)
|
|
18
|
+
|
|
15
19
|
def add_node(self, node_for_adding: str) -> None:
|
|
16
20
|
"""Add a node to the graph."""
|
|
17
21
|
self._nodes.add(node_for_adding)
|
|
@@ -31,9 +35,21 @@ class DirectedGraph:
|
|
|
31
35
|
def in_degree(self, node: str) -> int:
|
|
32
36
|
"""Return the in-degree of a node (number of incoming edges)."""
|
|
33
37
|
if node not in self._nodes:
|
|
34
|
-
|
|
38
|
+
msg = f"Node {node} not found in graph"
|
|
39
|
+
raise KeyError(msg)
|
|
35
40
|
return len(self._predecessors.get(node, set()))
|
|
36
41
|
|
|
42
|
+
def out_degree(self, node: str) -> int:
|
|
43
|
+
"""Return the out-degree of a node (number of outgoing edges)."""
|
|
44
|
+
if node not in self._nodes:
|
|
45
|
+
msg = f"Node {node} not found in graph"
|
|
46
|
+
raise KeyError(msg)
|
|
47
|
+
count = 0
|
|
48
|
+
for predecessors in self._predecessors.values():
|
|
49
|
+
if node in predecessors:
|
|
50
|
+
count += 1
|
|
51
|
+
return count
|
|
52
|
+
|
|
37
53
|
def remove_node(self, node: str) -> None:
|
|
38
54
|
"""Remove a node and all its edges from the graph."""
|
|
39
55
|
if node not in self._nodes:
|