griptape-nodes 0.70.1__py3-none-any.whl → 0.72.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. griptape_nodes/api_client/client.py +8 -5
  2. griptape_nodes/app/app.py +4 -0
  3. griptape_nodes/bootstrap/utils/python_subprocess_executor.py +48 -9
  4. griptape_nodes/bootstrap/utils/subprocess_websocket_base.py +88 -0
  5. griptape_nodes/bootstrap/utils/subprocess_websocket_listener.py +126 -0
  6. griptape_nodes/bootstrap/utils/subprocess_websocket_sender.py +121 -0
  7. griptape_nodes/bootstrap/workflow_executors/local_session_workflow_executor.py +17 -170
  8. griptape_nodes/bootstrap/workflow_executors/local_workflow_executor.py +10 -1
  9. griptape_nodes/bootstrap/workflow_executors/subprocess_workflow_executor.py +13 -117
  10. griptape_nodes/bootstrap/workflow_executors/utils/subprocess_script.py +4 -0
  11. griptape_nodes/bootstrap/workflow_publishers/local_session_workflow_publisher.py +206 -0
  12. griptape_nodes/bootstrap/workflow_publishers/subprocess_workflow_publisher.py +22 -3
  13. griptape_nodes/bootstrap/workflow_publishers/utils/subprocess_script.py +49 -25
  14. griptape_nodes/common/node_executor.py +61 -14
  15. griptape_nodes/drivers/image_metadata/__init__.py +21 -0
  16. griptape_nodes/drivers/image_metadata/base_image_metadata_driver.py +63 -0
  17. griptape_nodes/drivers/image_metadata/exif_metadata_driver.py +218 -0
  18. griptape_nodes/drivers/image_metadata/image_metadata_driver_registry.py +55 -0
  19. griptape_nodes/drivers/image_metadata/png_metadata_driver.py +71 -0
  20. griptape_nodes/drivers/storage/base_storage_driver.py +32 -0
  21. griptape_nodes/drivers/storage/griptape_cloud_storage_driver.py +384 -10
  22. griptape_nodes/drivers/storage/local_storage_driver.py +65 -4
  23. griptape_nodes/drivers/thread_storage/local_thread_storage_driver.py +1 -0
  24. griptape_nodes/exe_types/base_iterative_nodes.py +1 -1
  25. griptape_nodes/exe_types/node_groups/base_node_group.py +3 -0
  26. griptape_nodes/exe_types/node_groups/subflow_node_group.py +18 -0
  27. griptape_nodes/exe_types/node_types.py +13 -0
  28. griptape_nodes/exe_types/param_components/log_parameter.py +4 -4
  29. griptape_nodes/exe_types/param_components/subflow_execution_component.py +329 -0
  30. griptape_nodes/exe_types/param_types/parameter_audio.py +17 -2
  31. griptape_nodes/exe_types/param_types/parameter_float.py +4 -4
  32. griptape_nodes/exe_types/param_types/parameter_image.py +14 -1
  33. griptape_nodes/exe_types/param_types/parameter_int.py +4 -4
  34. griptape_nodes/exe_types/param_types/parameter_number.py +12 -14
  35. griptape_nodes/exe_types/param_types/parameter_three_d.py +14 -1
  36. griptape_nodes/exe_types/param_types/parameter_video.py +17 -2
  37. griptape_nodes/node_library/workflow_registry.py +5 -8
  38. griptape_nodes/retained_mode/events/app_events.py +1 -0
  39. griptape_nodes/retained_mode/events/base_events.py +42 -26
  40. griptape_nodes/retained_mode/events/flow_events.py +67 -0
  41. griptape_nodes/retained_mode/events/library_events.py +1 -1
  42. griptape_nodes/retained_mode/events/node_events.py +1 -0
  43. griptape_nodes/retained_mode/events/os_events.py +22 -0
  44. griptape_nodes/retained_mode/events/static_file_events.py +28 -4
  45. griptape_nodes/retained_mode/managers/flow_manager.py +134 -0
  46. griptape_nodes/retained_mode/managers/image_metadata_injector.py +339 -0
  47. griptape_nodes/retained_mode/managers/library_manager.py +71 -41
  48. griptape_nodes/retained_mode/managers/model_manager.py +1 -0
  49. griptape_nodes/retained_mode/managers/node_manager.py +8 -5
  50. griptape_nodes/retained_mode/managers/os_manager.py +270 -33
  51. griptape_nodes/retained_mode/managers/project_manager.py +3 -7
  52. griptape_nodes/retained_mode/managers/session_manager.py +1 -0
  53. griptape_nodes/retained_mode/managers/settings.py +5 -0
  54. griptape_nodes/retained_mode/managers/static_files_manager.py +83 -17
  55. griptape_nodes/retained_mode/managers/workflow_manager.py +71 -41
  56. griptape_nodes/servers/static.py +31 -0
  57. griptape_nodes/utils/__init__.py +9 -1
  58. griptape_nodes/utils/artifact_normalization.py +245 -0
  59. griptape_nodes/utils/file_utils.py +13 -13
  60. griptape_nodes/utils/http_file_patch.py +613 -0
  61. griptape_nodes/utils/image_preview.py +27 -0
  62. griptape_nodes/utils/path_utils.py +58 -0
  63. griptape_nodes/utils/url_utils.py +106 -0
  64. {griptape_nodes-0.70.1.dist-info → griptape_nodes-0.72.0.dist-info}/METADATA +2 -1
  65. {griptape_nodes-0.70.1.dist-info → griptape_nodes-0.72.0.dist-info}/RECORD +67 -52
  66. {griptape_nodes-0.70.1.dist-info → griptape_nodes-0.72.0.dist-info}/WHEEL +1 -1
  67. {griptape_nodes-0.70.1.dist-info → griptape_nodes-0.72.0.dist-info}/entry_points.txt +0 -0
@@ -8,9 +8,11 @@ from typing import TYPE_CHECKING, Any, Self
8
8
  import anyio
9
9
 
10
10
  from griptape_nodes.bootstrap.utils.python_subprocess_executor import PythonSubprocessExecutor
11
+ from griptape_nodes.bootstrap.utils.subprocess_websocket_listener import SubprocessWebSocketListenerMixin
11
12
  from griptape_nodes.bootstrap.workflow_publishers.local_workflow_publisher import LocalWorkflowPublisher
12
13
 
13
14
  if TYPE_CHECKING:
15
+ from collections.abc import Callable
14
16
  from types import TracebackType
15
17
 
16
18
  logger = logging.getLogger(__name__)
@@ -20,11 +22,18 @@ class SubprocessWorkflowPublisherError(Exception):
20
22
  """Exception raised during subprocess workflow publishing."""
21
23
 
22
24
 
23
- class SubprocessWorkflowPublisher(LocalWorkflowPublisher, PythonSubprocessExecutor):
24
- def __init__(self) -> None:
25
+ class SubprocessWorkflowPublisher(LocalWorkflowPublisher, PythonSubprocessExecutor, SubprocessWebSocketListenerMixin):
26
+ def __init__(
27
+ self,
28
+ on_event: Callable[[dict], None] | None = None,
29
+ session_id: str | None = None,
30
+ ) -> None:
25
31
  PythonSubprocessExecutor.__init__(self)
32
+ self._init_websocket_listener(session_id=session_id, on_event=on_event)
26
33
 
27
34
  async def __aenter__(self) -> Self:
35
+ """Async context manager entry: start WebSocket listener."""
36
+ await self._start_websocket_listener()
28
37
  return self
29
38
 
30
39
  async def __aexit__(
@@ -33,7 +42,8 @@ class SubprocessWorkflowPublisher(LocalWorkflowPublisher, PythonSubprocessExecut
33
42
  exc_val: BaseException | None,
34
43
  exc_tb: TracebackType | None,
35
44
  ) -> None:
36
- return
45
+ """Async context manager exit: stop WebSocket listener."""
46
+ await self._stop_websocket_listener()
37
47
 
38
48
  async def arun(
39
49
  self,
@@ -76,6 +86,8 @@ class SubprocessWorkflowPublisher(LocalWorkflowPublisher, PythonSubprocessExecut
76
86
  publisher_name,
77
87
  "--published-workflow-file-name",
78
88
  published_workflow_file_name,
89
+ "--session-id",
90
+ self._session_id,
79
91
  ]
80
92
  if kwargs.get("pickle_control_flow_result"):
81
93
  args.append("--pickle-control-flow-result")
@@ -87,3 +99,10 @@ class SubprocessWorkflowPublisher(LocalWorkflowPublisher, PythonSubprocessExecut
87
99
  "GTN_CONFIG_ENABLE_WORKSPACE_FILE_WATCHING": "false",
88
100
  },
89
101
  )
102
+
103
+ async def _handle_subprocess_event(self, event: dict) -> None:
104
+ """Handle publisher-specific events from the subprocess.
105
+
106
+ Currently, this is a no-op as we just forward all events via the on_event callback.
107
+ Subclasses can override to add specific event handling logic.
108
+ """
@@ -1,33 +1,51 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from argparse import ArgumentParser
4
+ from dataclasses import dataclass
4
5
 
6
+ from griptape_nodes.bootstrap.workflow_publishers.local_session_workflow_publisher import (
7
+ LocalSessionWorkflowPublisher,
8
+ )
5
9
  from griptape_nodes.bootstrap.workflow_publishers.local_workflow_publisher import LocalWorkflowPublisher
10
+ from griptape_nodes.utils import install_file_url_support
11
+
12
+ # Install file:// URL support for httpx/requests in subprocess
13
+ install_file_url_support()
6
14
 
7
15
  logging.basicConfig(level=logging.INFO)
8
16
 
9
17
  logger = logging.getLogger(__name__)
10
18
 
11
19
 
12
- async def _main(
13
- workflow_name: str,
14
- workflow_path: str,
15
- publisher_name: str,
16
- published_workflow_file_name: str,
17
- *,
18
- pickle_control_flow_result: bool,
19
- ) -> None:
20
- local_publisher = LocalWorkflowPublisher()
21
- async with local_publisher as publisher:
20
+ @dataclass
21
+ class PublishWorkflowArgs:
22
+ """Arguments for publishing a workflow."""
23
+
24
+ workflow_name: str
25
+ workflow_path: str
26
+ publisher_name: str
27
+ published_workflow_file_name: str
28
+ pickle_control_flow_result: bool
29
+ session_id: str | None = None
30
+
31
+
32
+ async def _main(args: PublishWorkflowArgs) -> None:
33
+ publisher: LocalWorkflowPublisher
34
+ if args.session_id is not None:
35
+ publisher = LocalSessionWorkflowPublisher(session_id=args.session_id)
36
+ else:
37
+ publisher = LocalWorkflowPublisher()
38
+
39
+ async with publisher:
22
40
  await publisher.arun(
23
- workflow_name=workflow_name,
24
- workflow_path=workflow_path,
25
- publisher_name=publisher_name,
26
- published_workflow_file_name=published_workflow_file_name,
27
- pickle_control_flow_result=pickle_control_flow_result,
41
+ workflow_name=args.workflow_name,
42
+ workflow_path=args.workflow_path,
43
+ publisher_name=args.publisher_name,
44
+ published_workflow_file_name=args.published_workflow_file_name,
45
+ pickle_control_flow_result=args.pickle_control_flow_result,
28
46
  )
29
47
 
30
- msg = f"Published workflow to file: {published_workflow_file_name}"
48
+ msg = f"Published workflow to file: {args.published_workflow_file_name}"
31
49
  logger.info(msg)
32
50
 
33
51
 
@@ -57,13 +75,19 @@ if __name__ == "__main__":
57
75
  default=False,
58
76
  help="Whether to pickle control flow results",
59
77
  )
60
- args = parser.parse_args()
61
- asyncio.run(
62
- _main(
63
- workflow_name=args.workflow_name,
64
- workflow_path=args.workflow_path,
65
- publisher_name=args.publisher_name,
66
- published_workflow_file_name=args.published_workflow_file_name,
67
- pickle_control_flow_result=args.pickle_control_flow_result,
68
- )
78
+ parser.add_argument(
79
+ "--session-id",
80
+ default=None,
81
+ help="Session ID for WebSocket event emission",
82
+ )
83
+ parsed_args = parser.parse_args()
84
+
85
+ publish_args = PublishWorkflowArgs(
86
+ workflow_name=parsed_args.workflow_name,
87
+ workflow_path=parsed_args.workflow_path,
88
+ publisher_name=parsed_args.publisher_name,
89
+ published_workflow_file_name=parsed_args.published_workflow_file_name,
90
+ pickle_control_flow_result=parsed_args.pickle_control_flow_result,
91
+ session_id=parsed_args.session_id,
69
92
  )
93
+ asyncio.run(_main(publish_args))
@@ -102,6 +102,8 @@ from griptape_nodes.retained_mode.managers.event_manager import (
102
102
  )
103
103
 
104
104
  if TYPE_CHECKING:
105
+ from collections.abc import Callable
106
+
105
107
  from griptape_nodes.retained_mode.events.node_events import SerializedNodeCommands
106
108
  from griptape_nodes.retained_mode.managers.library_manager import LibraryManager
107
109
 
@@ -220,6 +222,8 @@ class NodeExecutor:
220
222
  # Just execute the node normally! This means we aren't doing any special packaging.
221
223
  await node.aprocess()
222
224
  return
225
+ # Clear execution state before subprocess execution starts
226
+ node.subflow_execution_component.clear_execution_state()
223
227
  if execution_type == PRIVATE_EXECUTION:
224
228
  # Package the flow and run it in a subprocess.
225
229
  await self._execute_private_workflow(node)
@@ -251,7 +255,9 @@ class NodeExecutor:
251
255
  file_name: Name of workflow for logging
252
256
  package_result: The packaging result containing parameter mappings
253
257
  """
254
- my_subprocess_result = await self._execute_subprocess(workflow_path, file_name)
258
+ # Pass node for event updates if it's a SubflowNodeGroup
259
+ subflow_node = node if isinstance(node, SubflowNodeGroup) else None
260
+ my_subprocess_result = await self._execute_subprocess(workflow_path, file_name, node=subflow_node)
255
261
  parameter_output_values = self._extract_parameter_output_values(my_subprocess_result)
256
262
  self._apply_parameter_values_to_node(node, parameter_output_values, package_result)
257
263
 
@@ -343,7 +349,7 @@ class NodeExecutor:
343
349
 
344
350
  try:
345
351
  published_workflow_filename = await self._publish_library_workflow(
346
- workflow_result, library_name, result.file_name
352
+ workflow_result, library_name, result.file_name, node=node
347
353
  )
348
354
  except Exception as e:
349
355
  logger.exception(
@@ -481,19 +487,29 @@ class NodeExecutor:
481
487
  )
482
488
 
483
489
  async def _publish_library_workflow(
484
- self, workflow_result: SaveWorkflowFileFromSerializedFlowResultSuccess, library_name: str, file_name: str
490
+ self,
491
+ workflow_result: SaveWorkflowFileFromSerializedFlowResultSuccess,
492
+ library_name: str,
493
+ file_name: str,
494
+ node: BaseNode | None = None,
485
495
  ) -> Path:
486
- subprocess_workflow_publisher = SubprocessWorkflowPublisher()
496
+ # Define event callback if node is a SubflowNodeGroup for GUI updates
497
+ on_event: Callable[[dict], None] | None = None
498
+ if isinstance(node, SubflowNodeGroup):
499
+ on_event = node.subflow_execution_component.handle_publishing_event
500
+
501
+ subprocess_workflow_publisher = SubprocessWorkflowPublisher(on_event=on_event)
487
502
  published_filename = f"{Path(workflow_result.file_path).stem}_published"
488
503
  published_workflow_filename = GriptapeNodes.ConfigManager().workspace_path / (published_filename + ".py")
489
504
 
490
- await subprocess_workflow_publisher.arun(
491
- workflow_name=file_name,
492
- workflow_path=workflow_result.file_path,
493
- publisher_name=library_name,
494
- published_workflow_file_name=published_filename,
495
- pickle_control_flow_result=True,
496
- )
505
+ async with subprocess_workflow_publisher:
506
+ await subprocess_workflow_publisher.arun(
507
+ workflow_name=file_name,
508
+ workflow_path=workflow_result.file_path,
509
+ publisher_name=library_name,
510
+ published_workflow_file_name=published_filename,
511
+ pickle_control_flow_result=True,
512
+ )
497
513
 
498
514
  if not published_workflow_filename.exists():
499
515
  msg = f"Published workflow file does not exist at path: {published_workflow_filename}"
@@ -507,6 +523,7 @@ class NodeExecutor:
507
523
  file_name: str,
508
524
  pickle_control_flow_result: bool = True, # noqa: FBT001, FBT002
509
525
  flow_input: dict[str, Any] | None = None,
526
+ node: SubflowNodeGroup | None = None,
510
527
  ) -> dict[str, dict[str | SerializedNodeCommands.UniqueParameterValueUUID, Any] | None]:
511
528
  """Execute the published workflow in a subprocess.
512
529
 
@@ -515,6 +532,7 @@ class NodeExecutor:
515
532
  file_name: Name of the workflow for logging
516
533
  pickle_control_flow_result: Whether to pickle control flow results (defaults to True)
517
534
  flow_input: Optional dictionary of parameter values to pass to the workflow's StartFlow node
535
+ node: Optional SubflowNodeGroup to receive real-time event updates
518
536
 
519
537
  Returns:
520
538
  The subprocess execution output dictionary
@@ -523,7 +541,15 @@ class NodeExecutor:
523
541
  SubprocessWorkflowExecutor,
524
542
  )
525
543
 
526
- subprocess_executor = SubprocessWorkflowExecutor(workflow_path=str(published_workflow_filename))
544
+ # Define event callback if node provided for GUI updates
545
+ on_event: Callable[[dict], None] | None = None
546
+ if node is not None:
547
+ on_event = node.subflow_execution_component.handle_execution_event
548
+
549
+ subprocess_executor = SubprocessWorkflowExecutor(
550
+ workflow_path=str(published_workflow_filename),
551
+ on_event=on_event,
552
+ )
527
553
  try:
528
554
  async with subprocess_executor as executor:
529
555
  await executor.arun(
@@ -1454,6 +1480,10 @@ class NodeExecutor:
1454
1480
  # Get execution environment
1455
1481
  execution_type = node.get_parameter_value(node.execution_environment.name)
1456
1482
 
1483
+ # Clear execution state before subprocess execution starts (for non-local execution)
1484
+ if execution_type != LOCAL_EXECUTION:
1485
+ node.subflow_execution_component.clear_execution_state()
1486
+
1457
1487
  # Check if we should run in order (default is sequential/True)
1458
1488
  run_in_order = node.get_parameter_value("run_in_order")
1459
1489
 
@@ -2550,11 +2580,14 @@ class NodeExecutor:
2550
2580
  end_loop_node.name,
2551
2581
  )
2552
2582
 
2583
+ # Pass node for event updates if it's a SubflowNodeGroup (includes BaseIterativeNodeGroup)
2584
+ subflow_node = end_loop_node if isinstance(end_loop_node, SubflowNodeGroup) else None
2553
2585
  subprocess_result = await self._execute_subprocess(
2554
2586
  published_workflow_filename=workflow_path,
2555
2587
  file_name=f"{file_name_prefix}_iteration_{iteration_index}",
2556
2588
  pickle_control_flow_result=True,
2557
2589
  flow_input=flow_input,
2590
+ node=subflow_node,
2558
2591
  )
2559
2592
  iteration_outputs.append((iteration_index, True, subprocess_result))
2560
2593
  except Exception:
@@ -2562,6 +2595,9 @@ class NodeExecutor:
2562
2595
  iteration_outputs.append((iteration_index, False, None))
2563
2596
  else:
2564
2597
  # Execute all iterations concurrently
2598
+ # Get subflow_node reference for event updates (scoped outside the closure)
2599
+ subflow_node = end_loop_node if isinstance(end_loop_node, SubflowNodeGroup) else None
2600
+
2565
2601
  async def run_single_iteration(iteration_index: int) -> tuple[int, bool, dict[str, Any] | None]:
2566
2602
  try:
2567
2603
  flow_input = {start_node_name: parameter_values_per_iteration[iteration_index]}
@@ -2577,6 +2613,7 @@ class NodeExecutor:
2577
2613
  file_name=f"{file_name_prefix}_iteration_{iteration_index}",
2578
2614
  pickle_control_flow_result=True,
2579
2615
  flow_input=flow_input,
2616
+ node=subflow_node,
2580
2617
  )
2581
2618
  except Exception:
2582
2619
  logger.exception("Iteration %d failed for loop '%s'", iteration_index, end_loop_node.name)
@@ -2794,10 +2831,13 @@ class NodeExecutor:
2794
2831
  sanitized_loop_name = end_loop_node.name.replace(" ", "_")
2795
2832
  file_name_prefix = f"{sanitized_loop_name}_{library_name.replace(' ', '_')}_sequential_loop_flow"
2796
2833
 
2834
+ # Pass node for publishing progress events if it's a SubflowNodeGroup
2835
+ publish_node = end_loop_node if isinstance(end_loop_node, SubflowNodeGroup) else None
2797
2836
  published_workflow_filename, workflow_result = await self._publish_workflow_for_loop_execution(
2798
2837
  package_result=package_result,
2799
2838
  library_name=library_name,
2800
2839
  file_name=file_name_prefix,
2840
+ node=publish_node,
2801
2841
  )
2802
2842
 
2803
2843
  try:
@@ -2837,10 +2877,13 @@ class NodeExecutor:
2837
2877
  sanitized_loop_name = end_loop_node.name.replace(" ", "_")
2838
2878
  file_name_prefix = f"{sanitized_loop_name}_{library_name.replace(' ', '_')}_loop_flow"
2839
2879
 
2880
+ # Pass node for publishing progress events if it's a SubflowNodeGroup
2881
+ publish_node = end_loop_node if isinstance(end_loop_node, SubflowNodeGroup) else None
2840
2882
  published_workflow_filename, workflow_result = await self._publish_workflow_for_loop_execution(
2841
2883
  package_result=package_result,
2842
2884
  library_name=library_name,
2843
2885
  file_name=file_name_prefix,
2886
+ node=publish_node,
2844
2887
  )
2845
2888
 
2846
2889
  try:
@@ -2866,6 +2909,7 @@ class NodeExecutor:
2866
2909
  package_result: PackageNodesAsSerializedFlowResultSuccess,
2867
2910
  library_name: str,
2868
2911
  file_name: str,
2912
+ node: BaseNode | None = None,
2869
2913
  ) -> tuple[Path, Any]:
2870
2914
  """Save and publish workflow for loop execution via publisher.
2871
2915
 
@@ -2873,6 +2917,7 @@ class NodeExecutor:
2873
2917
  package_result: The packaged flow
2874
2918
  library_name: Name of the library to publish to
2875
2919
  file_name: Base file name for the workflow
2920
+ node: Optional node to receive publishing progress events
2876
2921
 
2877
2922
  Returns:
2878
2923
  Tuple of (published_workflow_filename, workflow_result)
@@ -2890,7 +2935,9 @@ class NodeExecutor:
2890
2935
  raise RuntimeError(msg) # noqa: TRY004 - This is a runtime failure, not a type validation error
2891
2936
 
2892
2937
  # Publish to the library
2893
- published_workflow_filename = await self._publish_library_workflow(workflow_result, library_name, file_name)
2938
+ published_workflow_filename = await self._publish_library_workflow(
2939
+ workflow_result, library_name, file_name, node=node
2940
+ )
2894
2941
 
2895
2942
  logger.info("Successfully published workflow to '%s'", published_workflow_filename)
2896
2943
 
@@ -3185,7 +3232,7 @@ class NodeExecutor:
3185
3232
  WorkflowRegistry.generate_new_workflow(str(workflow_path), result.metadata)
3186
3233
 
3187
3234
  delete_request = DeleteWorkflowRequest(name=workflow_name)
3188
- delete_result = GriptapeNodes.handle_request(delete_request)
3235
+ delete_result = await GriptapeNodes.ahandle_request(delete_request)
3189
3236
  if isinstance(delete_result, DeleteWorkflowResultFailure):
3190
3237
  logger.error(
3191
3238
  "Failed to delete workflow '%s'. Error: %s",
@@ -0,0 +1,21 @@
1
+ """Image metadata injection drivers.
2
+
3
+ Provides pluggable drivers for injecting workflow metadata into different image formats.
4
+ Drivers are automatically registered on module import.
5
+ """
6
+
7
+ from griptape_nodes.drivers.image_metadata.base_image_metadata_driver import BaseImageMetadataDriver
8
+ from griptape_nodes.drivers.image_metadata.exif_metadata_driver import ExifMetadataDriver
9
+ from griptape_nodes.drivers.image_metadata.image_metadata_driver_registry import ImageMetadataDriverRegistry
10
+ from griptape_nodes.drivers.image_metadata.png_metadata_driver import PngMetadataDriver
11
+
12
+ # Register core drivers on import
13
+ ImageMetadataDriverRegistry.register_driver(PngMetadataDriver())
14
+ ImageMetadataDriverRegistry.register_driver(ExifMetadataDriver())
15
+
16
+ __all__ = [
17
+ "BaseImageMetadataDriver",
18
+ "ExifMetadataDriver",
19
+ "ImageMetadataDriverRegistry",
20
+ "PngMetadataDriver",
21
+ ]
@@ -0,0 +1,63 @@
1
+ """Base class for image metadata injection and extraction drivers."""
2
+
3
+ from abc import ABC, abstractmethod
4
+
5
+ from PIL import Image
6
+
7
+
8
+ class BaseImageMetadataDriver(ABC):
9
+ """Base class for bidirectional image metadata drivers.
10
+
11
+ Each driver handles a specific metadata protocol (e.g., PNG text chunks, EXIF).
12
+ Drivers support both injection (writing) and extraction (reading) of metadata.
13
+
14
+ Extraction returns ALL available metadata for the format.
15
+ Injection writes only custom key-value pairs, preserving existing metadata.
16
+
17
+ Drivers are registered with ImageMetadataDriverRegistry and selected based on image format.
18
+ """
19
+
20
+ @abstractmethod
21
+ def get_supported_formats(self) -> list[str]:
22
+ """Return list of PIL format strings this driver supports.
23
+
24
+ Returns:
25
+ List of format strings (e.g., ["PNG"], ["JPEG", "TIFF", "MPO"])
26
+ """
27
+ ...
28
+
29
+ @abstractmethod
30
+ def inject_metadata(self, pil_image: Image.Image, metadata: dict[str, str]) -> bytes:
31
+ """Inject metadata into image and return modified image bytes.
32
+
33
+ Args:
34
+ pil_image: PIL Image object to inject metadata into
35
+ metadata: Dictionary of key-value string pairs to inject
36
+
37
+ Returns:
38
+ Image bytes with metadata injected
39
+
40
+ Raises:
41
+ Exception: On metadata injection failures
42
+ """
43
+ ...
44
+
45
+ @abstractmethod
46
+ def extract_metadata(self, pil_image: Image.Image) -> dict[str, str]:
47
+ """Extract ALL metadata from image.
48
+
49
+ Returns all available metadata for the format:
50
+ - PNG: All text chunks
51
+ - EXIF: Standard tags, GPS data, and custom UserComment field
52
+
53
+ The amount and type of metadata depends on the format and what's
54
+ present in the image. Custom metadata injected via inject_metadata()
55
+ is included alongside format-specific metadata.
56
+
57
+ Args:
58
+ pil_image: PIL Image object to extract metadata from
59
+
60
+ Returns:
61
+ Dictionary of metadata key-value pairs, empty dict if none found
62
+ """
63
+ ...
@@ -0,0 +1,218 @@
1
+ """EXIF metadata injection and extraction driver for JPEG, TIFF, and MPO formats."""
2
+
3
+ import json
4
+ import logging
5
+ from io import BytesIO
6
+ from typing import Any
7
+
8
+ from PIL import Image
9
+ from PIL.ExifTags import GPSTAGS, TAGS
10
+
11
+ from griptape_nodes.drivers.image_metadata.base_image_metadata_driver import BaseImageMetadataDriver
12
+
13
+ logger = logging.getLogger("griptape_nodes")
14
+
15
+ # EXIF tag IDs
16
+ EXIF_USERCOMMENT_TAG = 0x9286
17
+ EXIF_GPSINFO_TAG = 0x8825 # 34853
18
+
19
+ # GPS coordinate tuple length
20
+ GPS_COORD_TUPLE_LENGTH = 3
21
+
22
+
23
+ class ExifMetadataDriver(BaseImageMetadataDriver):
24
+ """Bidirectional driver for EXIF metadata.
25
+
26
+ Supports reading ALL EXIF metadata (standard tags, GPS data, custom UserComment)
27
+ and writing custom metadata to EXIF UserComment field as JSON.
28
+ Preserves all existing EXIF data when writing.
29
+ All EXIF-specific logic is encapsulated in this driver.
30
+ """
31
+
32
+ def get_supported_formats(self) -> list[str]:
33
+ """Return list of PIL format strings this driver supports.
34
+
35
+ Returns:
36
+ List containing "JPEG", "TIFF", "MPO"
37
+ """
38
+ return ["JPEG", "TIFF", "MPO", "WEBP"]
39
+
40
+ def inject_metadata(self, pil_image: Image.Image, metadata: dict[str, str]) -> bytes:
41
+ """Inject metadata into EXIF UserComment field as JSON.
42
+
43
+ Serializes metadata dictionary to JSON and stores in EXIF UserComment.
44
+ Preserves all existing EXIF tags.
45
+
46
+ Args:
47
+ pil_image: PIL Image to inject metadata into
48
+ metadata: Dictionary of key-value pairs to inject
49
+
50
+ Returns:
51
+ Image bytes with metadata injected
52
+
53
+ Raises:
54
+ Exception: On EXIF save errors
55
+ """
56
+ # Get existing EXIF data to preserve it
57
+ exif_data = pil_image.getexif()
58
+
59
+ # Serialize metadata to JSON for UserComment field
60
+ metadata_json = json.dumps(metadata, separators=(",", ":"))
61
+
62
+ # Set UserComment field
63
+ exif_data[EXIF_USERCOMMENT_TAG] = metadata_json
64
+
65
+ # Save with updated EXIF
66
+ output_buffer = BytesIO()
67
+ pil_image.save(output_buffer, format=pil_image.format, exif=exif_data)
68
+ return output_buffer.getvalue()
69
+
70
+ def extract_metadata(self, pil_image: Image.Image) -> dict[str, str]:
71
+ """Extract ALL EXIF metadata including standard tags, GPS data, and custom UserComment.
72
+
73
+ Returns combined metadata from all sources:
74
+ - Standard EXIF tags (Make, Model, DateTime, etc.)
75
+ - GPS metadata (prefixed with 'GPS_')
76
+ - Custom metadata from UserComment field (JSON parsed)
77
+
78
+ Args:
79
+ pil_image: PIL Image to extract metadata from
80
+
81
+ Returns:
82
+ Dictionary of all metadata key-value pairs, empty dict if no EXIF data
83
+ """
84
+ exif_data = pil_image.getexif()
85
+ if not exif_data:
86
+ return {}
87
+
88
+ metadata = {}
89
+
90
+ # Extract standard EXIF tags
91
+ self._extract_standard_tags(exif_data, metadata)
92
+
93
+ # Extract GPS metadata
94
+ self._extract_gps_metadata(exif_data, metadata)
95
+
96
+ # Extract custom UserComment metadata
97
+ self._extract_user_comment_metadata(exif_data, metadata)
98
+
99
+ return metadata
100
+
101
+ def _exif_value_to_string(self, value: Any) -> str:
102
+ """Convert EXIF value to readable string format.
103
+
104
+ Args:
105
+ value: EXIF value (can be bytes, tuple, list, int, etc.)
106
+
107
+ Returns:
108
+ String representation of the value
109
+ """
110
+ if isinstance(value, bytes):
111
+ try:
112
+ return value.decode("utf-8", errors="ignore").strip("\x00")
113
+ except Exception:
114
+ return str(value)
115
+ elif isinstance(value, (tuple, list)):
116
+ return ", ".join(str(v) for v in value)
117
+ return str(value)
118
+
119
+ def _format_gps_coordinate(self, coord_tuple: tuple) -> str:
120
+ """Format GPS coordinate tuple to decimal degrees string.
121
+
122
+ Args:
123
+ coord_tuple: Tuple of (degrees, minutes, seconds) as rational numbers
124
+
125
+ Returns:
126
+ Decimal degrees as string
127
+ """
128
+ if not coord_tuple or len(coord_tuple) != GPS_COORD_TUPLE_LENGTH:
129
+ return str(coord_tuple)
130
+
131
+ try:
132
+ degrees = float(coord_tuple[0])
133
+ minutes = float(coord_tuple[1])
134
+ seconds = float(coord_tuple[2])
135
+
136
+ decimal_degrees = degrees + (minutes / 60.0) + (seconds / 3600.0)
137
+ except Exception:
138
+ return self._exif_value_to_string(coord_tuple)
139
+ else:
140
+ return f"{decimal_degrees:.6f}"
141
+
142
+ def _extract_standard_tags(self, exif_data: Any, metadata: dict[str, str]) -> None:
143
+ """Extract standard EXIF tags.
144
+
145
+ Args:
146
+ exif_data: EXIF data object from PIL
147
+ metadata: Dictionary to populate with standard tags (modified in-place)
148
+ """
149
+ for tag_id, value in exif_data.items():
150
+ # Get tag name
151
+ tag_name = TAGS.get(tag_id, f"Tag_{tag_id}")
152
+
153
+ # Skip UserComment and GPSInfo, we'll handle them separately
154
+ if tag_id in (EXIF_USERCOMMENT_TAG, EXIF_GPSINFO_TAG):
155
+ continue
156
+
157
+ # Convert value to string
158
+ metadata[tag_name] = self._exif_value_to_string(value)
159
+
160
+ def _extract_gps_metadata(self, exif_data: Any, metadata: dict[str, str]) -> None:
161
+ """Extract GPS metadata from EXIF data.
162
+
163
+ Formats GPS coordinates as decimal degrees and prefixes all GPS tags with 'GPS_'.
164
+
165
+ Args:
166
+ exif_data: EXIF data object from PIL
167
+ metadata: Dictionary to populate with GPS metadata (modified in-place)
168
+ """
169
+ gps_ifd = exif_data.get_ifd(EXIF_GPSINFO_TAG)
170
+ if not gps_ifd:
171
+ return
172
+
173
+ for gps_tag_id, gps_value in gps_ifd.items():
174
+ gps_tag_name = GPSTAGS.get(gps_tag_id, f"GPSTag_{gps_tag_id}")
175
+
176
+ # Format GPS coordinate values specially
177
+ if gps_tag_id in (1, 2, 3, 4): # GPSLatitudeRef, GPSLatitude, GPSLongitudeRef, GPSLongitude
178
+ if gps_tag_id in (2, 4): # Latitude or Longitude tuple
179
+ metadata[f"GPS_{gps_tag_name}"] = self._format_gps_coordinate(gps_value)
180
+ else:
181
+ metadata[f"GPS_{gps_tag_name}"] = str(gps_value)
182
+ else:
183
+ metadata[f"GPS_{gps_tag_name}"] = self._exif_value_to_string(gps_value)
184
+
185
+ def _extract_user_comment_metadata(self, exif_data: Any, metadata: dict[str, str]) -> None:
186
+ """Extract custom metadata from EXIF UserComment field.
187
+
188
+ Args:
189
+ exif_data: EXIF data object from PIL
190
+ metadata: Dictionary to populate with custom metadata (modified in-place)
191
+ """
192
+ user_comment = exif_data.get(EXIF_USERCOMMENT_TAG)
193
+ if not user_comment:
194
+ return
195
+
196
+ # Parse JSON from UserComment field
197
+ try:
198
+ # Pillow handles the character code prefix, but might return bytes or string
199
+ if isinstance(user_comment, bytes):
200
+ comment_str = user_comment.decode("utf-8", errors="ignore").strip("\x00")
201
+ else:
202
+ comment_str = str(user_comment)
203
+
204
+ custom_metadata = json.loads(comment_str)
205
+ if not isinstance(custom_metadata, dict):
206
+ logger.debug("UserComment is not a dict, skipping custom metadata")
207
+ return
208
+
209
+ # Merge custom metadata directly into metadata dict
210
+ for key, value in custom_metadata.items():
211
+ metadata[key] = str(value)
212
+
213
+ if custom_metadata:
214
+ logger.debug("Merged %d custom metadata entries from EXIF UserComment", len(custom_metadata))
215
+ except (json.JSONDecodeError, ValueError) as e:
216
+ logger.debug("Could not parse UserComment as JSON: %s", e)
217
+ except Exception as e:
218
+ logger.debug("Unexpected error parsing UserComment: %s", e)