griptape-nodes 0.52.1__py3-none-any.whl → 0.54.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- griptape_nodes/__init__.py +8 -942
- griptape_nodes/__main__.py +6 -0
- griptape_nodes/app/app.py +48 -86
- griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
- griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
- griptape_nodes/cli/__init__.py +1 -0
- griptape_nodes/cli/commands/__init__.py +1 -0
- griptape_nodes/cli/commands/config.py +74 -0
- griptape_nodes/cli/commands/engine.py +80 -0
- griptape_nodes/cli/commands/init.py +550 -0
- griptape_nodes/cli/commands/libraries.py +96 -0
- griptape_nodes/cli/commands/models.py +504 -0
- griptape_nodes/cli/commands/self.py +120 -0
- griptape_nodes/cli/main.py +56 -0
- griptape_nodes/cli/shared.py +75 -0
- griptape_nodes/common/__init__.py +1 -0
- griptape_nodes/common/directed_graph.py +71 -0
- griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
- griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
- griptape_nodes/drivers/storage/local_storage_driver.py +23 -14
- griptape_nodes/exe_types/core_types.py +60 -2
- griptape_nodes/exe_types/node_types.py +257 -38
- griptape_nodes/exe_types/param_components/__init__.py +1 -0
- griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
- griptape_nodes/machines/control_flow.py +195 -94
- griptape_nodes/machines/dag_builder.py +207 -0
- griptape_nodes/machines/fsm.py +10 -1
- griptape_nodes/machines/parallel_resolution.py +558 -0
- griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +30 -57
- griptape_nodes/node_library/library_registry.py +34 -1
- griptape_nodes/retained_mode/events/app_events.py +5 -1
- griptape_nodes/retained_mode/events/base_events.py +9 -9
- griptape_nodes/retained_mode/events/config_events.py +30 -0
- griptape_nodes/retained_mode/events/execution_events.py +2 -2
- griptape_nodes/retained_mode/events/model_events.py +296 -0
- griptape_nodes/retained_mode/events/node_events.py +4 -3
- griptape_nodes/retained_mode/griptape_nodes.py +34 -12
- griptape_nodes/retained_mode/managers/agent_manager.py +23 -5
- griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
- griptape_nodes/retained_mode/managers/config_manager.py +44 -3
- griptape_nodes/retained_mode/managers/context_manager.py +6 -5
- griptape_nodes/retained_mode/managers/event_manager.py +8 -2
- griptape_nodes/retained_mode/managers/flow_manager.py +150 -206
- griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
- griptape_nodes/retained_mode/managers/library_manager.py +35 -25
- griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
- griptape_nodes/retained_mode/managers/node_manager.py +102 -220
- griptape_nodes/retained_mode/managers/object_manager.py +11 -5
- griptape_nodes/retained_mode/managers/os_manager.py +28 -13
- griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
- griptape_nodes/retained_mode/managers/settings.py +116 -7
- griptape_nodes/retained_mode/managers/static_files_manager.py +85 -12
- griptape_nodes/retained_mode/managers/sync_manager.py +17 -9
- griptape_nodes/retained_mode/managers/workflow_manager.py +186 -192
- griptape_nodes/retained_mode/retained_mode.py +19 -0
- griptape_nodes/servers/__init__.py +1 -0
- griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
- griptape_nodes/{app/api.py → servers/static.py} +43 -40
- griptape_nodes/traits/add_param_button.py +1 -1
- griptape_nodes/traits/button.py +334 -6
- griptape_nodes/traits/color_picker.py +66 -0
- griptape_nodes/traits/multi_options.py +188 -0
- griptape_nodes/traits/numbers_selector.py +77 -0
- griptape_nodes/traits/options.py +93 -2
- griptape_nodes/traits/traits.json +4 -0
- griptape_nodes/utils/async_utils.py +31 -0
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +4 -1
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +71 -48
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
- /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
- {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
|
@@ -10,6 +10,7 @@ from griptape_nodes.retained_mode.events.base_events import (
|
|
|
10
10
|
)
|
|
11
11
|
from griptape_nodes.retained_mode.events.config_events import (
|
|
12
12
|
GetConfigCategoryRequest,
|
|
13
|
+
GetConfigSchemaRequest,
|
|
13
14
|
GetConfigValueRequest,
|
|
14
15
|
SetConfigCategoryRequest,
|
|
15
16
|
SetConfigValueRequest,
|
|
@@ -1396,6 +1397,24 @@ class RetainedMode:
|
|
|
1396
1397
|
result = GriptapeNodes().handle_request(request)
|
|
1397
1398
|
return result
|
|
1398
1399
|
|
|
1400
|
+
@classmethod
|
|
1401
|
+
def get_config_schema(cls) -> ResultPayload:
|
|
1402
|
+
"""Gets the JSON schema for the configuration model.
|
|
1403
|
+
|
|
1404
|
+
Returns:
|
|
1405
|
+
ResultPayload: Contains the configuration schema with field types, enums, and validation rules.
|
|
1406
|
+
|
|
1407
|
+
Example:
|
|
1408
|
+
# Get the configuration schema
|
|
1409
|
+
schema_result = cmd.get_config_schema()
|
|
1410
|
+
if isinstance(schema_result, GetConfigSchemaResultSuccess):
|
|
1411
|
+
schema = schema_result.schema
|
|
1412
|
+
# Use schema to render appropriate UI components
|
|
1413
|
+
"""
|
|
1414
|
+
request = GetConfigSchemaRequest()
|
|
1415
|
+
result = GriptapeNodes().handle_request(request)
|
|
1416
|
+
return result
|
|
1417
|
+
|
|
1399
1418
|
@classmethod
|
|
1400
1419
|
def rename(cls, object_name: str, requested_name: str) -> ResultPayload:
|
|
1401
1420
|
"""Renames a node or flow.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Package for web servers the engine may need to start."""
|
|
@@ -16,7 +16,6 @@ from pydantic import TypeAdapter
|
|
|
16
16
|
from rich.logging import RichHandler
|
|
17
17
|
from starlette.types import Receive, Scope, Send
|
|
18
18
|
|
|
19
|
-
from griptape_nodes.mcp_server.ws_request_manager import AsyncRequestManager, WebSocketConnectionManager
|
|
20
19
|
from griptape_nodes.retained_mode.events.base_events import RequestPayload
|
|
21
20
|
from griptape_nodes.retained_mode.events.connection_events import (
|
|
22
21
|
CreateConnectionRequest,
|
|
@@ -38,6 +37,7 @@ from griptape_nodes.retained_mode.events.parameter_events import (
|
|
|
38
37
|
)
|
|
39
38
|
from griptape_nodes.retained_mode.managers.config_manager import ConfigManager
|
|
40
39
|
from griptape_nodes.retained_mode.managers.secrets_manager import SecretsManager
|
|
40
|
+
from griptape_nodes.servers.ws_request_manager import AsyncRequestManager, WebSocketConnectionManager
|
|
41
41
|
|
|
42
42
|
SUPPORTED_REQUEST_EVENTS: dict[str, type[RequestPayload]] = {
|
|
43
43
|
# Nodes
|
|
@@ -4,15 +4,16 @@ import binascii
|
|
|
4
4
|
import logging
|
|
5
5
|
import os
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Annotated
|
|
8
7
|
from urllib.parse import urljoin
|
|
9
8
|
|
|
10
9
|
import uvicorn
|
|
11
|
-
from fastapi import
|
|
10
|
+
from fastapi import FastAPI, HTTPException, Request
|
|
12
11
|
from fastapi.middleware.cors import CORSMiddleware
|
|
13
12
|
from fastapi.staticfiles import StaticFiles
|
|
14
13
|
from rich.logging import RichHandler
|
|
15
14
|
|
|
15
|
+
from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
|
|
16
|
+
|
|
16
17
|
# Whether to enable the static server
|
|
17
18
|
STATIC_SERVER_ENABLED = os.getenv("STATIC_SERVER_ENABLED", "true").lower() == "true"
|
|
18
19
|
# Host of the static server
|
|
@@ -20,7 +21,7 @@ STATIC_SERVER_HOST = os.getenv("STATIC_SERVER_HOST", "localhost")
|
|
|
20
21
|
# Port of the static server
|
|
21
22
|
STATIC_SERVER_PORT = int(os.getenv("STATIC_SERVER_PORT", "8124"))
|
|
22
23
|
# URL path for the static server
|
|
23
|
-
STATIC_SERVER_URL = os.getenv("STATIC_SERVER_URL", "/
|
|
24
|
+
STATIC_SERVER_URL = os.getenv("STATIC_SERVER_URL", "/workspace")
|
|
24
25
|
# Log level for the static server
|
|
25
26
|
STATIC_SERVER_LOG_LEVEL = os.getenv("STATIC_SERVER_LOG_LEVEL", "ERROR").lower()
|
|
26
27
|
|
|
@@ -28,18 +29,6 @@ logger = logging.getLogger("griptape_nodes_api")
|
|
|
28
29
|
logging.getLogger("uvicorn").addHandler(RichHandler(show_time=True, show_path=False, markup=True, rich_tracebacks=True))
|
|
29
30
|
|
|
30
31
|
|
|
31
|
-
# Global static directory - initialized as None and set when starting the API
|
|
32
|
-
static_dir: Path | None = None
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def get_static_dir() -> Path:
|
|
36
|
-
"""FastAPI dependency to get the static directory."""
|
|
37
|
-
if static_dir is None:
|
|
38
|
-
msg = "Static directory is not initialized"
|
|
39
|
-
raise HTTPException(status_code=500, detail=msg)
|
|
40
|
-
return static_dir
|
|
41
|
-
|
|
42
|
-
|
|
43
32
|
"""Create and configure the FastAPI application."""
|
|
44
33
|
app = FastAPI()
|
|
45
34
|
|
|
@@ -52,35 +41,34 @@ async def _create_static_file_upload_url(request: Request) -> dict:
|
|
|
52
41
|
"""
|
|
53
42
|
base_url = request.base_url
|
|
54
43
|
body = await request.json()
|
|
55
|
-
|
|
56
|
-
url = urljoin(str(base_url), f"/static-uploads/{
|
|
44
|
+
file_path = body["file_path"].lstrip("/")
|
|
45
|
+
url = urljoin(str(base_url), f"/static-uploads/{file_path}")
|
|
57
46
|
|
|
58
47
|
return {"url": url}
|
|
59
48
|
|
|
60
49
|
|
|
61
50
|
@app.put("/static-uploads/{file_path:path}")
|
|
62
|
-
async def _create_static_file(
|
|
63
|
-
request: Request, file_path: str, static_directory: Annotated[Path, Depends(get_static_dir)]
|
|
64
|
-
) -> dict:
|
|
51
|
+
async def _create_static_file(request: Request, file_path: str) -> dict:
|
|
65
52
|
"""Upload a static file to the static server."""
|
|
66
53
|
if not STATIC_SERVER_ENABLED:
|
|
67
54
|
msg = "Static server is not enabled. Please set STATIC_SERVER_ENABLED to True."
|
|
68
55
|
raise ValueError(msg)
|
|
69
56
|
|
|
70
|
-
|
|
57
|
+
workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
|
|
58
|
+
full_file_path = workspace_directory / file_path
|
|
71
59
|
|
|
72
60
|
# Create parent directories if they don't exist
|
|
73
|
-
|
|
61
|
+
full_file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
74
62
|
|
|
75
63
|
data = await request.body()
|
|
76
64
|
try:
|
|
77
|
-
|
|
65
|
+
full_file_path.write_bytes(data)
|
|
78
66
|
except binascii.Error as e:
|
|
79
67
|
msg = f"Invalid base64 encoding for file {file_path}."
|
|
80
68
|
logger.error(msg)
|
|
81
69
|
raise HTTPException(status_code=400, detail=msg) from e
|
|
82
70
|
except (OSError, PermissionError) as e:
|
|
83
|
-
msg = f"Failed to write file {
|
|
71
|
+
msg = f"Failed to write file {full_file_path}: {e}"
|
|
84
72
|
logger.error(msg)
|
|
85
73
|
raise HTTPException(status_code=500, detail=msg) from e
|
|
86
74
|
|
|
@@ -88,19 +76,28 @@ async def _create_static_file(
|
|
|
88
76
|
return {"url": static_url}
|
|
89
77
|
|
|
90
78
|
|
|
79
|
+
@app.get("/static-uploads/{file_path_prefix:path}")
|
|
91
80
|
@app.get("/static-uploads/")
|
|
92
|
-
async def _list_static_files(
|
|
93
|
-
"""List
|
|
81
|
+
async def _list_static_files(file_path_prefix: str = "") -> dict:
|
|
82
|
+
"""List static files in the static server under the specified path prefix."""
|
|
94
83
|
if not STATIC_SERVER_ENABLED:
|
|
95
84
|
msg = "Static server is not enabled. Please set STATIC_SERVER_ENABLED to True."
|
|
96
85
|
raise HTTPException(status_code=500, detail=msg)
|
|
97
86
|
|
|
87
|
+
workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
|
|
88
|
+
|
|
89
|
+
# Handle the prefix path
|
|
90
|
+
if file_path_prefix:
|
|
91
|
+
target_directory = workspace_directory / file_path_prefix
|
|
92
|
+
else:
|
|
93
|
+
target_directory = workspace_directory
|
|
94
|
+
|
|
98
95
|
try:
|
|
99
96
|
file_names = []
|
|
100
|
-
if
|
|
101
|
-
for file_path in
|
|
97
|
+
if target_directory.exists() and target_directory.is_dir():
|
|
98
|
+
for file_path in target_directory.rglob("*"):
|
|
102
99
|
if file_path.is_file():
|
|
103
|
-
relative_path = file_path.relative_to(
|
|
100
|
+
relative_path = file_path.relative_to(workspace_directory)
|
|
104
101
|
file_names.append(str(relative_path))
|
|
105
102
|
except (OSError, PermissionError) as e:
|
|
106
103
|
msg = f"Failed to list files in static directory: {e}"
|
|
@@ -111,13 +108,14 @@ async def _list_static_files(static_directory: Annotated[Path, Depends(get_stati
|
|
|
111
108
|
|
|
112
109
|
|
|
113
110
|
@app.delete("/static-files/{file_path:path}")
|
|
114
|
-
async def _delete_static_file(file_path: str
|
|
111
|
+
async def _delete_static_file(file_path: str) -> dict:
|
|
115
112
|
"""Delete a static file from the static server."""
|
|
116
113
|
if not STATIC_SERVER_ENABLED:
|
|
117
114
|
msg = "Static server is not enabled. Please set STATIC_SERVER_ENABLED to True."
|
|
118
115
|
raise HTTPException(status_code=500, detail=msg)
|
|
119
116
|
|
|
120
|
-
|
|
117
|
+
workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
|
|
118
|
+
file_full_path = workspace_directory / file_path
|
|
121
119
|
|
|
122
120
|
# Check if file exists
|
|
123
121
|
if not file_full_path.exists():
|
|
@@ -141,13 +139,10 @@ async def _delete_static_file(file_path: str, static_directory: Annotated[Path,
|
|
|
141
139
|
return {"message": f"File {file_path} deleted successfully"}
|
|
142
140
|
|
|
143
141
|
|
|
144
|
-
def _setup_app(
|
|
142
|
+
def _setup_app() -> None:
|
|
145
143
|
"""Setup FastAPI app with middleware and static files."""
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
if not static_dir.exists():
|
|
150
|
-
static_dir.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
|
|
145
|
+
static_files_directory = Path(GriptapeNodes.ConfigManager().get_config_value("static_files_directory"))
|
|
151
146
|
|
|
152
147
|
app.add_middleware(
|
|
153
148
|
CORSMiddleware,
|
|
@@ -163,15 +158,23 @@ def _setup_app(static_directory: Path) -> None:
|
|
|
163
158
|
|
|
164
159
|
app.mount(
|
|
165
160
|
STATIC_SERVER_URL,
|
|
166
|
-
StaticFiles(directory=
|
|
161
|
+
StaticFiles(directory=workspace_directory),
|
|
162
|
+
name="workspace",
|
|
163
|
+
)
|
|
164
|
+
static_files_path = workspace_directory / static_files_directory
|
|
165
|
+
static_files_path.mkdir(parents=True, exist_ok=True)
|
|
166
|
+
# For legacy urls
|
|
167
|
+
app.mount(
|
|
168
|
+
"/static",
|
|
169
|
+
StaticFiles(directory=workspace_directory / static_files_directory),
|
|
167
170
|
name="static",
|
|
168
171
|
)
|
|
169
172
|
|
|
170
173
|
|
|
171
|
-
def start_static_server(
|
|
174
|
+
def start_static_server() -> None:
|
|
172
175
|
"""Run uvicorn server synchronously using uvicorn.run."""
|
|
173
176
|
# Setup the FastAPI app
|
|
174
|
-
_setup_app(
|
|
177
|
+
_setup_app()
|
|
175
178
|
|
|
176
179
|
try:
|
|
177
180
|
# Run server using uvicorn.run
|
|
@@ -11,7 +11,7 @@ class AddParameterButton(Trait):
|
|
|
11
11
|
|
|
12
12
|
def __init__(self) -> None:
|
|
13
13
|
super().__init__(element_id="AddParameterButton")
|
|
14
|
-
self.add_child(Button(
|
|
14
|
+
self.add_child(Button(label="AddParameter"))
|
|
15
15
|
|
|
16
16
|
@classmethod
|
|
17
17
|
def get_trait_keys(cls) -> list[str]:
|
griptape_nodes/traits/button.py
CHANGED
|
@@ -1,21 +1,349 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Literal, get_args
|
|
2
4
|
|
|
3
|
-
from griptape_nodes.exe_types.core_types import Trait
|
|
5
|
+
from griptape_nodes.exe_types.core_types import NodeMessagePayload, NodeMessageResult, Trait
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
|
|
10
|
+
# Don't export callback types - let users import explicitly
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger("griptape_nodes")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Type aliases using Literals
|
|
16
|
+
ButtonVariant = Literal[
|
|
17
|
+
"default",
|
|
18
|
+
"secondary",
|
|
19
|
+
"destructive",
|
|
20
|
+
"outline",
|
|
21
|
+
"ghost",
|
|
22
|
+
"link",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
ButtonSize = Literal[
|
|
26
|
+
"default",
|
|
27
|
+
"sm",
|
|
28
|
+
"icon",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
ButtonState = Literal[
|
|
32
|
+
"normal",
|
|
33
|
+
"disabled",
|
|
34
|
+
"loading",
|
|
35
|
+
"hidden",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
IconPosition = Literal[
|
|
39
|
+
"left",
|
|
40
|
+
"right",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ButtonDetailsMessagePayload(NodeMessagePayload):
|
|
45
|
+
"""Payload containing complete button details and status information."""
|
|
46
|
+
|
|
47
|
+
label: str
|
|
48
|
+
variant: str
|
|
49
|
+
size: str
|
|
50
|
+
state: str
|
|
51
|
+
icon: str | None = None
|
|
52
|
+
icon_class: str | None = None
|
|
53
|
+
icon_position: str | None = None
|
|
54
|
+
full_width: bool = False
|
|
55
|
+
loading_label: str | None = None
|
|
56
|
+
loading_icon: str | None = None
|
|
57
|
+
loading_icon_class: str | None = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class OnClickMessageResultPayload(NodeMessagePayload):
|
|
61
|
+
"""Payload for button click result messages."""
|
|
62
|
+
|
|
63
|
+
button_details: ButtonDetailsMessagePayload
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SetButtonStatusMessagePayload(NodeMessagePayload):
|
|
67
|
+
"""Payload for setting button status with explicit field updates."""
|
|
68
|
+
|
|
69
|
+
updates: dict[str, str | bool | None]
|
|
4
70
|
|
|
5
71
|
|
|
6
72
|
@dataclass(eq=False)
|
|
7
73
|
class Button(Trait):
|
|
8
|
-
type
|
|
74
|
+
# Specific callback types for better type safety and clarity
|
|
75
|
+
type OnClickCallback = Callable[[Button, ButtonDetailsMessagePayload], NodeMessageResult | None]
|
|
76
|
+
type GetButtonStateCallback = Callable[[Button, ButtonDetailsMessagePayload], NodeMessageResult | None]
|
|
77
|
+
|
|
78
|
+
# Static message type constants
|
|
79
|
+
ON_CLICK_MESSAGE_TYPE = "on_click"
|
|
80
|
+
GET_BUTTON_STATUS_MESSAGE_TYPE = "get_button_status"
|
|
81
|
+
SET_BUTTON_STATUS_MESSAGE_TYPE = "set_button_status"
|
|
82
|
+
|
|
83
|
+
# Button styling and behavior properties
|
|
84
|
+
label: str = "Button"
|
|
85
|
+
variant: ButtonVariant = "default"
|
|
86
|
+
size: ButtonSize = "default"
|
|
87
|
+
state: ButtonState = "normal"
|
|
88
|
+
icon: str | None = None
|
|
89
|
+
icon_class: str | None = None
|
|
90
|
+
icon_position: IconPosition | None = None
|
|
91
|
+
full_width: bool = False
|
|
92
|
+
loading_label: str | None = None
|
|
93
|
+
loading_icon: str | None = None
|
|
94
|
+
loading_icon_class: str | None = None
|
|
95
|
+
|
|
9
96
|
element_id: str = field(default_factory=lambda: "Button")
|
|
97
|
+
on_click_callback: OnClickCallback | None = field(default=None, init=False)
|
|
98
|
+
get_button_state_callback: GetButtonStateCallback | None = field(default=None, init=False)
|
|
10
99
|
|
|
11
|
-
def __init__(
|
|
100
|
+
def __init__( # noqa: PLR0913
|
|
101
|
+
self,
|
|
102
|
+
*,
|
|
103
|
+
label: str = "", # Allows a button with no text.
|
|
104
|
+
variant: ButtonVariant = "secondary",
|
|
105
|
+
size: ButtonSize = "default",
|
|
106
|
+
state: ButtonState = "normal",
|
|
107
|
+
icon: str | None = None,
|
|
108
|
+
icon_class: str | None = None,
|
|
109
|
+
icon_position: IconPosition | None = None,
|
|
110
|
+
full_width: bool = False,
|
|
111
|
+
loading_label: str | None = None,
|
|
112
|
+
loading_icon: str | None = None,
|
|
113
|
+
loading_icon_class: str | None = None,
|
|
114
|
+
on_click: OnClickCallback | None = None,
|
|
115
|
+
get_button_state: GetButtonStateCallback | None = None,
|
|
116
|
+
) -> None:
|
|
12
117
|
super().__init__(element_id="Button")
|
|
13
|
-
|
|
14
|
-
|
|
118
|
+
self.label = label
|
|
119
|
+
self.variant = variant
|
|
120
|
+
self.size = size
|
|
121
|
+
self.state = state
|
|
122
|
+
self.icon = icon
|
|
123
|
+
self.icon_class = icon_class
|
|
124
|
+
self.icon_position = icon_position
|
|
125
|
+
self.full_width = full_width
|
|
126
|
+
self.loading_label = loading_label
|
|
127
|
+
self.loading_icon = loading_icon
|
|
128
|
+
self.loading_icon_class = loading_icon_class
|
|
129
|
+
self.on_click_callback = on_click
|
|
130
|
+
self.get_button_state_callback = get_button_state
|
|
15
131
|
|
|
16
132
|
@classmethod
|
|
17
133
|
def get_trait_keys(cls) -> list[str]:
|
|
18
134
|
return ["button", "addbutton"]
|
|
19
135
|
|
|
136
|
+
def get_button_details(self, state: ButtonState | None = None) -> ButtonDetailsMessagePayload:
|
|
137
|
+
"""Create a ButtonDetailsMessagePayload with current or specified button state."""
|
|
138
|
+
return ButtonDetailsMessagePayload(
|
|
139
|
+
label=self.label,
|
|
140
|
+
variant=self.variant,
|
|
141
|
+
size=self.size,
|
|
142
|
+
state=state or self.state,
|
|
143
|
+
icon=self.icon,
|
|
144
|
+
icon_class=self.icon_class,
|
|
145
|
+
icon_position=self.icon_position,
|
|
146
|
+
full_width=self.full_width,
|
|
147
|
+
loading_label=self.loading_label,
|
|
148
|
+
loading_icon=self.loading_icon,
|
|
149
|
+
loading_icon_class=self.loading_icon_class,
|
|
150
|
+
)
|
|
151
|
+
|
|
20
152
|
def ui_options_for_trait(self) -> dict:
|
|
21
|
-
|
|
153
|
+
"""Generate UI options for the button trait with all styling properties."""
|
|
154
|
+
options = {
|
|
155
|
+
"button_label": self.label,
|
|
156
|
+
"variant": self.variant,
|
|
157
|
+
"size": self.size,
|
|
158
|
+
"state": self.state,
|
|
159
|
+
"full_width": self.full_width,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# Only include icon properties if icon is specified
|
|
163
|
+
if self.icon:
|
|
164
|
+
options["button_icon"] = self.icon
|
|
165
|
+
options["iconPosition"] = self.icon_position or "left"
|
|
166
|
+
if self.icon_class:
|
|
167
|
+
options["icon_class"] = self.icon_class
|
|
168
|
+
|
|
169
|
+
# Include loading properties if specified
|
|
170
|
+
if self.loading_label:
|
|
171
|
+
options["loading_label"] = self.loading_label
|
|
172
|
+
if self.loading_icon:
|
|
173
|
+
options["loading_icon"] = self.loading_icon
|
|
174
|
+
if self.loading_icon_class:
|
|
175
|
+
options["loading_icon_class"] = self.loading_icon_class
|
|
176
|
+
|
|
177
|
+
return options
|
|
178
|
+
|
|
179
|
+
def on_message_received(self, message_type: str, message: NodeMessagePayload | None) -> NodeMessageResult | None: # noqa: PLR0911
|
|
180
|
+
"""Handle messages sent to this button trait.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
message_type: String indicating the message type for parsing
|
|
184
|
+
message: Message payload as NodeMessagePayload or None
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
NodeMessageResult | None: Result if handled, None if no handler available
|
|
188
|
+
"""
|
|
189
|
+
match message_type.lower():
|
|
190
|
+
case self.ON_CLICK_MESSAGE_TYPE:
|
|
191
|
+
if self.on_click_callback is not None:
|
|
192
|
+
try:
|
|
193
|
+
# Pre-fill button details with current state and pass to callback
|
|
194
|
+
button_details = self.get_button_details()
|
|
195
|
+
result = self.on_click_callback(self, button_details)
|
|
196
|
+
|
|
197
|
+
# If callback returns None, provide optimistic success result
|
|
198
|
+
if result is None:
|
|
199
|
+
result = NodeMessageResult(
|
|
200
|
+
success=True,
|
|
201
|
+
details=f"Button '{self.label}' clicked successfully",
|
|
202
|
+
response=button_details,
|
|
203
|
+
)
|
|
204
|
+
return result # noqa: TRY300
|
|
205
|
+
except Exception as e:
|
|
206
|
+
return NodeMessageResult(
|
|
207
|
+
success=False,
|
|
208
|
+
details=f"Button '{self.label}' callback failed: {e!s}",
|
|
209
|
+
response=None,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Log debug message and fall through if no callback specified
|
|
213
|
+
logger.debug("Button '%s' was clicked, but no on_click_callback was specified.", self.label)
|
|
214
|
+
|
|
215
|
+
case self.GET_BUTTON_STATUS_MESSAGE_TYPE:
|
|
216
|
+
# Use custom callback if provided, otherwise use default implementation
|
|
217
|
+
if self.get_button_state_callback is not None:
|
|
218
|
+
try:
|
|
219
|
+
# Pre-fill button details with current state and pass to callback
|
|
220
|
+
button_details = self.get_button_details()
|
|
221
|
+
result = self.get_button_state_callback(self, button_details)
|
|
222
|
+
|
|
223
|
+
# If callback returns None, provide optimistic success result
|
|
224
|
+
if result is None:
|
|
225
|
+
result = NodeMessageResult(
|
|
226
|
+
success=True,
|
|
227
|
+
details=f"Button '{self.label}' state retrieved successfully",
|
|
228
|
+
response=button_details,
|
|
229
|
+
altered_workflow_state=False,
|
|
230
|
+
)
|
|
231
|
+
return result # noqa: TRY300
|
|
232
|
+
except Exception as e:
|
|
233
|
+
return NodeMessageResult(
|
|
234
|
+
success=False,
|
|
235
|
+
details=f"Button '{self.label}' get_button_state callback failed: {e!s}",
|
|
236
|
+
response=None,
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
return self._default_get_button_status(message_type, message)
|
|
240
|
+
|
|
241
|
+
case self.SET_BUTTON_STATUS_MESSAGE_TYPE:
|
|
242
|
+
return self._handle_set_button_status(message)
|
|
243
|
+
|
|
244
|
+
# Delegate to parent implementation for unhandled messages or no callback
|
|
245
|
+
return super().on_message_received(message_type, message)
|
|
246
|
+
|
|
247
|
+
def _default_get_button_status(
|
|
248
|
+
self,
|
|
249
|
+
message_type: str, # noqa: ARG002
|
|
250
|
+
message: NodeMessagePayload | None, # noqa: ARG002
|
|
251
|
+
) -> NodeMessageResult:
|
|
252
|
+
"""Default implementation for get_button_status that returns current button details."""
|
|
253
|
+
button_details = self.get_button_details()
|
|
254
|
+
|
|
255
|
+
return NodeMessageResult(
|
|
256
|
+
success=True,
|
|
257
|
+
details=f"Button '{self.label}' details retrieved",
|
|
258
|
+
response=button_details,
|
|
259
|
+
altered_workflow_state=False,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
def _handle_set_button_status(self, message: NodeMessagePayload | None) -> NodeMessageResult: # noqa: C901
|
|
263
|
+
"""Handle set button status messages by updating fields specified in the updates dict."""
|
|
264
|
+
if not message:
|
|
265
|
+
return NodeMessageResult(
|
|
266
|
+
success=False,
|
|
267
|
+
details="No message payload provided for set_button_status",
|
|
268
|
+
response=None,
|
|
269
|
+
altered_workflow_state=False,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if not isinstance(message, SetButtonStatusMessagePayload):
|
|
273
|
+
return NodeMessageResult(
|
|
274
|
+
success=False,
|
|
275
|
+
details="Invalid message payload type for set_button_status",
|
|
276
|
+
response=None,
|
|
277
|
+
altered_workflow_state=False,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# Track which fields were updated
|
|
281
|
+
updated_fields = []
|
|
282
|
+
validation_errors = []
|
|
283
|
+
|
|
284
|
+
# Valid field names and their expected types
|
|
285
|
+
valid_fields = {
|
|
286
|
+
"label": str,
|
|
287
|
+
"variant": str, # Will validate against ButtonVariant literals
|
|
288
|
+
"size": str, # Will validate against ButtonSize literals
|
|
289
|
+
"state": str, # Will validate against ButtonState literals
|
|
290
|
+
"icon": str,
|
|
291
|
+
"icon_class": str,
|
|
292
|
+
"icon_position": str, # Will validate against IconPosition literals
|
|
293
|
+
"full_width": bool,
|
|
294
|
+
"loading_label": str,
|
|
295
|
+
"loading_icon": str,
|
|
296
|
+
"loading_icon_class": str,
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
# Process each update
|
|
300
|
+
for field_name, value in message.updates.items():
|
|
301
|
+
# Check if field is valid
|
|
302
|
+
if field_name not in valid_fields:
|
|
303
|
+
validation_errors.append(f"Invalid field: {field_name}")
|
|
304
|
+
continue
|
|
305
|
+
|
|
306
|
+
# Type check if value is not None
|
|
307
|
+
if value is not None and not isinstance(value, valid_fields[field_name]):
|
|
308
|
+
validation_errors.append(
|
|
309
|
+
f"Invalid type for {field_name}: expected {valid_fields[field_name].__name__}, got {type(value).__name__}"
|
|
310
|
+
)
|
|
311
|
+
continue
|
|
312
|
+
|
|
313
|
+
# Additional validation for Literal types
|
|
314
|
+
if field_name == "variant" and value is not None and value not in get_args(ButtonVariant):
|
|
315
|
+
validation_errors.append(f"Invalid variant: {value}")
|
|
316
|
+
continue
|
|
317
|
+
if field_name == "size" and value is not None and value not in get_args(ButtonSize):
|
|
318
|
+
validation_errors.append(f"Invalid size: {value}")
|
|
319
|
+
continue
|
|
320
|
+
if field_name == "state" and value is not None and value not in get_args(ButtonState):
|
|
321
|
+
validation_errors.append(f"Invalid state: {value}")
|
|
322
|
+
continue
|
|
323
|
+
if field_name == "icon_position" and value is not None and value not in get_args(IconPosition):
|
|
324
|
+
validation_errors.append(f"Invalid icon_position: {value}")
|
|
325
|
+
continue
|
|
326
|
+
|
|
327
|
+
# Update the field
|
|
328
|
+
setattr(self, field_name, value)
|
|
329
|
+
updated_fields.append(field_name)
|
|
330
|
+
|
|
331
|
+
# Return validation errors if any
|
|
332
|
+
if validation_errors:
|
|
333
|
+
return NodeMessageResult(
|
|
334
|
+
success=False,
|
|
335
|
+
details=f"Validation errors: {'; '.join(validation_errors)}",
|
|
336
|
+
response=None,
|
|
337
|
+
altered_workflow_state=False,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Return success with updated button details
|
|
341
|
+
button_details = self.get_button_details()
|
|
342
|
+
fields_str = ", ".join(updated_fields) if updated_fields else "no fields"
|
|
343
|
+
|
|
344
|
+
return NodeMessageResult(
|
|
345
|
+
success=True,
|
|
346
|
+
details=f"Button '{self.label}' updated ({fields_str})",
|
|
347
|
+
response=button_details,
|
|
348
|
+
altered_workflow_state=True,
|
|
349
|
+
)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from griptape_nodes.exe_types.core_types import Parameter, ParameterMode, Trait
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass(eq=False)
|
|
9
|
+
class ColorPicker(Trait):
|
|
10
|
+
format: Literal["hex", "hexa", "rgb", "rgba", "hsl", "hsla", "hsv", "hsva"] = "hex"
|
|
11
|
+
element_id: str = field(default_factory=lambda: "ColorPicker")
|
|
12
|
+
|
|
13
|
+
_allowed_modes: set = field(default_factory=lambda: {ParameterMode.PROPERTY})
|
|
14
|
+
|
|
15
|
+
def __init__(self, format: Literal["hex", "hexa", "rgb", "rgba", "hsl", "hsla", "hsv", "hsva"] = "hex") -> None: # noqa: A002
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.format = format
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def get_trait_keys(cls) -> list[str]:
|
|
21
|
+
return ["color_picker"]
|
|
22
|
+
|
|
23
|
+
def ui_options_for_trait(self) -> dict:
|
|
24
|
+
return {"color_picker": {"format": self.format}}
|
|
25
|
+
|
|
26
|
+
def _validate_hex_format(self, value: str) -> None:
|
|
27
|
+
"""Validate hex and hexa color formats."""
|
|
28
|
+
if not value.startswith("#"):
|
|
29
|
+
# Allow hex without # prefix
|
|
30
|
+
if len(value) in [6, 8] and all(c in "0123456789ABCDEFabcdef" for c in value):
|
|
31
|
+
return # Valid hex without # prefix
|
|
32
|
+
msg = f"Invalid {self.format} format: {value}. Expected format like #FF0000 or #FF000088"
|
|
33
|
+
raise ValueError(msg)
|
|
34
|
+
if self.format == "hex" and len(value) not in [4, 7]: # #fff or #ffffff
|
|
35
|
+
msg = f"Invalid hex format: {value}. Expected format like #FF0000 or #FFF"
|
|
36
|
+
raise ValueError(msg)
|
|
37
|
+
if self.format == "hexa" and len(value) not in [5, 9]: # #ffff or #ffffffff
|
|
38
|
+
msg = f"Invalid hexa format: {value}. Expected format like #FF000088 or #FFFF"
|
|
39
|
+
raise ValueError(msg)
|
|
40
|
+
|
|
41
|
+
def _validate_function_format(self, value: str, prefixes: tuple[str, ...], example: str) -> None:
|
|
42
|
+
"""Validate function-based color formats (rgb, hsl, hsv)."""
|
|
43
|
+
if not value.startswith(prefixes):
|
|
44
|
+
msg = f"Invalid {self.format} format: {value}. Expected format like {example}"
|
|
45
|
+
raise ValueError(msg)
|
|
46
|
+
|
|
47
|
+
def validators_for_trait(self) -> list[Callable[..., Any]]:
|
|
48
|
+
def validate(param: Parameter, value: Any) -> None: # noqa: ARG001
|
|
49
|
+
if value is None:
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
if not isinstance(value, str):
|
|
53
|
+
msg = f"Color value must be a string for format {self.format}"
|
|
54
|
+
raise TypeError(msg)
|
|
55
|
+
|
|
56
|
+
# Validate based on format
|
|
57
|
+
if self.format in ["hex", "hexa"]:
|
|
58
|
+
self._validate_hex_format(value)
|
|
59
|
+
elif self.format in ["rgb", "rgba"]:
|
|
60
|
+
self._validate_function_format(value, ("rgb(", "rgba("), "rgb(255, 255, 255)")
|
|
61
|
+
elif self.format in ["hsl", "hsla"]:
|
|
62
|
+
self._validate_function_format(value, ("hsl(", "hsla("), "hsl(0, 0%, 100%)")
|
|
63
|
+
elif self.format in ["hsv", "hsva"]:
|
|
64
|
+
self._validate_function_format(value, ("hsv(", "hsva("), "hsv(0, 0%, 100%)")
|
|
65
|
+
|
|
66
|
+
return [validate]
|