griptape-nodes 0.52.1__py3-none-any.whl → 0.54.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. griptape_nodes/__init__.py +8 -942
  2. griptape_nodes/__main__.py +6 -0
  3. griptape_nodes/app/app.py +48 -86
  4. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +35 -5
  5. griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +15 -1
  6. griptape_nodes/cli/__init__.py +1 -0
  7. griptape_nodes/cli/commands/__init__.py +1 -0
  8. griptape_nodes/cli/commands/config.py +74 -0
  9. griptape_nodes/cli/commands/engine.py +80 -0
  10. griptape_nodes/cli/commands/init.py +550 -0
  11. griptape_nodes/cli/commands/libraries.py +96 -0
  12. griptape_nodes/cli/commands/models.py +504 -0
  13. griptape_nodes/cli/commands/self.py +120 -0
  14. griptape_nodes/cli/main.py +56 -0
  15. griptape_nodes/cli/shared.py +75 -0
  16. griptape_nodes/common/__init__.py +1 -0
  17. griptape_nodes/common/directed_graph.py +71 -0
  18. griptape_nodes/drivers/storage/base_storage_driver.py +40 -20
  19. griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +24 -29
  20. griptape_nodes/drivers/storage/local_storage_driver.py +23 -14
  21. griptape_nodes/exe_types/core_types.py +60 -2
  22. griptape_nodes/exe_types/node_types.py +257 -38
  23. griptape_nodes/exe_types/param_components/__init__.py +1 -0
  24. griptape_nodes/exe_types/param_components/execution_status_component.py +138 -0
  25. griptape_nodes/machines/control_flow.py +195 -94
  26. griptape_nodes/machines/dag_builder.py +207 -0
  27. griptape_nodes/machines/fsm.py +10 -1
  28. griptape_nodes/machines/parallel_resolution.py +558 -0
  29. griptape_nodes/machines/{node_resolution.py → sequential_resolution.py} +30 -57
  30. griptape_nodes/node_library/library_registry.py +34 -1
  31. griptape_nodes/retained_mode/events/app_events.py +5 -1
  32. griptape_nodes/retained_mode/events/base_events.py +9 -9
  33. griptape_nodes/retained_mode/events/config_events.py +30 -0
  34. griptape_nodes/retained_mode/events/execution_events.py +2 -2
  35. griptape_nodes/retained_mode/events/model_events.py +296 -0
  36. griptape_nodes/retained_mode/events/node_events.py +4 -3
  37. griptape_nodes/retained_mode/griptape_nodes.py +34 -12
  38. griptape_nodes/retained_mode/managers/agent_manager.py +23 -5
  39. griptape_nodes/retained_mode/managers/arbitrary_code_exec_manager.py +3 -1
  40. griptape_nodes/retained_mode/managers/config_manager.py +44 -3
  41. griptape_nodes/retained_mode/managers/context_manager.py +6 -5
  42. griptape_nodes/retained_mode/managers/event_manager.py +8 -2
  43. griptape_nodes/retained_mode/managers/flow_manager.py +150 -206
  44. griptape_nodes/retained_mode/managers/library_lifecycle/library_directory.py +1 -1
  45. griptape_nodes/retained_mode/managers/library_manager.py +35 -25
  46. griptape_nodes/retained_mode/managers/model_manager.py +1107 -0
  47. griptape_nodes/retained_mode/managers/node_manager.py +102 -220
  48. griptape_nodes/retained_mode/managers/object_manager.py +11 -5
  49. griptape_nodes/retained_mode/managers/os_manager.py +28 -13
  50. griptape_nodes/retained_mode/managers/secrets_manager.py +8 -4
  51. griptape_nodes/retained_mode/managers/settings.py +116 -7
  52. griptape_nodes/retained_mode/managers/static_files_manager.py +85 -12
  53. griptape_nodes/retained_mode/managers/sync_manager.py +17 -9
  54. griptape_nodes/retained_mode/managers/workflow_manager.py +186 -192
  55. griptape_nodes/retained_mode/retained_mode.py +19 -0
  56. griptape_nodes/servers/__init__.py +1 -0
  57. griptape_nodes/{mcp_server/server.py → servers/mcp.py} +1 -1
  58. griptape_nodes/{app/api.py → servers/static.py} +43 -40
  59. griptape_nodes/traits/add_param_button.py +1 -1
  60. griptape_nodes/traits/button.py +334 -6
  61. griptape_nodes/traits/color_picker.py +66 -0
  62. griptape_nodes/traits/multi_options.py +188 -0
  63. griptape_nodes/traits/numbers_selector.py +77 -0
  64. griptape_nodes/traits/options.py +93 -2
  65. griptape_nodes/traits/traits.json +4 -0
  66. griptape_nodes/utils/async_utils.py +31 -0
  67. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/METADATA +4 -1
  68. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/RECORD +71 -48
  69. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/WHEEL +1 -1
  70. /griptape_nodes/{mcp_server → servers}/ws_request_manager.py +0 -0
  71. {griptape_nodes-0.52.1.dist-info → griptape_nodes-0.54.0.dist-info}/entry_points.txt +0 -0
@@ -10,6 +10,7 @@ from griptape_nodes.retained_mode.events.base_events import (
10
10
  )
11
11
  from griptape_nodes.retained_mode.events.config_events import (
12
12
  GetConfigCategoryRequest,
13
+ GetConfigSchemaRequest,
13
14
  GetConfigValueRequest,
14
15
  SetConfigCategoryRequest,
15
16
  SetConfigValueRequest,
@@ -1396,6 +1397,24 @@ class RetainedMode:
1396
1397
  result = GriptapeNodes().handle_request(request)
1397
1398
  return result
1398
1399
 
1400
+ @classmethod
1401
+ def get_config_schema(cls) -> ResultPayload:
1402
+ """Gets the JSON schema for the configuration model.
1403
+
1404
+ Returns:
1405
+ ResultPayload: Contains the configuration schema with field types, enums, and validation rules.
1406
+
1407
+ Example:
1408
+ # Get the configuration schema
1409
+ schema_result = cmd.get_config_schema()
1410
+ if isinstance(schema_result, GetConfigSchemaResultSuccess):
1411
+ schema = schema_result.schema
1412
+ # Use schema to render appropriate UI components
1413
+ """
1414
+ request = GetConfigSchemaRequest()
1415
+ result = GriptapeNodes().handle_request(request)
1416
+ return result
1417
+
1399
1418
  @classmethod
1400
1419
  def rename(cls, object_name: str, requested_name: str) -> ResultPayload:
1401
1420
  """Renames a node or flow.
@@ -0,0 +1 @@
1
+ """Package for web servers the engine may need to start."""
@@ -16,7 +16,6 @@ from pydantic import TypeAdapter
16
16
  from rich.logging import RichHandler
17
17
  from starlette.types import Receive, Scope, Send
18
18
 
19
- from griptape_nodes.mcp_server.ws_request_manager import AsyncRequestManager, WebSocketConnectionManager
20
19
  from griptape_nodes.retained_mode.events.base_events import RequestPayload
21
20
  from griptape_nodes.retained_mode.events.connection_events import (
22
21
  CreateConnectionRequest,
@@ -38,6 +37,7 @@ from griptape_nodes.retained_mode.events.parameter_events import (
38
37
  )
39
38
  from griptape_nodes.retained_mode.managers.config_manager import ConfigManager
40
39
  from griptape_nodes.retained_mode.managers.secrets_manager import SecretsManager
40
+ from griptape_nodes.servers.ws_request_manager import AsyncRequestManager, WebSocketConnectionManager
41
41
 
42
42
  SUPPORTED_REQUEST_EVENTS: dict[str, type[RequestPayload]] = {
43
43
  # Nodes
@@ -4,15 +4,16 @@ import binascii
4
4
  import logging
5
5
  import os
6
6
  from pathlib import Path
7
- from typing import Annotated
8
7
  from urllib.parse import urljoin
9
8
 
10
9
  import uvicorn
11
- from fastapi import Depends, FastAPI, HTTPException, Request
10
+ from fastapi import FastAPI, HTTPException, Request
12
11
  from fastapi.middleware.cors import CORSMiddleware
13
12
  from fastapi.staticfiles import StaticFiles
14
13
  from rich.logging import RichHandler
15
14
 
15
+ from griptape_nodes.retained_mode.griptape_nodes import GriptapeNodes
16
+
16
17
  # Whether to enable the static server
17
18
  STATIC_SERVER_ENABLED = os.getenv("STATIC_SERVER_ENABLED", "true").lower() == "true"
18
19
  # Host of the static server
@@ -20,7 +21,7 @@ STATIC_SERVER_HOST = os.getenv("STATIC_SERVER_HOST", "localhost")
20
21
  # Port of the static server
21
22
  STATIC_SERVER_PORT = int(os.getenv("STATIC_SERVER_PORT", "8124"))
22
23
  # URL path for the static server
23
- STATIC_SERVER_URL = os.getenv("STATIC_SERVER_URL", "/static")
24
+ STATIC_SERVER_URL = os.getenv("STATIC_SERVER_URL", "/workspace")
24
25
  # Log level for the static server
25
26
  STATIC_SERVER_LOG_LEVEL = os.getenv("STATIC_SERVER_LOG_LEVEL", "ERROR").lower()
26
27
 
@@ -28,18 +29,6 @@ logger = logging.getLogger("griptape_nodes_api")
28
29
  logging.getLogger("uvicorn").addHandler(RichHandler(show_time=True, show_path=False, markup=True, rich_tracebacks=True))
29
30
 
30
31
 
31
- # Global static directory - initialized as None and set when starting the API
32
- static_dir: Path | None = None
33
-
34
-
35
- def get_static_dir() -> Path:
36
- """FastAPI dependency to get the static directory."""
37
- if static_dir is None:
38
- msg = "Static directory is not initialized"
39
- raise HTTPException(status_code=500, detail=msg)
40
- return static_dir
41
-
42
-
43
32
  """Create and configure the FastAPI application."""
44
33
  app = FastAPI()
45
34
 
@@ -52,35 +41,34 @@ async def _create_static_file_upload_url(request: Request) -> dict:
52
41
  """
53
42
  base_url = request.base_url
54
43
  body = await request.json()
55
- file_name = body["file_name"]
56
- url = urljoin(str(base_url), f"/static-uploads/{file_name}")
44
+ file_path = body["file_path"].lstrip("/")
45
+ url = urljoin(str(base_url), f"/static-uploads/{file_path}")
57
46
 
58
47
  return {"url": url}
59
48
 
60
49
 
61
50
  @app.put("/static-uploads/{file_path:path}")
62
- async def _create_static_file(
63
- request: Request, file_path: str, static_directory: Annotated[Path, Depends(get_static_dir)]
64
- ) -> dict:
51
+ async def _create_static_file(request: Request, file_path: str) -> dict:
65
52
  """Upload a static file to the static server."""
66
53
  if not STATIC_SERVER_ENABLED:
67
54
  msg = "Static server is not enabled. Please set STATIC_SERVER_ENABLED to True."
68
55
  raise ValueError(msg)
69
56
 
70
- file_full_path = Path(static_directory / file_path)
57
+ workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
58
+ full_file_path = workspace_directory / file_path
71
59
 
72
60
  # Create parent directories if they don't exist
73
- file_full_path.parent.mkdir(parents=True, exist_ok=True)
61
+ full_file_path.parent.mkdir(parents=True, exist_ok=True)
74
62
 
75
63
  data = await request.body()
76
64
  try:
77
- file_full_path.write_bytes(data)
65
+ full_file_path.write_bytes(data)
78
66
  except binascii.Error as e:
79
67
  msg = f"Invalid base64 encoding for file {file_path}."
80
68
  logger.error(msg)
81
69
  raise HTTPException(status_code=400, detail=msg) from e
82
70
  except (OSError, PermissionError) as e:
83
- msg = f"Failed to write file {file_path} to {static_dir}: {e}"
71
+ msg = f"Failed to write file {full_file_path}: {e}"
84
72
  logger.error(msg)
85
73
  raise HTTPException(status_code=500, detail=msg) from e
86
74
 
@@ -88,19 +76,28 @@ async def _create_static_file(
88
76
  return {"url": static_url}
89
77
 
90
78
 
79
+ @app.get("/static-uploads/{file_path_prefix:path}")
91
80
  @app.get("/static-uploads/")
92
- async def _list_static_files(static_directory: Annotated[Path, Depends(get_static_dir)]) -> dict:
93
- """List all static files in the static server."""
81
+ async def _list_static_files(file_path_prefix: str = "") -> dict:
82
+ """List static files in the static server under the specified path prefix."""
94
83
  if not STATIC_SERVER_ENABLED:
95
84
  msg = "Static server is not enabled. Please set STATIC_SERVER_ENABLED to True."
96
85
  raise HTTPException(status_code=500, detail=msg)
97
86
 
87
+ workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
88
+
89
+ # Handle the prefix path
90
+ if file_path_prefix:
91
+ target_directory = workspace_directory / file_path_prefix
92
+ else:
93
+ target_directory = workspace_directory
94
+
98
95
  try:
99
96
  file_names = []
100
- if static_directory.exists():
101
- for file_path in static_directory.rglob("*"):
97
+ if target_directory.exists() and target_directory.is_dir():
98
+ for file_path in target_directory.rglob("*"):
102
99
  if file_path.is_file():
103
- relative_path = file_path.relative_to(static_directory)
100
+ relative_path = file_path.relative_to(workspace_directory)
104
101
  file_names.append(str(relative_path))
105
102
  except (OSError, PermissionError) as e:
106
103
  msg = f"Failed to list files in static directory: {e}"
@@ -111,13 +108,14 @@ async def _list_static_files(static_directory: Annotated[Path, Depends(get_stati
111
108
 
112
109
 
113
110
  @app.delete("/static-files/{file_path:path}")
114
- async def _delete_static_file(file_path: str, static_directory: Annotated[Path, Depends(get_static_dir)]) -> dict:
111
+ async def _delete_static_file(file_path: str) -> dict:
115
112
  """Delete a static file from the static server."""
116
113
  if not STATIC_SERVER_ENABLED:
117
114
  msg = "Static server is not enabled. Please set STATIC_SERVER_ENABLED to True."
118
115
  raise HTTPException(status_code=500, detail=msg)
119
116
 
120
- file_full_path = Path(static_directory / file_path)
117
+ workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
118
+ file_full_path = workspace_directory / file_path
121
119
 
122
120
  # Check if file exists
123
121
  if not file_full_path.exists():
@@ -141,13 +139,10 @@ async def _delete_static_file(file_path: str, static_directory: Annotated[Path,
141
139
  return {"message": f"File {file_path} deleted successfully"}
142
140
 
143
141
 
144
- def _setup_app(static_directory: Path) -> None:
142
+ def _setup_app() -> None:
145
143
  """Setup FastAPI app with middleware and static files."""
146
- global static_dir # noqa: PLW0603
147
- static_dir = static_directory
148
-
149
- if not static_dir.exists():
150
- static_dir.mkdir(parents=True, exist_ok=True)
144
+ workspace_directory = Path(GriptapeNodes.ConfigManager().get_config_value("workspace_directory"))
145
+ static_files_directory = Path(GriptapeNodes.ConfigManager().get_config_value("static_files_directory"))
151
146
 
152
147
  app.add_middleware(
153
148
  CORSMiddleware,
@@ -163,15 +158,23 @@ def _setup_app(static_directory: Path) -> None:
163
158
 
164
159
  app.mount(
165
160
  STATIC_SERVER_URL,
166
- StaticFiles(directory=static_directory),
161
+ StaticFiles(directory=workspace_directory),
162
+ name="workspace",
163
+ )
164
+ static_files_path = workspace_directory / static_files_directory
165
+ static_files_path.mkdir(parents=True, exist_ok=True)
166
+ # For legacy urls
167
+ app.mount(
168
+ "/static",
169
+ StaticFiles(directory=workspace_directory / static_files_directory),
167
170
  name="static",
168
171
  )
169
172
 
170
173
 
171
- def start_static_server(static_directory: Path) -> None:
174
+ def start_static_server() -> None:
172
175
  """Run uvicorn server synchronously using uvicorn.run."""
173
176
  # Setup the FastAPI app
174
- _setup_app(static_directory)
177
+ _setup_app()
175
178
 
176
179
  try:
177
180
  # Run server using uvicorn.run
@@ -11,7 +11,7 @@ class AddParameterButton(Trait):
11
11
 
12
12
  def __init__(self) -> None:
13
13
  super().__init__(element_id="AddParameterButton")
14
- self.add_child(Button(button_type="AddParameter"))
14
+ self.add_child(Button(label="AddParameter"))
15
15
 
16
16
  @classmethod
17
17
  def get_trait_keys(cls) -> list[str]:
@@ -1,21 +1,349 @@
1
+ import logging
1
2
  from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Literal, get_args
2
4
 
3
- from griptape_nodes.exe_types.core_types import Trait
5
+ from griptape_nodes.exe_types.core_types import NodeMessagePayload, NodeMessageResult, Trait
6
+
7
+ if TYPE_CHECKING:
8
+ from collections.abc import Callable
9
+
10
+ # Don't export callback types - let users import explicitly
11
+
12
+ logger = logging.getLogger("griptape_nodes")
13
+
14
+
15
+ # Type aliases using Literals
16
+ ButtonVariant = Literal[
17
+ "default",
18
+ "secondary",
19
+ "destructive",
20
+ "outline",
21
+ "ghost",
22
+ "link",
23
+ ]
24
+
25
+ ButtonSize = Literal[
26
+ "default",
27
+ "sm",
28
+ "icon",
29
+ ]
30
+
31
+ ButtonState = Literal[
32
+ "normal",
33
+ "disabled",
34
+ "loading",
35
+ "hidden",
36
+ ]
37
+
38
+ IconPosition = Literal[
39
+ "left",
40
+ "right",
41
+ ]
42
+
43
+
44
+ class ButtonDetailsMessagePayload(NodeMessagePayload):
45
+ """Payload containing complete button details and status information."""
46
+
47
+ label: str
48
+ variant: str
49
+ size: str
50
+ state: str
51
+ icon: str | None = None
52
+ icon_class: str | None = None
53
+ icon_position: str | None = None
54
+ full_width: bool = False
55
+ loading_label: str | None = None
56
+ loading_icon: str | None = None
57
+ loading_icon_class: str | None = None
58
+
59
+
60
+ class OnClickMessageResultPayload(NodeMessagePayload):
61
+ """Payload for button click result messages."""
62
+
63
+ button_details: ButtonDetailsMessagePayload
64
+
65
+
66
+ class SetButtonStatusMessagePayload(NodeMessagePayload):
67
+ """Payload for setting button status with explicit field updates."""
68
+
69
+ updates: dict[str, str | bool | None]
4
70
 
5
71
 
6
72
  @dataclass(eq=False)
7
73
  class Button(Trait):
8
- type: str = field(default_factory=lambda: "Generic")
74
+ # Specific callback types for better type safety and clarity
75
+ type OnClickCallback = Callable[[Button, ButtonDetailsMessagePayload], NodeMessageResult | None]
76
+ type GetButtonStateCallback = Callable[[Button, ButtonDetailsMessagePayload], NodeMessageResult | None]
77
+
78
+ # Static message type constants
79
+ ON_CLICK_MESSAGE_TYPE = "on_click"
80
+ GET_BUTTON_STATUS_MESSAGE_TYPE = "get_button_status"
81
+ SET_BUTTON_STATUS_MESSAGE_TYPE = "set_button_status"
82
+
83
+ # Button styling and behavior properties
84
+ label: str = "Button"
85
+ variant: ButtonVariant = "default"
86
+ size: ButtonSize = "default"
87
+ state: ButtonState = "normal"
88
+ icon: str | None = None
89
+ icon_class: str | None = None
90
+ icon_position: IconPosition | None = None
91
+ full_width: bool = False
92
+ loading_label: str | None = None
93
+ loading_icon: str | None = None
94
+ loading_icon_class: str | None = None
95
+
9
96
  element_id: str = field(default_factory=lambda: "Button")
97
+ on_click_callback: OnClickCallback | None = field(default=None, init=False)
98
+ get_button_state_callback: GetButtonStateCallback | None = field(default=None, init=False)
10
99
 
11
- def __init__(self, button_type: str | None = None) -> None:
100
+ def __init__( # noqa: PLR0913
101
+ self,
102
+ *,
103
+ label: str = "", # Allows a button with no text.
104
+ variant: ButtonVariant = "secondary",
105
+ size: ButtonSize = "default",
106
+ state: ButtonState = "normal",
107
+ icon: str | None = None,
108
+ icon_class: str | None = None,
109
+ icon_position: IconPosition | None = None,
110
+ full_width: bool = False,
111
+ loading_label: str | None = None,
112
+ loading_icon: str | None = None,
113
+ loading_icon_class: str | None = None,
114
+ on_click: OnClickCallback | None = None,
115
+ get_button_state: GetButtonStateCallback | None = None,
116
+ ) -> None:
12
117
  super().__init__(element_id="Button")
13
- if button_type:
14
- self.type = button_type
118
+ self.label = label
119
+ self.variant = variant
120
+ self.size = size
121
+ self.state = state
122
+ self.icon = icon
123
+ self.icon_class = icon_class
124
+ self.icon_position = icon_position
125
+ self.full_width = full_width
126
+ self.loading_label = loading_label
127
+ self.loading_icon = loading_icon
128
+ self.loading_icon_class = loading_icon_class
129
+ self.on_click_callback = on_click
130
+ self.get_button_state_callback = get_button_state
15
131
 
16
132
  @classmethod
17
133
  def get_trait_keys(cls) -> list[str]:
18
134
  return ["button", "addbutton"]
19
135
 
136
+ def get_button_details(self, state: ButtonState | None = None) -> ButtonDetailsMessagePayload:
137
+ """Create a ButtonDetailsMessagePayload with current or specified button state."""
138
+ return ButtonDetailsMessagePayload(
139
+ label=self.label,
140
+ variant=self.variant,
141
+ size=self.size,
142
+ state=state or self.state,
143
+ icon=self.icon,
144
+ icon_class=self.icon_class,
145
+ icon_position=self.icon_position,
146
+ full_width=self.full_width,
147
+ loading_label=self.loading_label,
148
+ loading_icon=self.loading_icon,
149
+ loading_icon_class=self.loading_icon_class,
150
+ )
151
+
20
152
  def ui_options_for_trait(self) -> dict:
21
- return {"button": self.type}
153
+ """Generate UI options for the button trait with all styling properties."""
154
+ options = {
155
+ "button_label": self.label,
156
+ "variant": self.variant,
157
+ "size": self.size,
158
+ "state": self.state,
159
+ "full_width": self.full_width,
160
+ }
161
+
162
+ # Only include icon properties if icon is specified
163
+ if self.icon:
164
+ options["button_icon"] = self.icon
165
+ options["iconPosition"] = self.icon_position or "left"
166
+ if self.icon_class:
167
+ options["icon_class"] = self.icon_class
168
+
169
+ # Include loading properties if specified
170
+ if self.loading_label:
171
+ options["loading_label"] = self.loading_label
172
+ if self.loading_icon:
173
+ options["loading_icon"] = self.loading_icon
174
+ if self.loading_icon_class:
175
+ options["loading_icon_class"] = self.loading_icon_class
176
+
177
+ return options
178
+
179
+ def on_message_received(self, message_type: str, message: NodeMessagePayload | None) -> NodeMessageResult | None: # noqa: PLR0911
180
+ """Handle messages sent to this button trait.
181
+
182
+ Args:
183
+ message_type: String indicating the message type for parsing
184
+ message: Message payload as NodeMessagePayload or None
185
+
186
+ Returns:
187
+ NodeMessageResult | None: Result if handled, None if no handler available
188
+ """
189
+ match message_type.lower():
190
+ case self.ON_CLICK_MESSAGE_TYPE:
191
+ if self.on_click_callback is not None:
192
+ try:
193
+ # Pre-fill button details with current state and pass to callback
194
+ button_details = self.get_button_details()
195
+ result = self.on_click_callback(self, button_details)
196
+
197
+ # If callback returns None, provide optimistic success result
198
+ if result is None:
199
+ result = NodeMessageResult(
200
+ success=True,
201
+ details=f"Button '{self.label}' clicked successfully",
202
+ response=button_details,
203
+ )
204
+ return result # noqa: TRY300
205
+ except Exception as e:
206
+ return NodeMessageResult(
207
+ success=False,
208
+ details=f"Button '{self.label}' callback failed: {e!s}",
209
+ response=None,
210
+ )
211
+
212
+ # Log debug message and fall through if no callback specified
213
+ logger.debug("Button '%s' was clicked, but no on_click_callback was specified.", self.label)
214
+
215
+ case self.GET_BUTTON_STATUS_MESSAGE_TYPE:
216
+ # Use custom callback if provided, otherwise use default implementation
217
+ if self.get_button_state_callback is not None:
218
+ try:
219
+ # Pre-fill button details with current state and pass to callback
220
+ button_details = self.get_button_details()
221
+ result = self.get_button_state_callback(self, button_details)
222
+
223
+ # If callback returns None, provide optimistic success result
224
+ if result is None:
225
+ result = NodeMessageResult(
226
+ success=True,
227
+ details=f"Button '{self.label}' state retrieved successfully",
228
+ response=button_details,
229
+ altered_workflow_state=False,
230
+ )
231
+ return result # noqa: TRY300
232
+ except Exception as e:
233
+ return NodeMessageResult(
234
+ success=False,
235
+ details=f"Button '{self.label}' get_button_state callback failed: {e!s}",
236
+ response=None,
237
+ )
238
+ else:
239
+ return self._default_get_button_status(message_type, message)
240
+
241
+ case self.SET_BUTTON_STATUS_MESSAGE_TYPE:
242
+ return self._handle_set_button_status(message)
243
+
244
+ # Delegate to parent implementation for unhandled messages or no callback
245
+ return super().on_message_received(message_type, message)
246
+
247
+ def _default_get_button_status(
248
+ self,
249
+ message_type: str, # noqa: ARG002
250
+ message: NodeMessagePayload | None, # noqa: ARG002
251
+ ) -> NodeMessageResult:
252
+ """Default implementation for get_button_status that returns current button details."""
253
+ button_details = self.get_button_details()
254
+
255
+ return NodeMessageResult(
256
+ success=True,
257
+ details=f"Button '{self.label}' details retrieved",
258
+ response=button_details,
259
+ altered_workflow_state=False,
260
+ )
261
+
262
+ def _handle_set_button_status(self, message: NodeMessagePayload | None) -> NodeMessageResult: # noqa: C901
263
+ """Handle set button status messages by updating fields specified in the updates dict."""
264
+ if not message:
265
+ return NodeMessageResult(
266
+ success=False,
267
+ details="No message payload provided for set_button_status",
268
+ response=None,
269
+ altered_workflow_state=False,
270
+ )
271
+
272
+ if not isinstance(message, SetButtonStatusMessagePayload):
273
+ return NodeMessageResult(
274
+ success=False,
275
+ details="Invalid message payload type for set_button_status",
276
+ response=None,
277
+ altered_workflow_state=False,
278
+ )
279
+
280
+ # Track which fields were updated
281
+ updated_fields = []
282
+ validation_errors = []
283
+
284
+ # Valid field names and their expected types
285
+ valid_fields = {
286
+ "label": str,
287
+ "variant": str, # Will validate against ButtonVariant literals
288
+ "size": str, # Will validate against ButtonSize literals
289
+ "state": str, # Will validate against ButtonState literals
290
+ "icon": str,
291
+ "icon_class": str,
292
+ "icon_position": str, # Will validate against IconPosition literals
293
+ "full_width": bool,
294
+ "loading_label": str,
295
+ "loading_icon": str,
296
+ "loading_icon_class": str,
297
+ }
298
+
299
+ # Process each update
300
+ for field_name, value in message.updates.items():
301
+ # Check if field is valid
302
+ if field_name not in valid_fields:
303
+ validation_errors.append(f"Invalid field: {field_name}")
304
+ continue
305
+
306
+ # Type check if value is not None
307
+ if value is not None and not isinstance(value, valid_fields[field_name]):
308
+ validation_errors.append(
309
+ f"Invalid type for {field_name}: expected {valid_fields[field_name].__name__}, got {type(value).__name__}"
310
+ )
311
+ continue
312
+
313
+ # Additional validation for Literal types
314
+ if field_name == "variant" and value is not None and value not in get_args(ButtonVariant):
315
+ validation_errors.append(f"Invalid variant: {value}")
316
+ continue
317
+ if field_name == "size" and value is not None and value not in get_args(ButtonSize):
318
+ validation_errors.append(f"Invalid size: {value}")
319
+ continue
320
+ if field_name == "state" and value is not None and value not in get_args(ButtonState):
321
+ validation_errors.append(f"Invalid state: {value}")
322
+ continue
323
+ if field_name == "icon_position" and value is not None and value not in get_args(IconPosition):
324
+ validation_errors.append(f"Invalid icon_position: {value}")
325
+ continue
326
+
327
+ # Update the field
328
+ setattr(self, field_name, value)
329
+ updated_fields.append(field_name)
330
+
331
+ # Return validation errors if any
332
+ if validation_errors:
333
+ return NodeMessageResult(
334
+ success=False,
335
+ details=f"Validation errors: {'; '.join(validation_errors)}",
336
+ response=None,
337
+ altered_workflow_state=False,
338
+ )
339
+
340
+ # Return success with updated button details
341
+ button_details = self.get_button_details()
342
+ fields_str = ", ".join(updated_fields) if updated_fields else "no fields"
343
+
344
+ return NodeMessageResult(
345
+ success=True,
346
+ details=f"Button '{self.label}' updated ({fields_str})",
347
+ response=button_details,
348
+ altered_workflow_state=True,
349
+ )
@@ -0,0 +1,66 @@
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Literal
4
+
5
+ from griptape_nodes.exe_types.core_types import Parameter, ParameterMode, Trait
6
+
7
+
8
+ @dataclass(eq=False)
9
+ class ColorPicker(Trait):
10
+ format: Literal["hex", "hexa", "rgb", "rgba", "hsl", "hsla", "hsv", "hsva"] = "hex"
11
+ element_id: str = field(default_factory=lambda: "ColorPicker")
12
+
13
+ _allowed_modes: set = field(default_factory=lambda: {ParameterMode.PROPERTY})
14
+
15
+ def __init__(self, format: Literal["hex", "hexa", "rgb", "rgba", "hsl", "hsla", "hsv", "hsva"] = "hex") -> None: # noqa: A002
16
+ super().__init__()
17
+ self.format = format
18
+
19
+ @classmethod
20
+ def get_trait_keys(cls) -> list[str]:
21
+ return ["color_picker"]
22
+
23
+ def ui_options_for_trait(self) -> dict:
24
+ return {"color_picker": {"format": self.format}}
25
+
26
+ def _validate_hex_format(self, value: str) -> None:
27
+ """Validate hex and hexa color formats."""
28
+ if not value.startswith("#"):
29
+ # Allow hex without # prefix
30
+ if len(value) in [6, 8] and all(c in "0123456789ABCDEFabcdef" for c in value):
31
+ return # Valid hex without # prefix
32
+ msg = f"Invalid {self.format} format: {value}. Expected format like #FF0000 or #FF000088"
33
+ raise ValueError(msg)
34
+ if self.format == "hex" and len(value) not in [4, 7]: # #fff or #ffffff
35
+ msg = f"Invalid hex format: {value}. Expected format like #FF0000 or #FFF"
36
+ raise ValueError(msg)
37
+ if self.format == "hexa" and len(value) not in [5, 9]: # #ffff or #ffffffff
38
+ msg = f"Invalid hexa format: {value}. Expected format like #FF000088 or #FFFF"
39
+ raise ValueError(msg)
40
+
41
+ def _validate_function_format(self, value: str, prefixes: tuple[str, ...], example: str) -> None:
42
+ """Validate function-based color formats (rgb, hsl, hsv)."""
43
+ if not value.startswith(prefixes):
44
+ msg = f"Invalid {self.format} format: {value}. Expected format like {example}"
45
+ raise ValueError(msg)
46
+
47
+ def validators_for_trait(self) -> list[Callable[..., Any]]:
48
+ def validate(param: Parameter, value: Any) -> None: # noqa: ARG001
49
+ if value is None:
50
+ return
51
+
52
+ if not isinstance(value, str):
53
+ msg = f"Color value must be a string for format {self.format}"
54
+ raise TypeError(msg)
55
+
56
+ # Validate based on format
57
+ if self.format in ["hex", "hexa"]:
58
+ self._validate_hex_format(value)
59
+ elif self.format in ["rgb", "rgba"]:
60
+ self._validate_function_format(value, ("rgb(", "rgba("), "rgb(255, 255, 255)")
61
+ elif self.format in ["hsl", "hsla"]:
62
+ self._validate_function_format(value, ("hsl(", "hsla("), "hsl(0, 0%, 100%)")
63
+ elif self.format in ["hsv", "hsva"]:
64
+ self._validate_function_format(value, ("hsv(", "hsva("), "hsv(0, 0%, 100%)")
65
+
66
+ return [validate]