graphai-lib 0.0.5__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,19 +1,33 @@
1
- from typing import List, Dict, Any, Optional
2
- from graphai.nodes.base import _Node
1
+ from typing import List, Dict, Any, Optional, Protocol, Type
3
2
  from graphai.callback import Callback
4
3
  from graphai.utils import logger
5
4
 
6
5
 
6
+ class NodeProtocol(Protocol):
7
+ """Protocol defining the interface of a decorated node."""
8
+ name: str
9
+ is_start: bool
10
+ is_end: bool
11
+ is_router: bool
12
+ stream: bool
13
+
14
+ async def invoke(
15
+ self,
16
+ input: Dict[str, Any],
17
+ callback: Optional[Callback] = None,
18
+ state: Optional[Dict[str, Any]] = None
19
+ ) -> Dict[str, Any]: ...
20
+
21
+
7
22
  class Graph:
8
23
  def __init__(
9
24
  self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
10
25
  ):
11
- self.nodes: Dict[str, _Node] = {}
26
+ self.nodes: Dict[str, NodeProtocol] = {}
12
27
  self.edges: List[Any] = []
13
- self.start_node: Optional[_Node] = None
14
- self.end_nodes: List[_Node] = []
15
- self.Callback = Callback
16
- self.callback = None
28
+ self.start_node: Optional[NodeProtocol] = None
29
+ self.end_nodes: List[NodeProtocol] = []
30
+ self.Callback: Type[Callback] = Callback
17
31
  self.max_steps = max_steps
18
32
  self.state = initial_state or {}
19
33
 
@@ -34,7 +48,7 @@ class Graph:
34
48
  """Reset the graph state to an empty dict."""
35
49
  self.state = {}
36
50
 
37
- def add_node(self, node):
51
+ def add_node(self, node: NodeProtocol):
38
52
  if node.name in self.nodes:
39
53
  raise Exception(f"Node with name '{node.name}' already exists.")
40
54
  self.nodes[node.name] = node
@@ -49,7 +63,7 @@ class Graph:
49
63
  if node.is_end:
50
64
  self.end_nodes.append(node)
51
65
 
52
- def add_edge(self, source: _Node | str, destination: _Node | str):
66
+ def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str):
53
67
  """Adds an edge between two nodes that already exist in the graph.
54
68
 
55
69
  Args:
@@ -83,7 +97,7 @@ class Graph:
83
97
  self.edges.append(edge)
84
98
 
85
99
  def add_router(
86
- self, sources: list[_Node], router: _Node, destinations: List[_Node]
100
+ self, sources: list[NodeProtocol], router: NodeProtocol, destinations: List[NodeProtocol]
87
101
  ):
88
102
  if not router.is_router:
89
103
  raise TypeError("A router object must be passed to the router parameter.")
@@ -91,10 +105,10 @@ class Graph:
91
105
  for destination in destinations:
92
106
  self.add_edge(router, destination)
93
107
 
94
- def set_start_node(self, node: _Node):
108
+ def set_start_node(self, node: NodeProtocol):
95
109
  self.start_node = node
96
110
 
97
- def set_end_node(self, node: _Node):
111
+ def set_end_node(self, node: NodeProtocol):
98
112
  self.end_node = node
99
113
 
100
114
  def compile(self):
@@ -116,11 +130,15 @@ class Graph:
116
130
  f"Instead, got {type(output)} from '{output}'."
117
131
  )
118
132
 
119
- async def execute(self, input):
133
+ async def execute(self, input, callback: Callback | None = None):
120
134
  # TODO JB: may need to add init callback here to init the queue on every new execution
121
- if self.callback is None:
122
- self.callback = self.get_callback()
135
+ if callback is None:
136
+ callback = self.get_callback()
137
+
138
+ # Type assertion to tell the type checker that start_node is not None after compile()
139
+ assert self.start_node is not None, "Graph must be compiled before execution"
123
140
  current_node = self.start_node
141
+
124
142
  state = input
125
143
  # Don't reset the graph state if it was initialized with initial_state
126
144
  steps = 0
@@ -128,13 +146,13 @@ class Graph:
128
146
  # we invoke the node here
129
147
  if current_node.stream:
130
148
  # add callback tokens and param here if we are streaming
131
- await self.callback.start_node(node_name=current_node.name)
149
+ await callback.start_node(node_name=current_node.name)
132
150
  # Include graph's internal state in the node execution context
133
151
  output = await current_node.invoke(
134
- input=state, callback=self.callback, state=self.state
152
+ input=state, callback=callback, state=self.state
135
153
  )
136
154
  self._validate_output(output=output, node_name=current_node.name)
137
- await self.callback.end_node(node_name=current_node.name)
155
+ await callback.end_node(node_name=current_node.name)
138
156
  else:
139
157
  # Include graph's internal state in the node execution context
140
158
  output = await current_node.invoke(input=state, state=self.state)
@@ -159,16 +177,30 @@ class Graph:
159
177
  "by setting `max_steps` when initializing the Graph object."
160
178
  )
161
179
  # TODO JB: may need to add end callback here to close the queue for every execution
162
- if self.callback and "callback" in state:
163
- await self.callback.close()
180
+ if callback and "callback" in state:
181
+ await callback.close()
164
182
  del state["callback"]
165
183
  return state
166
184
 
167
185
  def get_callback(self):
168
- self.callback = self.Callback()
169
- return self.callback
186
+ """Get a new instance of the callback class.
187
+
188
+ :return: A new instance of the callback class.
189
+ :rtype: Callback
190
+ """
191
+ callback = self.Callback()
192
+ return callback
193
+
194
+ def set_callback(self, callback_class: Type[Callback]):
195
+ """Set the callback class that is returned by the `get_callback` method and used
196
+ as the default callback when no callback is passed to the `execute` method.
197
+
198
+ :param callback_class: The callback class to use as the default callback.
199
+ :type callback_class: Type[Callback]
200
+ """
201
+ self.Callback = callback_class
170
202
 
171
- def _get_node_by_name(self, node_name: str) -> _Node:
203
+ def _get_node_by_name(self, node_name: str) -> NodeProtocol:
172
204
  """Get a node by its name.
173
205
 
174
206
  Args:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.5
3
+ Version: 0.0.6
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "graphai-lib"
3
- version = "0.0.5"
3
+ version = "0.0.6"
4
4
  description = "Not an AI framework"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.14"
File without changes
File without changes