graphai-lib 0.0.8__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/PKG-INFO +1 -1
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai/callback.py +142 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai/graph.py +243 -39
- graphai_lib-0.0.9/graphai/py.typed +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai/utils.py +127 -21
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai_lib.egg-info/PKG-INFO +1 -1
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai_lib.egg-info/SOURCES.txt +1 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/pyproject.toml +4 -1
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/README.md +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai/__init__.py +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai/nodes/__init__.py +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai/nodes/base.py +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai_lib.egg-info/dependency_links.txt +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai_lib.egg-info/requires.txt +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/graphai_lib.egg-info/top_level.txt +0 -0
- {graphai_lib-0.0.8 → graphai_lib-0.0.9}/setup.cfg +0 -0
@@ -1,13 +1,53 @@
|
|
1
1
|
import asyncio
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from enum import Enum
|
2
4
|
from pydantic import Field
|
3
5
|
from typing import Any
|
4
6
|
from collections.abc import AsyncIterator
|
7
|
+
import warnings
|
5
8
|
|
6
9
|
|
7
10
|
log_stream = True
|
8
11
|
|
9
12
|
|
13
|
+
class StrEnum(Enum):
|
14
|
+
def __str__(self) -> str:
|
15
|
+
return str(self.value)
|
16
|
+
|
17
|
+
class GraphEventType(StrEnum):
|
18
|
+
START = "start"
|
19
|
+
END = "end"
|
20
|
+
START_NODE = "start_node"
|
21
|
+
END_NODE = "end_node"
|
22
|
+
CALLBACK = "callback"
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class GraphEvent:
|
26
|
+
"""A graph event emitted for specific graph events such as start node or end node,
|
27
|
+
and used by the callback to emit user-defined events.
|
28
|
+
|
29
|
+
:param type: The type of event, can be start_node, end_node, or callback.
|
30
|
+
:type type: GraphEventType
|
31
|
+
:param identifier: The identifier of the event, this is set typically by a callback
|
32
|
+
handler and can be used to distinguish between different events. For example, a
|
33
|
+
conversation/session ID could be used.
|
34
|
+
:type identifier: str
|
35
|
+
:param token: The token associated with the event, such as LLM streamed output.
|
36
|
+
:type token: str | None
|
37
|
+
:param params: The parameters associated with the event, such as tool call parameters
|
38
|
+
or event metadata.
|
39
|
+
:type params: dict[str, Any] | None
|
40
|
+
"""
|
41
|
+
type: GraphEventType
|
42
|
+
identifier: str
|
43
|
+
token: str | None = None
|
44
|
+
params: dict[str, Any] | None = None
|
45
|
+
|
46
|
+
|
10
47
|
class Callback:
|
48
|
+
"""The original callback handler class. Outputs a stream of structured text
|
49
|
+
tokens. It is recommended to use the newer `EventCallback` handler instead.
|
50
|
+
"""
|
11
51
|
identifier: str = Field(
|
12
52
|
default="graphai",
|
13
53
|
description=(
|
@@ -67,6 +107,11 @@ class Callback:
|
|
67
107
|
special_token_format: str = "<{identifier}:{token}:{params}>",
|
68
108
|
token_format: str = "{token}",
|
69
109
|
):
|
110
|
+
warnings.warn(
|
111
|
+
"The `Callback` class is deprecated and will be removed in " +
|
112
|
+
"v0.1.0. Use the `EventCallback` class instead.",
|
113
|
+
DeprecationWarning
|
114
|
+
)
|
70
115
|
self.identifier = identifier
|
71
116
|
self.special_token_format = special_token_format
|
72
117
|
self.token_format = token_format
|
@@ -198,3 +243,100 @@ class Callback:
|
|
198
243
|
return self.special_token_format.format(
|
199
244
|
identifier=identifier, token=name, params=params_str
|
200
245
|
)
|
246
|
+
|
247
|
+
|
248
|
+
class EventCallback(Callback):
|
249
|
+
"""The event callback handler class. Outputs a stream of structured text
|
250
|
+
tokens. It is recommended to use the newer `EventCallback` handler instead.
|
251
|
+
"""
|
252
|
+
def __init__(
|
253
|
+
self,
|
254
|
+
identifier: str = "graphai",
|
255
|
+
special_token_format: str | None = None,
|
256
|
+
token_format: str | None = None,
|
257
|
+
):
|
258
|
+
warnings.warn(
|
259
|
+
"The `special_token_format` and `token_format` parameters are " +
|
260
|
+
"deprecated and will be removed in v0.1.0.",
|
261
|
+
DeprecationWarning
|
262
|
+
)
|
263
|
+
if special_token_format is None:
|
264
|
+
special_token_format = "<{identifier}:{token}:{params}>"
|
265
|
+
if token_format is None:
|
266
|
+
token_format = "{token}"
|
267
|
+
super().__init__(identifier, special_token_format, token_format)
|
268
|
+
self.events: list[GraphEvent] = []
|
269
|
+
|
270
|
+
def __call__(self, token: str, node_name: str | None = None):
|
271
|
+
if self._done:
|
272
|
+
raise RuntimeError("Cannot add tokens to a closed stream")
|
273
|
+
self._check_node_name(node_name=node_name)
|
274
|
+
event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
|
275
|
+
# otherwise we just assume node is correct and send token
|
276
|
+
self.queue.put_nowait(event)
|
277
|
+
|
278
|
+
async def acall(self, token: str, node_name: str | None = None):
|
279
|
+
# TODO JB: do we need to have `node_name` param?
|
280
|
+
if self._done:
|
281
|
+
raise RuntimeError("Cannot add tokens to a closed stream")
|
282
|
+
self._check_node_name(node_name=node_name)
|
283
|
+
event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
|
284
|
+
# otherwise we just assume node is correct and send token
|
285
|
+
self.queue.put_nowait(event)
|
286
|
+
|
287
|
+
async def aiter(self) -> AsyncIterator[GraphEvent]: # type: ignore[override]
|
288
|
+
"""Used by receiver to get the tokens from the stream queue. Creates
|
289
|
+
a generator that yields tokens from the queue until the END token is
|
290
|
+
received.
|
291
|
+
"""
|
292
|
+
while True: # Keep going until we see the END token
|
293
|
+
try:
|
294
|
+
if self._done and self.queue.empty():
|
295
|
+
break
|
296
|
+
token = await self.queue.get()
|
297
|
+
yield token
|
298
|
+
self.queue.task_done()
|
299
|
+
if token.type == GraphEventType.END:
|
300
|
+
break
|
301
|
+
except asyncio.CancelledError:
|
302
|
+
break
|
303
|
+
self._done = True # Mark as done after processing all tokens
|
304
|
+
|
305
|
+
async def start_node(self, node_name: str, active: bool = True):
|
306
|
+
"""Starts a new node and emits the start token."""
|
307
|
+
if self._done:
|
308
|
+
raise RuntimeError("Cannot start node on a closed stream")
|
309
|
+
self.current_node_name = node_name
|
310
|
+
if self.first_token:
|
311
|
+
self.first_token = False
|
312
|
+
self.active = active
|
313
|
+
if self.active:
|
314
|
+
token = GraphEvent(type=GraphEventType.START_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
|
315
|
+
self.queue.put_nowait(token)
|
316
|
+
|
317
|
+
async def end_node(self, node_name: str):
|
318
|
+
"""Emits the end token for the current node."""
|
319
|
+
if self._done:
|
320
|
+
raise RuntimeError("Cannot end node on a closed stream")
|
321
|
+
# self.current_node_name = node_name
|
322
|
+
if self.active:
|
323
|
+
token = GraphEvent(type=GraphEventType.END_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
|
324
|
+
self.queue.put_nowait(token)
|
325
|
+
|
326
|
+
async def close(self):
|
327
|
+
"""Close the stream and prevent further tokens from being added.
|
328
|
+
This will send an END token and set the done flag to True.
|
329
|
+
"""
|
330
|
+
if self._done:
|
331
|
+
return
|
332
|
+
end_token = GraphEvent(type=GraphEventType.END, identifier=self.identifier)
|
333
|
+
self._done = True # Set done before putting the end token
|
334
|
+
self.queue.put_nowait(end_token)
|
335
|
+
# Don't wait for queue.join() as it can cause deadlock
|
336
|
+
# The stream will close when aiter processes the END token
|
337
|
+
|
338
|
+
async def _build_special_token(
|
339
|
+
self, name: str, params: dict[str, Any] | None = None
|
340
|
+
):
|
341
|
+
raise NotImplementedError("This method is not implemented for the `EventCallback` class.")
|
342
|
+
|
@@ -1,8 +1,24 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
import asyncio
|
3
|
+
from typing import Any, Iterable, Protocol
|
4
|
+
from graphlib import TopologicalSorter, CycleError
|
2
5
|
from graphai.callback import Callback
|
3
6
|
from graphai.utils import logger
|
4
7
|
|
5
8
|
|
9
|
+
# to fix mypy error
|
10
|
+
class _HasName(Protocol):
|
11
|
+
name: str
|
12
|
+
|
13
|
+
|
14
|
+
class GraphError(Exception):
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
class GraphCompileError(GraphError):
|
19
|
+
pass
|
20
|
+
|
21
|
+
|
6
22
|
class NodeProtocol(Protocol):
|
7
23
|
"""Protocol defining the interface of a decorated node."""
|
8
24
|
|
@@ -20,6 +36,26 @@ class NodeProtocol(Protocol):
|
|
20
36
|
) -> dict[str, Any]: ...
|
21
37
|
|
22
38
|
|
39
|
+
def _name_of(x: Any) -> str | None:
|
40
|
+
"""Return the node name if x is a str or has .name, else None."""
|
41
|
+
if x is None:
|
42
|
+
return None
|
43
|
+
if isinstance(x, str):
|
44
|
+
return x
|
45
|
+
name = getattr(x, "name", None)
|
46
|
+
return name if isinstance(name, str) else None
|
47
|
+
|
48
|
+
|
49
|
+
def _require_name(x: Any, kind: str) -> str:
|
50
|
+
"""Like _name_of, but raises a helpful compile error when missing."""
|
51
|
+
s = _name_of(x)
|
52
|
+
if s is None:
|
53
|
+
raise GraphCompileError(
|
54
|
+
f"Edge {kind} must be a node name (str) or object with .name"
|
55
|
+
)
|
56
|
+
return s
|
57
|
+
|
58
|
+
|
23
59
|
class Graph:
|
24
60
|
def __init__(
|
25
61
|
self, max_steps: int = 10, initial_state: dict[str, Any] | None = None
|
@@ -28,7 +64,7 @@ class Graph:
|
|
28
64
|
self.edges: list[Any] = []
|
29
65
|
self.start_node: NodeProtocol | None = None
|
30
66
|
self.end_nodes: list[NodeProtocol] = []
|
31
|
-
self.Callback:
|
67
|
+
self.Callback: type[Callback] = Callback
|
32
68
|
self.max_steps = max_steps
|
33
69
|
self.state = initial_state or {}
|
34
70
|
|
@@ -37,22 +73,22 @@ class Graph:
|
|
37
73
|
"""Get the current graph state."""
|
38
74
|
return self.state
|
39
75
|
|
40
|
-
def set_state(self, state: dict[str, Any]) ->
|
76
|
+
def set_state(self, state: dict[str, Any]) -> Graph:
|
41
77
|
"""Set the graph state."""
|
42
78
|
self.state = state
|
43
79
|
return self
|
44
80
|
|
45
|
-
def update_state(self, values: dict[str, Any]) ->
|
81
|
+
def update_state(self, values: dict[str, Any]) -> Graph:
|
46
82
|
"""Update the graph state with new values."""
|
47
83
|
self.state.update(values)
|
48
84
|
return self
|
49
85
|
|
50
|
-
def reset_state(self) ->
|
86
|
+
def reset_state(self) -> Graph:
|
51
87
|
"""Reset the graph state to an empty dict."""
|
52
88
|
self.state = {}
|
53
89
|
return self
|
54
90
|
|
55
|
-
def add_node(self, node: NodeProtocol) ->
|
91
|
+
def add_node(self, node: NodeProtocol) -> Graph:
|
56
92
|
if node.name in self.nodes:
|
57
93
|
raise Exception(f"Node with name '{node.name}' already exists.")
|
58
94
|
self.nodes[node.name] = node
|
@@ -68,7 +104,9 @@ class Graph:
|
|
68
104
|
self.end_nodes.append(node)
|
69
105
|
return self
|
70
106
|
|
71
|
-
def add_edge(
|
107
|
+
def add_edge(
|
108
|
+
self, source: NodeProtocol | str, destination: NodeProtocol | str
|
109
|
+
) -> Graph:
|
72
110
|
"""Adds an edge between two nodes that already exist in the graph.
|
73
111
|
|
74
112
|
Args:
|
@@ -89,9 +127,7 @@ class Graph:
|
|
89
127
|
else:
|
90
128
|
source_name = str(source)
|
91
129
|
if source_node is None:
|
92
|
-
raise ValueError(
|
93
|
-
f"Node with name '{source_name}' not found."
|
94
|
-
)
|
130
|
+
raise ValueError(f"Node with name '{source_name}' not found.")
|
95
131
|
# get destination node from graph
|
96
132
|
destination_name: str
|
97
133
|
if isinstance(destination, str):
|
@@ -105,9 +141,7 @@ class Graph:
|
|
105
141
|
else:
|
106
142
|
destination_name = str(destination)
|
107
143
|
if destination_node is None:
|
108
|
-
raise ValueError(
|
109
|
-
f"Node with name '{destination_name}' not found."
|
110
|
-
)
|
144
|
+
raise ValueError(f"Node with name '{destination_name}' not found.")
|
111
145
|
edge = Edge(source_node, destination_node)
|
112
146
|
self.edges.append(edge)
|
113
147
|
return self
|
@@ -117,7 +151,7 @@ class Graph:
|
|
117
151
|
sources: list[NodeProtocol],
|
118
152
|
router: NodeProtocol,
|
119
153
|
destinations: list[NodeProtocol],
|
120
|
-
) ->
|
154
|
+
) -> Graph:
|
121
155
|
if not router.is_router:
|
122
156
|
raise TypeError("A router object must be passed to the router parameter.")
|
123
157
|
[self.add_edge(source, router) for source in sources]
|
@@ -125,26 +159,162 @@ class Graph:
|
|
125
159
|
self.add_edge(router, destination)
|
126
160
|
return self
|
127
161
|
|
128
|
-
def set_start_node(self, node: NodeProtocol) ->
|
162
|
+
def set_start_node(self, node: NodeProtocol) -> Graph:
|
129
163
|
self.start_node = node
|
130
164
|
return self
|
131
165
|
|
132
|
-
def set_end_node(self, node: NodeProtocol) ->
|
166
|
+
def set_end_node(self, node: NodeProtocol) -> Graph:
|
133
167
|
self.end_node = node
|
134
168
|
return self
|
135
169
|
|
136
|
-
def compile(self) ->
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
170
|
+
def compile(self, *, strict: bool = False) -> Graph:
|
171
|
+
"""
|
172
|
+
Validate the graph:
|
173
|
+
- exactly one start node present (or Graph.start_node set)
|
174
|
+
- at least one end node present
|
175
|
+
- all edges reference known nodes
|
176
|
+
- all nodes reachable from the start
|
177
|
+
(optional) **no cycles** when strict=True
|
178
|
+
Returns self on success; raises GraphCompileError otherwise.
|
179
|
+
"""
|
180
|
+
# nodes map
|
181
|
+
nodes = getattr(self, "nodes", None)
|
182
|
+
if not isinstance(nodes, dict) or not nodes:
|
183
|
+
raise GraphCompileError("No nodes have been added to the graph")
|
184
|
+
|
185
|
+
start_name: str | None = None
|
186
|
+
# Bind and narrow the attribute for mypy
|
187
|
+
start_node: _HasName | None = getattr(self, "start_node", None)
|
188
|
+
if start_node is not None:
|
189
|
+
start_name = start_node.name
|
190
|
+
else:
|
191
|
+
starts = [
|
192
|
+
name
|
193
|
+
for name, n in nodes.items()
|
194
|
+
if getattr(n, "is_start", False) or getattr(n, "start", False)
|
195
|
+
]
|
196
|
+
if len(starts) > 1:
|
197
|
+
raise GraphCompileError(f"Multiple start nodes defined: {starts}")
|
198
|
+
if len(starts) == 1:
|
199
|
+
start_name = starts[0]
|
200
|
+
|
201
|
+
if not start_name:
|
202
|
+
raise GraphCompileError("No start node defined")
|
203
|
+
|
204
|
+
# at least one end node
|
205
|
+
if not any(
|
206
|
+
getattr(n, "is_end", False) or getattr(n, "end", False)
|
207
|
+
for n in nodes.values()
|
208
|
+
):
|
209
|
+
raise GraphCompileError("No end node defined")
|
210
|
+
|
211
|
+
# normalize edges into adjacency {src: set(dst)}
|
212
|
+
raw_edges = getattr(self, "edges", None)
|
213
|
+
adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
|
214
|
+
|
215
|
+
def _add_edge(src: str, dst: str) -> None:
|
216
|
+
if src not in nodes:
|
217
|
+
raise GraphCompileError(f"Edge references unknown source node: {src}")
|
218
|
+
if dst not in nodes:
|
219
|
+
raise GraphCompileError(
|
220
|
+
f"Edge from {src} references unknown node(s): ['{dst}']"
|
221
|
+
)
|
222
|
+
adj[src].add(dst)
|
223
|
+
|
224
|
+
if raw_edges is None:
|
225
|
+
pass
|
226
|
+
elif isinstance(raw_edges, dict):
|
227
|
+
for raw_src, dsts in raw_edges.items():
|
228
|
+
src = _require_name(raw_src, "source")
|
229
|
+
dst_iter = (
|
230
|
+
[dsts]
|
231
|
+
if isinstance(dsts, (str,)) or getattr(dsts, "name", None)
|
232
|
+
else list(dsts)
|
233
|
+
)
|
234
|
+
for d in dst_iter:
|
235
|
+
dst = _require_name(d, "destination")
|
236
|
+
_add_edge(src, dst)
|
237
|
+
else:
|
238
|
+
# generic iterable of “edge records”
|
239
|
+
try:
|
240
|
+
iterator = iter(raw_edges)
|
241
|
+
except TypeError:
|
242
|
+
raise GraphCompileError("Internal edge map has unsupported type")
|
243
|
+
|
244
|
+
for item in iterator:
|
245
|
+
# (src, dst) OR (src, Iterable[dst])
|
246
|
+
if isinstance(item, (tuple, list)) and len(item) == 2:
|
247
|
+
raw_src, rhs = item
|
248
|
+
src = _require_name(raw_src, "source")
|
249
|
+
|
250
|
+
if isinstance(rhs, str) or getattr(rhs, "name", None):
|
251
|
+
dst = _require_name(rhs, "destination")
|
252
|
+
_add_edge(src, rhs)
|
253
|
+
else:
|
254
|
+
# assume iterable of dsts (strings or node-like)
|
255
|
+
try:
|
256
|
+
for d in rhs:
|
257
|
+
dst = _require_name(d, "destination")
|
258
|
+
_add_edge(src, d)
|
259
|
+
except TypeError:
|
260
|
+
raise GraphCompileError(
|
261
|
+
"Edge tuple second item must be a destination or an iterable of destinations"
|
262
|
+
)
|
263
|
+
continue
|
264
|
+
|
265
|
+
# Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
|
266
|
+
if isinstance(item, dict):
|
267
|
+
src = _require_name(item.get("source", item.get("src")), "source")
|
268
|
+
dst = _require_name(
|
269
|
+
item.get("destination", item.get("dst")), "destination"
|
270
|
+
)
|
271
|
+
_add_edge(src, dst)
|
272
|
+
continue
|
273
|
+
|
274
|
+
# Object with attributes .source/.destination (or .src/.dst)
|
275
|
+
if hasattr(item, "source") or hasattr(item, "src"):
|
276
|
+
src = _require_name(
|
277
|
+
getattr(item, "source", getattr(item, "src", None)), "source"
|
278
|
+
)
|
279
|
+
dst = _require_name(
|
280
|
+
getattr(item, "destination", getattr(item, "dst", None)),
|
281
|
+
"destination",
|
282
|
+
)
|
283
|
+
_add_edge(src, dst)
|
284
|
+
continue
|
285
|
+
|
286
|
+
# If none matched, this is an unsupported edge record
|
287
|
+
raise GraphCompileError(
|
288
|
+
"Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
|
289
|
+
"(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
|
290
|
+
)
|
291
|
+
|
292
|
+
# reachability from start
|
293
|
+
seen: set[str] = set()
|
294
|
+
stack = [start_name]
|
295
|
+
while stack:
|
296
|
+
cur = stack.pop()
|
297
|
+
if cur in seen:
|
298
|
+
continue
|
299
|
+
seen.add(cur)
|
300
|
+
stack.extend(adj.get(cur, ()))
|
301
|
+
|
302
|
+
unreachable = sorted(set(nodes.keys()) - seen)
|
303
|
+
if unreachable:
|
304
|
+
raise GraphCompileError(f"Unreachable nodes: {unreachable}")
|
305
|
+
|
306
|
+
# optional cycle detection (strict mode)
|
307
|
+
if strict:
|
308
|
+
preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
|
309
|
+
for s, ds in adj.items():
|
310
|
+
for d in ds:
|
311
|
+
preds[d].add(s)
|
312
|
+
try:
|
313
|
+
list(TopologicalSorter(preds).static_order())
|
314
|
+
except CycleError as e:
|
315
|
+
raise GraphCompileError("cycle detected in graph (strict mode)") from e
|
144
316
|
|
145
|
-
|
146
|
-
# Implement validation logic, e.g., checking for cycles, disconnected components, etc.
|
147
|
-
return True
|
317
|
+
return self
|
148
318
|
|
149
319
|
def _validate_output(self, output: dict[str, Any], node_name: str):
|
150
320
|
if not isinstance(output, dict):
|
@@ -205,6 +375,32 @@ class Graph:
|
|
205
375
|
del state["callback"]
|
206
376
|
return state
|
207
377
|
|
378
|
+
async def execute_many(
|
379
|
+
self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
|
380
|
+
) -> list[Any]:
|
381
|
+
"""Execute the graph on many inputs concurrently.
|
382
|
+
|
383
|
+
:param inputs: An iterable of input dicts to feed into the graph.
|
384
|
+
:type inputs: Iterable[dict]
|
385
|
+
:param concurrency: Maximum number of graph executions to run at once.
|
386
|
+
:type concurrency: int
|
387
|
+
:param state: Optional shared state to pass to each execution.
|
388
|
+
If you want isolated state per execution, pass None
|
389
|
+
and the graph's normal semantics will apply.
|
390
|
+
:type state: Optional[Any]
|
391
|
+
:returns: The list of results in the same order as ``inputs``.
|
392
|
+
:rtype: list[Any]
|
393
|
+
"""
|
394
|
+
|
395
|
+
sem = asyncio.Semaphore(concurrency)
|
396
|
+
|
397
|
+
async def _run_one(inp: dict[str, Any]) -> Any:
|
398
|
+
async with sem:
|
399
|
+
return await self.execute(input=inp)
|
400
|
+
|
401
|
+
tasks = [asyncio.create_task(_run_one(i)) for i in inputs]
|
402
|
+
return await asyncio.gather(*tasks)
|
403
|
+
|
208
404
|
def get_callback(self):
|
209
405
|
"""Get a new instance of the callback class.
|
210
406
|
|
@@ -219,7 +415,7 @@ class Graph:
|
|
219
415
|
as the default callback when no callback is passed to the `execute` method.
|
220
416
|
|
221
417
|
:param callback_class: The callback class to use as the default callback.
|
222
|
-
:type callback_class:
|
418
|
+
:type callback_class: type[Callback]
|
223
419
|
"""
|
224
420
|
self.Callback = callback_class
|
225
421
|
return self
|
@@ -249,20 +445,24 @@ class Graph:
|
|
249
445
|
f"No outgoing edge found for current node '{current_node.name}'."
|
250
446
|
)
|
251
447
|
|
252
|
-
def visualize(self):
|
448
|
+
def visualize(self, *, save_path: str | None = None):
|
449
|
+
"""Render the current graph. If matplotlib is not installed,
|
450
|
+
raise a helpful error telling users to install the viz extra.
|
451
|
+
Optionally save to a file via `save_path`.
|
452
|
+
"""
|
253
453
|
try:
|
254
|
-
import
|
255
|
-
except ImportError:
|
454
|
+
import matplotlib.pyplot as plt
|
455
|
+
except ImportError as e:
|
256
456
|
raise ImportError(
|
257
|
-
"
|
258
|
-
)
|
457
|
+
"Graph visualization requires matplotlib. Install it with: `pip install matplotlib`"
|
458
|
+
) from e
|
259
459
|
|
260
460
|
try:
|
261
|
-
import
|
262
|
-
except ImportError:
|
461
|
+
import networkx as nx
|
462
|
+
except ImportError as e:
|
263
463
|
raise ImportError(
|
264
|
-
"
|
265
|
-
)
|
464
|
+
"NetworkX is required for visualization. Please install it with `pip install networkx`."
|
465
|
+
) from e
|
266
466
|
|
267
467
|
G: Any = nx.DiGraph()
|
268
468
|
|
@@ -328,8 +528,12 @@ class Graph:
|
|
328
528
|
arrowsize=20,
|
329
529
|
)
|
330
530
|
|
331
|
-
|
332
|
-
|
531
|
+
if save_path:
|
532
|
+
plt.savefig(save_path, bbox_inches="tight")
|
533
|
+
else:
|
534
|
+
plt.axis("off")
|
535
|
+
plt.show()
|
536
|
+
plt.close()
|
333
537
|
|
334
538
|
|
335
539
|
class Edge:
|
File without changes
|
@@ -1,9 +1,18 @@
|
|
1
|
+
from enum import Enum
|
1
2
|
import inspect
|
2
3
|
import os
|
3
|
-
|
4
|
+
import sys
|
5
|
+
from typing import Any, Callable, Union, get_args, get_origin
|
4
6
|
from pydantic import BaseModel, Field
|
7
|
+
from pydantic_core import PydanticUndefined
|
5
8
|
import logging
|
6
|
-
|
9
|
+
|
10
|
+
|
11
|
+
# we support python 3.10 so we define our own StrEnum (introduced in 3.11)
|
12
|
+
class StrEnum(str, Enum):
|
13
|
+
"""Backport of StrEnum for Python < 3.11"""
|
14
|
+
def __str__(self):
|
15
|
+
return self.value
|
7
16
|
|
8
17
|
|
9
18
|
class ColoredFormatter(logging.Formatter):
|
@@ -119,6 +128,9 @@ class Parameter(BaseModel):
|
|
119
128
|
}
|
120
129
|
}
|
121
130
|
|
131
|
+
class OpenAIAPI(StrEnum):
|
132
|
+
COMPLETIONS = "completions"
|
133
|
+
RESPONSES = "responses"
|
122
134
|
|
123
135
|
class FunctionSchema(BaseModel):
|
124
136
|
"""Class that consumes a function and can return a schema required by
|
@@ -167,29 +179,95 @@ class FunctionSchema(BaseModel):
|
|
167
179
|
)
|
168
180
|
|
169
181
|
@classmethod
|
170
|
-
def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
|
182
|
+
def from_pydantic(cls, model: type[BaseModel]) -> "FunctionSchema":
|
183
|
+
"""Create a FunctionSchema from a Pydantic model class.
|
184
|
+
|
185
|
+
:param model: The Pydantic model class to convert
|
186
|
+
:type model: type[BaseModel]
|
187
|
+
:return: FunctionSchema instance
|
188
|
+
:rtype: FunctionSchema
|
189
|
+
"""
|
190
|
+
# Extract model metadata
|
191
|
+
name = model.__name__
|
192
|
+
description = model.__doc__ or ""
|
193
|
+
|
194
|
+
# Build parameters list
|
195
|
+
parameters = []
|
171
196
|
signature_parts = []
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
197
|
+
|
198
|
+
for field_name, field_info in model.model_fields.items():
|
199
|
+
# Get the field type
|
200
|
+
field_type = model.__annotations__.get(field_name)
|
201
|
+
|
202
|
+
# Determine the type name - handle Optional and other generic types
|
203
|
+
type_name = str(field_type)
|
204
|
+
|
205
|
+
# Try to extract the actual type from Optional[T] -> T
|
206
|
+
origin = get_origin(field_type)
|
207
|
+
args = get_args(field_type)
|
208
|
+
|
209
|
+
if origin is Union:
|
210
|
+
# This is likely Optional[T] which is Union[T, None]
|
211
|
+
non_none_types = [arg for arg in args if arg is not type(None)]
|
212
|
+
if non_none_types:
|
213
|
+
actual_type = non_none_types[0]
|
214
|
+
if hasattr(actual_type, '__name__'):
|
215
|
+
type_name = actual_type.__name__
|
216
|
+
else:
|
217
|
+
type_name = str(actual_type)
|
218
|
+
elif field_type and hasattr(field_type, '__name__'):
|
219
|
+
type_name = field_type.__name__
|
220
|
+
|
221
|
+
# Check if field is required (no default value)
|
222
|
+
# In Pydantic v2, PydanticUndefined means no default
|
223
|
+
is_required = (
|
224
|
+
field_info.default is PydanticUndefined
|
225
|
+
and field_info.default_factory is None
|
226
|
+
)
|
227
|
+
|
228
|
+
# Get the actual default value
|
229
|
+
if field_info.default is not PydanticUndefined and field_info.default is not None:
|
230
|
+
default_value = field_info.default
|
231
|
+
elif field_info.default_factory is not None:
|
232
|
+
# For default_factory, we can't always call it without arguments
|
233
|
+
# Just use a placeholder to indicate there's a factory
|
234
|
+
try:
|
235
|
+
# Try calling with no arguments (common case)
|
236
|
+
default_value = field_info.default_factory() # type: ignore[call-arg]
|
237
|
+
except TypeError:
|
238
|
+
# If it needs arguments, just indicate it has a factory default
|
239
|
+
default_value = "<factory>"
|
240
|
+
else:
|
241
|
+
default_value = inspect.Parameter.empty
|
242
|
+
|
243
|
+
# Add parameter
|
244
|
+
parameters.append(
|
245
|
+
Parameter(
|
246
|
+
name=field_name,
|
247
|
+
description=field_info.description,
|
248
|
+
type=type_name,
|
249
|
+
default=default_value,
|
250
|
+
required=is_required,
|
179
251
|
)
|
252
|
+
)
|
253
|
+
|
254
|
+
# Build signature part
|
255
|
+
if default_value != inspect.Parameter.empty:
|
256
|
+
signature_parts.append(f"{field_name}: {type_name} = {repr(default_value)}")
|
180
257
|
else:
|
181
|
-
|
182
|
-
|
183
|
-
signature = f"({', '.join(signature_parts)}) ->
|
258
|
+
signature_parts.append(f"{field_name}: {type_name}")
|
259
|
+
|
260
|
+
signature = f"({', '.join(signature_parts)}) -> dict"
|
261
|
+
|
184
262
|
return cls.model_construct(
|
185
|
-
name=
|
186
|
-
description=
|
263
|
+
name=name,
|
264
|
+
description=description,
|
187
265
|
signature=signature,
|
188
|
-
output="",
|
189
|
-
parameters=
|
266
|
+
output="dict",
|
267
|
+
parameters=parameters,
|
190
268
|
)
|
191
269
|
|
192
|
-
def to_dict(self) -> dict:
|
270
|
+
def to_dict(self) -> dict[str, Any]:
|
193
271
|
schema_dict = {
|
194
272
|
"type": "function",
|
195
273
|
"function": {
|
@@ -210,14 +288,42 @@ class FunctionSchema(BaseModel):
|
|
210
288
|
}
|
211
289
|
return schema_dict
|
212
290
|
|
213
|
-
def to_openai(self) -> dict:
|
214
|
-
|
291
|
+
def to_openai(self, api: OpenAIAPI=OpenAIAPI.COMPLETIONS) -> dict[str, Any]:
|
292
|
+
"""Convert the function schema into OpenAI-compatible formats. Supports
|
293
|
+
both completions and responses APIs.
|
294
|
+
|
295
|
+
:param api: The API to convert to.
|
296
|
+
:type api: OpenAIAPI
|
297
|
+
:return: The function schema in OpenAI-compatible format.
|
298
|
+
:rtype: dict
|
299
|
+
"""
|
300
|
+
if api == "completions":
|
301
|
+
return self.to_dict()
|
302
|
+
elif api == "responses":
|
303
|
+
return {
|
304
|
+
"type": "function",
|
305
|
+
"name": self.name,
|
306
|
+
"description": self.description,
|
307
|
+
"parameters": {
|
308
|
+
"type": "object",
|
309
|
+
"properties": {
|
310
|
+
k: v
|
311
|
+
for param in self.parameters
|
312
|
+
for k, v in param.to_dict().items()
|
313
|
+
},
|
314
|
+
"required": [
|
315
|
+
param.name for param in self.parameters if param.required
|
316
|
+
],
|
317
|
+
},
|
318
|
+
}
|
319
|
+
else:
|
320
|
+
raise ValueError(f"Unrecognized OpenAI API: {api}")
|
215
321
|
|
216
322
|
|
217
323
|
DEFAULT = set(["default", "openai", "ollama", "litellm"])
|
218
324
|
|
219
325
|
|
220
|
-
def get_schemas(callables: list[Callable], format: str = "default") -> list[dict]:
|
326
|
+
def get_schemas(callables: list[Callable], format: str = "default") -> list[dict[str, Any]]:
|
221
327
|
if format in DEFAULT:
|
222
328
|
return [
|
223
329
|
FunctionSchema.from_callable(callable).to_dict() for callable in callables
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "graphai-lib"
|
3
|
-
version = "0.0.
|
3
|
+
version = "0.0.9"
|
4
4
|
description = "Not an AI framework"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.10,<3.14"
|
@@ -29,6 +29,9 @@ build-backend = "setuptools.build_meta"
|
|
29
29
|
[tool.setuptools]
|
30
30
|
packages = ["graphai", "graphai.nodes"]
|
31
31
|
|
32
|
+
[tool.setuptools.package-data]
|
33
|
+
graphai = ["py.typed"]
|
34
|
+
|
32
35
|
[tool.mypy]
|
33
36
|
python_version = "3.10"
|
34
37
|
warn_return_any = true
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|