graphai-lib 0.0.5__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/PKG-INFO +1 -1
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai/graph.py +55 -23
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai_lib.egg-info/PKG-INFO +1 -1
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/pyproject.toml +1 -1
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/README.md +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai/__init__.py +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai/callback.py +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai/nodes/__init__.py +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai/nodes/base.py +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai/utils.py +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai_lib.egg-info/SOURCES.txt +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai_lib.egg-info/dependency_links.txt +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai_lib.egg-info/requires.txt +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/graphai_lib.egg-info/top_level.txt +0 -0
- {graphai_lib-0.0.5 → graphai_lib-0.0.6}/setup.cfg +0 -0
@@ -1,19 +1,33 @@
|
|
1
|
-
from typing import List, Dict, Any, Optional
|
2
|
-
from graphai.nodes.base import _Node
|
1
|
+
from typing import List, Dict, Any, Optional, Protocol, Type
|
3
2
|
from graphai.callback import Callback
|
4
3
|
from graphai.utils import logger
|
5
4
|
|
6
5
|
|
6
|
+
class NodeProtocol(Protocol):
|
7
|
+
"""Protocol defining the interface of a decorated node."""
|
8
|
+
name: str
|
9
|
+
is_start: bool
|
10
|
+
is_end: bool
|
11
|
+
is_router: bool
|
12
|
+
stream: bool
|
13
|
+
|
14
|
+
async def invoke(
|
15
|
+
self,
|
16
|
+
input: Dict[str, Any],
|
17
|
+
callback: Optional[Callback] = None,
|
18
|
+
state: Optional[Dict[str, Any]] = None
|
19
|
+
) -> Dict[str, Any]: ...
|
20
|
+
|
21
|
+
|
7
22
|
class Graph:
|
8
23
|
def __init__(
|
9
24
|
self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
|
10
25
|
):
|
11
|
-
self.nodes: Dict[str,
|
26
|
+
self.nodes: Dict[str, NodeProtocol] = {}
|
12
27
|
self.edges: List[Any] = []
|
13
|
-
self.start_node: Optional[
|
14
|
-
self.end_nodes: List[
|
15
|
-
self.Callback = Callback
|
16
|
-
self.callback = None
|
28
|
+
self.start_node: Optional[NodeProtocol] = None
|
29
|
+
self.end_nodes: List[NodeProtocol] = []
|
30
|
+
self.Callback: Type[Callback] = Callback
|
17
31
|
self.max_steps = max_steps
|
18
32
|
self.state = initial_state or {}
|
19
33
|
|
@@ -34,7 +48,7 @@ class Graph:
|
|
34
48
|
"""Reset the graph state to an empty dict."""
|
35
49
|
self.state = {}
|
36
50
|
|
37
|
-
def add_node(self, node):
|
51
|
+
def add_node(self, node: NodeProtocol):
|
38
52
|
if node.name in self.nodes:
|
39
53
|
raise Exception(f"Node with name '{node.name}' already exists.")
|
40
54
|
self.nodes[node.name] = node
|
@@ -49,7 +63,7 @@ class Graph:
|
|
49
63
|
if node.is_end:
|
50
64
|
self.end_nodes.append(node)
|
51
65
|
|
52
|
-
def add_edge(self, source:
|
66
|
+
def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str):
|
53
67
|
"""Adds an edge between two nodes that already exist in the graph.
|
54
68
|
|
55
69
|
Args:
|
@@ -83,7 +97,7 @@ class Graph:
|
|
83
97
|
self.edges.append(edge)
|
84
98
|
|
85
99
|
def add_router(
|
86
|
-
self, sources: list[
|
100
|
+
self, sources: list[NodeProtocol], router: NodeProtocol, destinations: List[NodeProtocol]
|
87
101
|
):
|
88
102
|
if not router.is_router:
|
89
103
|
raise TypeError("A router object must be passed to the router parameter.")
|
@@ -91,10 +105,10 @@ class Graph:
|
|
91
105
|
for destination in destinations:
|
92
106
|
self.add_edge(router, destination)
|
93
107
|
|
94
|
-
def set_start_node(self, node:
|
108
|
+
def set_start_node(self, node: NodeProtocol):
|
95
109
|
self.start_node = node
|
96
110
|
|
97
|
-
def set_end_node(self, node:
|
111
|
+
def set_end_node(self, node: NodeProtocol):
|
98
112
|
self.end_node = node
|
99
113
|
|
100
114
|
def compile(self):
|
@@ -116,11 +130,15 @@ class Graph:
|
|
116
130
|
f"Instead, got {type(output)} from '{output}'."
|
117
131
|
)
|
118
132
|
|
119
|
-
async def execute(self, input):
|
133
|
+
async def execute(self, input, callback: Callback | None = None):
|
120
134
|
# TODO JB: may need to add init callback here to init the queue on every new execution
|
121
|
-
if
|
122
|
-
|
135
|
+
if callback is None:
|
136
|
+
callback = self.get_callback()
|
137
|
+
|
138
|
+
# Type assertion to tell the type checker that start_node is not None after compile()
|
139
|
+
assert self.start_node is not None, "Graph must be compiled before execution"
|
123
140
|
current_node = self.start_node
|
141
|
+
|
124
142
|
state = input
|
125
143
|
# Don't reset the graph state if it was initialized with initial_state
|
126
144
|
steps = 0
|
@@ -128,13 +146,13 @@ class Graph:
|
|
128
146
|
# we invoke the node here
|
129
147
|
if current_node.stream:
|
130
148
|
# add callback tokens and param here if we are streaming
|
131
|
-
await
|
149
|
+
await callback.start_node(node_name=current_node.name)
|
132
150
|
# Include graph's internal state in the node execution context
|
133
151
|
output = await current_node.invoke(
|
134
|
-
input=state, callback=
|
152
|
+
input=state, callback=callback, state=self.state
|
135
153
|
)
|
136
154
|
self._validate_output(output=output, node_name=current_node.name)
|
137
|
-
await
|
155
|
+
await callback.end_node(node_name=current_node.name)
|
138
156
|
else:
|
139
157
|
# Include graph's internal state in the node execution context
|
140
158
|
output = await current_node.invoke(input=state, state=self.state)
|
@@ -159,16 +177,30 @@ class Graph:
|
|
159
177
|
"by setting `max_steps` when initializing the Graph object."
|
160
178
|
)
|
161
179
|
# TODO JB: may need to add end callback here to close the queue for every execution
|
162
|
-
if
|
163
|
-
await
|
180
|
+
if callback and "callback" in state:
|
181
|
+
await callback.close()
|
164
182
|
del state["callback"]
|
165
183
|
return state
|
166
184
|
|
167
185
|
def get_callback(self):
|
168
|
-
|
169
|
-
|
186
|
+
"""Get a new instance of the callback class.
|
187
|
+
|
188
|
+
:return: A new instance of the callback class.
|
189
|
+
:rtype: Callback
|
190
|
+
"""
|
191
|
+
callback = self.Callback()
|
192
|
+
return callback
|
193
|
+
|
194
|
+
def set_callback(self, callback_class: Type[Callback]):
|
195
|
+
"""Set the callback class that is returned by the `get_callback` method and used
|
196
|
+
as the default callback when no callback is passed to the `execute` method.
|
197
|
+
|
198
|
+
:param callback_class: The callback class to use as the default callback.
|
199
|
+
:type callback_class: Type[Callback]
|
200
|
+
"""
|
201
|
+
self.Callback = callback_class
|
170
202
|
|
171
|
-
def _get_node_by_name(self, node_name: str) ->
|
203
|
+
def _get_node_by_name(self, node_name: str) -> NodeProtocol:
|
172
204
|
"""Get a node by its name.
|
173
205
|
|
174
206
|
Args:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|