griptape-nodes 0.38.1__py3-none-any.whl → 0.40.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. griptape_nodes/__init__.py +13 -9
  2. griptape_nodes/app/__init__.py +10 -1
  3. griptape_nodes/app/app.py +2 -3
  4. griptape_nodes/app/app_sessions.py +458 -0
  5. griptape_nodes/bootstrap/workflow_executors/__init__.py +1 -0
  6. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +213 -0
  7. griptape_nodes/bootstrap/workflow_executors/workflow_executor.py +13 -0
  8. griptape_nodes/bootstrap/workflow_runners/local_workflow_runner.py +1 -1
  9. griptape_nodes/drivers/storage/__init__.py +4 -0
  10. griptape_nodes/drivers/storage/storage_backend.py +10 -0
  11. griptape_nodes/exe_types/core_types.py +5 -1
  12. griptape_nodes/exe_types/node_types.py +20 -24
  13. griptape_nodes/machines/node_resolution.py +5 -1
  14. griptape_nodes/node_library/advanced_node_library.py +51 -0
  15. griptape_nodes/node_library/library_registry.py +28 -2
  16. griptape_nodes/node_library/workflow_registry.py +1 -1
  17. griptape_nodes/retained_mode/events/agent_events.py +15 -2
  18. griptape_nodes/retained_mode/events/app_events.py +113 -2
  19. griptape_nodes/retained_mode/events/base_events.py +28 -1
  20. griptape_nodes/retained_mode/events/library_events.py +111 -1
  21. griptape_nodes/retained_mode/events/workflow_events.py +1 -0
  22. griptape_nodes/retained_mode/griptape_nodes.py +240 -18
  23. griptape_nodes/retained_mode/managers/agent_manager.py +123 -17
  24. griptape_nodes/retained_mode/managers/flow_manager.py +16 -48
  25. griptape_nodes/retained_mode/managers/library_manager.py +642 -121
  26. griptape_nodes/retained_mode/managers/node_manager.py +1 -1
  27. griptape_nodes/retained_mode/managers/static_files_manager.py +4 -3
  28. griptape_nodes/retained_mode/managers/workflow_manager.py +666 -37
  29. griptape_nodes/retained_mode/utils/__init__.py +1 -0
  30. griptape_nodes/retained_mode/utils/engine_identity.py +131 -0
  31. griptape_nodes/retained_mode/utils/name_generator.py +162 -0
  32. griptape_nodes/retained_mode/utils/session_persistence.py +105 -0
  33. {griptape_nodes-0.38.1.dist-info → griptape_nodes-0.40.0.dist-info}/METADATA +1 -1
  34. {griptape_nodes-0.38.1.dist-info → griptape_nodes-0.40.0.dist-info}/RECORD +37 -27
  35. {griptape_nodes-0.38.1.dist-info → griptape_nodes-0.40.0.dist-info}/WHEEL +0 -0
  36. {griptape_nodes-0.38.1.dist-info → griptape_nodes-0.40.0.dist-info}/entry_points.txt +0 -0
  37. {griptape_nodes-0.38.1.dist-info → griptape_nodes-0.40.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,22 @@
1
+ import json
1
2
  import logging
3
+ import threading
4
+ import uuid
2
5
 
3
- from griptape.artifacts import ErrorArtifact, TextArtifact
6
+ from attrs import define, field
7
+ from griptape.artifacts import ErrorArtifact, ImageUrlArtifact, JsonArtifact
8
+ from griptape.drivers.image_generation import BaseImageGenerationDriver
9
+ from griptape.drivers.image_generation.griptape_cloud import GriptapeCloudImageGenerationDriver
4
10
  from griptape.drivers.prompt.griptape_cloud import GriptapeCloudPromptDriver
5
- from griptape.events import EventBus, FinishTaskEvent, TextChunkEvent
11
+ from griptape.events import EventBus, EventListener, FinishTaskEvent, TextChunkEvent
12
+ from griptape.loaders import ImageLoader
6
13
  from griptape.memory.structure import ConversationMemory
14
+ from griptape.rules import Rule, Ruleset
7
15
  from griptape.structures import Agent
16
+ from griptape.tools import BaseImageGenerationTool
17
+ from griptape.utils.decorators import activity
18
+ from json_repair import repair_json
19
+ from schema import Literal, Schema
8
20
 
9
21
  from griptape_nodes.retained_mode.events.agent_events import (
10
22
  AgentStreamEvent,
@@ -19,12 +31,16 @@ from griptape_nodes.retained_mode.events.agent_events import (
19
31
  ResetAgentConversationMemoryResultSuccess,
20
32
  RunAgentRequest,
21
33
  RunAgentResultFailure,
34
+ RunAgentResultStarted,
22
35
  RunAgentResultSuccess,
23
36
  )
24
37
  from griptape_nodes.retained_mode.events.base_events import ExecutionEvent, ExecutionGriptapeNodeEvent, ResultPayload
25
38
  from griptape_nodes.retained_mode.managers.config_manager import ConfigManager
26
39
  from griptape_nodes.retained_mode.managers.event_manager import EventManager
27
40
  from griptape_nodes.retained_mode.managers.secrets_manager import SecretsManager
41
+ from griptape_nodes.retained_mode.managers.static_files_manager import (
42
+ StaticFilesManager,
43
+ )
28
44
 
29
45
  logger = logging.getLogger("griptape_nodes")
30
46
 
@@ -35,10 +51,40 @@ config_manager = ConfigManager()
35
51
  secrets_manager = SecretsManager(config_manager)
36
52
 
37
53
 
54
+ @define
55
+ class NodesPromptImageGenerationTool(BaseImageGenerationTool):
56
+ image_generation_driver: BaseImageGenerationDriver = field(kw_only=True)
57
+ static_files_manager: StaticFilesManager = field(kw_only=True)
58
+
59
+ @activity(
60
+ config={
61
+ "description": "Generates an image from text prompts. Both prompt and negative_prompt are required.",
62
+ "schema": Schema(
63
+ {
64
+ Literal("prompt", description=BaseImageGenerationTool.PROMPT_DESCRIPTION): str,
65
+ Literal("negative_prompt", description=BaseImageGenerationTool.NEGATIVE_PROMPT_DESCRIPTION): str,
66
+ }
67
+ ),
68
+ },
69
+ )
70
+ def generate_image(self, params: dict[str, dict[str, str]]) -> ImageUrlArtifact | ErrorArtifact:
71
+ prompt = params["values"]["prompt"]
72
+ negative_prompt = params["values"]["negative_prompt"]
73
+
74
+ output_artifact = self.image_generation_driver.run_text_to_image(
75
+ prompts=[prompt], negative_prompts=[negative_prompt]
76
+ )
77
+ filename = f"{uuid.uuid4()}.png"
78
+ image_url = self.static_files_manager.save_static_file(output_artifact.to_bytes(), filename)
79
+ return ImageUrlArtifact(image_url)
80
+
81
+
38
82
  class AgentManager:
39
- def __init__(self, event_manager: EventManager | None = None) -> None:
83
+ def __init__(self, static_files_manager: StaticFilesManager, event_manager: EventManager | None = None) -> None:
40
84
  self.conversation_memory = ConversationMemory()
41
85
  self.prompt_driver = None
86
+ self.image_tool = None
87
+ self.static_files_manager = static_files_manager
42
88
 
43
89
  if event_manager is not None:
44
90
  event_manager.assign_manager_to_request_type(RunAgentRequest, self.on_handle_run_agent_request)
@@ -57,31 +103,91 @@ class AgentManager:
57
103
  raise ValueError(msg)
58
104
  return GriptapeCloudPromptDriver(api_key=api_key, stream=True)
59
105
 
106
+ def _initialize_image_tool(self) -> NodesPromptImageGenerationTool:
107
+ api_key = secrets_manager.get_secret(API_KEY_ENV_VAR)
108
+ if not api_key:
109
+ msg = f"Secret '{API_KEY_ENV_VAR}' not found"
110
+ raise ValueError(msg)
111
+ return NodesPromptImageGenerationTool(
112
+ image_generation_driver=GriptapeCloudImageGenerationDriver(api_key=api_key, model="dall-e-3"),
113
+ static_files_manager=self.static_files_manager,
114
+ )
115
+
60
116
  def on_handle_run_agent_request(self, request: RunAgentRequest) -> ResultPayload:
117
+ if self.prompt_driver is None:
118
+ self.prompt_driver = self._initialize_prompt_driver()
119
+ if self.image_tool is None:
120
+ self.image_tool = self._initialize_image_tool()
121
+ threading.Thread(target=self._on_handle_run_agent_request, args=(request, EventBus.event_listeners)).start()
122
+ return RunAgentResultStarted()
123
+
124
+ def _on_handle_run_agent_request(
125
+ self, request: RunAgentRequest, event_listeners: list[EventListener]
126
+ ) -> ResultPayload:
127
+ EventBus.event_listeners = event_listeners
61
128
  try:
62
- if self.prompt_driver is None:
63
- self.prompt_driver = self._initialize_prompt_driver()
64
- agent = Agent(prompt_driver=self.prompt_driver, conversation_memory=self.conversation_memory)
65
- *events, last_event = agent.run_stream(request.input)
129
+ artifacts = [
130
+ ImageLoader().parse(ImageUrlArtifact.from_dict(url_artifact).to_bytes())
131
+ for url_artifact in request.url_artifacts
132
+ if url_artifact["type"] == "ImageUrlArtifact"
133
+ ]
134
+
135
+ output_schema = Schema(
136
+ {
137
+ "generated_image_urls": [str],
138
+ "conversation_output": str,
139
+ }
140
+ )
141
+ agent = Agent(
142
+ prompt_driver=self.prompt_driver,
143
+ conversation_memory=self.conversation_memory,
144
+ tools=[self.image_tool] if self.image_tool else [],
145
+ output_schema=output_schema,
146
+ rulesets=[
147
+ Ruleset(
148
+ name="generated_image_urls",
149
+ rules=[
150
+ Rule("Do not hallucinate generated_image_urls."),
151
+ Rule("Only set generated_image_urls with images generated with your tools."),
152
+ ],
153
+ ),
154
+ ],
155
+ )
156
+ *events, last_event = agent.run_stream([request.input, *artifacts])
157
+ full_result = ""
158
+ last_conversation_output = ""
66
159
  for event in events:
67
160
  if isinstance(event, TextChunkEvent):
68
- EventBus.publish_event(
69
- ExecutionGriptapeNodeEvent(
70
- wrapped_event=ExecutionEvent(payload=AgentStreamEvent(token=event.token))
71
- )
72
- )
161
+ full_result += event.token
162
+ try:
163
+ result_json = json.loads(repair_json(full_result))
164
+ if "conversation_output" in result_json:
165
+ new_conversation_output = result_json["conversation_output"]
166
+ if new_conversation_output != last_conversation_output:
167
+ EventBus.publish_event(
168
+ ExecutionGriptapeNodeEvent(
169
+ wrapped_event=ExecutionEvent(
170
+ payload=AgentStreamEvent(
171
+ token=new_conversation_output[len(last_conversation_output) :]
172
+ )
173
+ )
174
+ )
175
+ )
176
+ last_conversation_output = new_conversation_output
177
+ except json.JSONDecodeError:
178
+ pass # Ignore incomplete JSON
73
179
  if isinstance(last_event, FinishTaskEvent):
74
180
  if isinstance(last_event.task_output, ErrorArtifact):
75
- return RunAgentResultFailure(last_event.task_output.to_json())
76
- if isinstance(last_event.task_output, TextArtifact):
77
- return RunAgentResultSuccess(last_event.task_output.to_json())
181
+ return RunAgentResultFailure(last_event.task_output.to_dict())
182
+ if isinstance(last_event.task_output, JsonArtifact):
183
+ return RunAgentResultSuccess(last_event.task_output.to_dict())
78
184
  err_msg = f"Unexpected final event: {last_event}"
79
185
  logger.error(err_msg)
80
- return RunAgentResultFailure(ErrorArtifact(last_event).to_json())
186
+ return RunAgentResultFailure(ErrorArtifact(last_event).to_dict())
81
187
  except Exception as e:
82
188
  err_msg = f"Error running agent: {e}"
83
189
  logger.error(err_msg)
84
- return RunAgentResultFailure(ErrorArtifact(e).to_json())
190
+ return RunAgentResultFailure(ErrorArtifact(e).to_dict())
85
191
 
86
192
  def on_handle_configure_agent_request(self, request: ConfigureAgentRequest) -> ResultPayload:
87
193
  try:
@@ -608,32 +608,16 @@ class FlowManager:
608
608
  return CreateConnectionResultFailure()
609
609
 
610
610
  # Let the source make any internal handling decisions now that the Connection has been made.
611
- try:
612
- source_node.after_outgoing_connection(
613
- source_parameter=source_param, target_node=target_node, target_parameter=target_param
614
- )
615
- except TypeError:
616
- source_node.after_outgoing_connection(
617
- source_parameter=source_param,
618
- target_node=target_node,
619
- target_parameter=target_param,
620
- modified_parameters_set=set(),
621
- )
611
+ source_node.after_outgoing_connection(
612
+ source_parameter=source_param, target_node=target_node, target_parameter=target_param
613
+ )
622
614
 
623
615
  # And target.
624
- try:
625
- target_node.after_incoming_connection(
626
- source_node=source_node,
627
- source_parameter=source_param,
628
- target_parameter=target_param,
629
- )
630
- except TypeError:
631
- target_node.after_incoming_connection(
632
- source_node=source_node,
633
- source_parameter=source_param,
634
- target_parameter=target_param,
635
- modified_parameters_set=set(),
636
- )
616
+ target_node.after_incoming_connection(
617
+ source_node=source_node,
618
+ source_parameter=source_param,
619
+ target_parameter=target_param,
620
+ )
637
621
 
638
622
  details = f'Connected "{source_node_name}.{request.source_parameter_name}" to "{target_node_name}.{request.target_parameter_name}"'
639
623
  logger.debug(details)
@@ -798,32 +782,16 @@ class FlowManager:
798
782
  except KeyError as e:
799
783
  logger.warning(e)
800
784
  # Let the source make any internal handling decisions now that the Connection has been REMOVED.
801
- try:
802
- source_node.after_outgoing_connection_removed(
803
- source_parameter=source_param, target_node=target_node, target_parameter=target_param
804
- )
805
- except TypeError:
806
- source_node.after_outgoing_connection_removed(
807
- source_parameter=source_param,
808
- target_node=target_node,
809
- target_parameter=target_param,
810
- modified_parameters_set=set(),
811
- )
785
+ source_node.after_outgoing_connection_removed(
786
+ source_parameter=source_param, target_node=target_node, target_parameter=target_param
787
+ )
812
788
 
813
789
  # And target.
814
- try:
815
- target_node.after_incoming_connection_removed(
816
- source_node=source_node,
817
- source_parameter=source_param,
818
- target_parameter=target_param,
819
- )
820
- except TypeError:
821
- target_node.after_incoming_connection_removed(
822
- source_node=source_node,
823
- source_parameter=source_param,
824
- target_parameter=target_param,
825
- modified_parameters_set=set(),
826
- )
790
+ target_node.after_incoming_connection_removed(
791
+ source_node=source_node,
792
+ source_parameter=source_param,
793
+ target_parameter=target_param,
794
+ )
827
795
 
828
796
  details = f'Connection "{source_node_name}.{request.source_parameter_name}" to "{target_node_name}.{request.target_parameter_name}" deleted.'
829
797
  logger.debug(details)