graphai-lib 0.0.4__tar.gz → 0.0.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/PKG-INFO +4 -3
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai/__init__.py +1 -1
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai/callback.py +29 -44
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai/graph.py +95 -41
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai/nodes/__init__.py +1 -1
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai/nodes/base.py +53 -16
- graphai_lib-0.0.6/graphai/utils.py +205 -0
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai_lib.egg-info/PKG-INFO +4 -3
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai_lib.egg-info/requires.txt +3 -2
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/pyproject.toml +5 -4
- graphai_lib-0.0.4/graphai/utils.py +0 -125
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/README.md +0 -0
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai_lib.egg-info/SOURCES.txt +0 -0
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai_lib.egg-info/dependency_links.txt +0 -0
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/graphai_lib.egg-info/top_level.txt +0 -0
- {graphai_lib-0.0.4 → graphai_lib-0.0.6}/setup.cfg +0 -0
@@ -1,12 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: graphai-lib
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: Not an AI framework
|
5
5
|
Requires-Python: <3.14,>=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
7
|
-
Requires-Dist: semantic-router>=0.1.5
|
8
7
|
Requires-Dist: networkx>=3.4.2
|
9
8
|
Requires-Dist: matplotlib>=3.10.0
|
9
|
+
Requires-Dist: pydantic>=2.11.1
|
10
|
+
Requires-Dist: colorlog>=6.9.0
|
10
11
|
Provides-Extra: dev
|
11
12
|
Requires-Dist: ipykernel>=6.25.0; extra == "dev"
|
12
13
|
Requires-Dist: ruff>=0.1.5; extra == "dev"
|
@@ -16,7 +17,7 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
|
16
17
|
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
17
18
|
Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
|
18
19
|
Requires-Dist: mypy>=1.7.1; extra == "dev"
|
19
|
-
Requires-Dist:
|
20
|
+
Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
|
20
21
|
Provides-Extra: docs
|
21
22
|
Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
|
22
23
|
|
@@ -1,12 +1,12 @@
|
|
1
1
|
import asyncio
|
2
2
|
from pydantic import Field
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Optional, Any
|
4
4
|
from collections.abc import AsyncIterator
|
5
|
-
from semantic_router.utils.logger import logger
|
6
5
|
|
7
6
|
|
8
7
|
log_stream = True
|
9
8
|
|
9
|
+
|
10
10
|
class Callback:
|
11
11
|
identifier: str = Field(
|
12
12
|
default="graphai",
|
@@ -14,7 +14,7 @@ class Callback:
|
|
14
14
|
"The identifier for special tokens. This allows us to easily "
|
15
15
|
"identify special tokens in the stream so we can handle them "
|
16
16
|
"correctly in any downstream process."
|
17
|
-
)
|
17
|
+
),
|
18
18
|
)
|
19
19
|
special_token_format: str = Field(
|
20
20
|
default="<{identifier}:{token}:{params}>",
|
@@ -32,8 +32,8 @@ class Callback:
|
|
32
32
|
examples=[
|
33
33
|
"<{identifier}:{token}:{params}>",
|
34
34
|
"<[{identifier} | {token} | {params}]>",
|
35
|
-
"<{token}:{params}>"
|
36
|
-
]
|
35
|
+
"<{token}:{params}>",
|
36
|
+
],
|
37
37
|
)
|
38
38
|
token_format: str = Field(
|
39
39
|
default="{token}",
|
@@ -41,27 +41,23 @@ class Callback:
|
|
41
41
|
"The format for streamed tokens. This is used to format the "
|
42
42
|
"tokens typically returned from LLMs. By default, no special "
|
43
43
|
"formatting is applied."
|
44
|
-
)
|
44
|
+
),
|
45
45
|
)
|
46
46
|
_first_token: bool = Field(
|
47
47
|
default=True,
|
48
48
|
description="Whether this is the first token in the stream.",
|
49
|
-
exclude=True
|
49
|
+
exclude=True,
|
50
50
|
)
|
51
51
|
_current_node_name: Optional[str] = Field(
|
52
|
-
default=None,
|
53
|
-
description="The name of the current node.",
|
54
|
-
exclude=True
|
52
|
+
default=None, description="The name of the current node.", exclude=True
|
55
53
|
)
|
56
54
|
_active: bool = Field(
|
57
|
-
default=True,
|
58
|
-
description="Whether the callback is active.",
|
59
|
-
exclude=True
|
55
|
+
default=True, description="Whether the callback is active.", exclude=True
|
60
56
|
)
|
61
57
|
_done: bool = Field(
|
62
58
|
default=False,
|
63
59
|
description="Whether the stream is done and should be closed.",
|
64
|
-
exclude=True
|
60
|
+
exclude=True,
|
65
61
|
)
|
66
62
|
queue: asyncio.Queue
|
67
63
|
|
@@ -83,7 +79,7 @@ class Callback:
|
|
83
79
|
@property
|
84
80
|
def first_token(self) -> bool:
|
85
81
|
return self._first_token
|
86
|
-
|
82
|
+
|
87
83
|
@first_token.setter
|
88
84
|
def first_token(self, value: bool):
|
89
85
|
self._first_token = value
|
@@ -91,7 +87,7 @@ class Callback:
|
|
91
87
|
@property
|
92
88
|
def current_node_name(self) -> Optional[str]:
|
93
89
|
return self._current_node_name
|
94
|
-
|
90
|
+
|
95
91
|
@current_node_name.setter
|
96
92
|
def current_node_name(self, value: Optional[str]):
|
97
93
|
self._current_node_name = value
|
@@ -99,7 +95,7 @@ class Callback:
|
|
99
95
|
@property
|
100
96
|
def active(self) -> bool:
|
101
97
|
return self._active
|
102
|
-
|
98
|
+
|
103
99
|
@active.setter
|
104
100
|
def active(self, value: bool):
|
105
101
|
self._active = value
|
@@ -110,7 +106,7 @@ class Callback:
|
|
110
106
|
self._check_node_name(node_name=node_name)
|
111
107
|
# otherwise we just assume node is correct and send token
|
112
108
|
self.queue.put_nowait(token)
|
113
|
-
|
109
|
+
|
114
110
|
async def acall(self, token: str, node_name: Optional[str] = None):
|
115
111
|
# TODO JB: do we need to have `node_name` param?
|
116
112
|
if self._done:
|
@@ -118,16 +114,13 @@ class Callback:
|
|
118
114
|
self._check_node_name(node_name=node_name)
|
119
115
|
# otherwise we just assume node is correct and send token
|
120
116
|
self.queue.put_nowait(token)
|
121
|
-
|
117
|
+
|
122
118
|
async def aiter(self) -> AsyncIterator[str]:
|
123
119
|
"""Used by receiver to get the tokens from the stream queue. Creates
|
124
120
|
a generator that yields tokens from the queue until the END token is
|
125
121
|
received.
|
126
122
|
"""
|
127
|
-
end_token = await self._build_special_token(
|
128
|
-
name="END",
|
129
|
-
params=None
|
130
|
-
)
|
123
|
+
end_token = await self._build_special_token(name="END", params=None)
|
131
124
|
while True: # Keep going until we see the END token
|
132
125
|
try:
|
133
126
|
if self._done and self.queue.empty():
|
@@ -142,8 +135,7 @@ class Callback:
|
|
142
135
|
self._done = True # Mark as done after processing all tokens
|
143
136
|
|
144
137
|
async def start_node(self, node_name: str, active: bool = True):
|
145
|
-
"""Starts a new node and emits the start token.
|
146
|
-
"""
|
138
|
+
"""Starts a new node and emits the start token."""
|
147
139
|
if self._done:
|
148
140
|
raise RuntimeError("Cannot start node on a closed stream")
|
149
141
|
self.current_node_name = node_name
|
@@ -152,27 +144,23 @@ class Callback:
|
|
152
144
|
self.active = active
|
153
145
|
if self.active:
|
154
146
|
token = await self._build_special_token(
|
155
|
-
name=f"{self.current_node_name}:start",
|
156
|
-
params=None
|
147
|
+
name=f"{self.current_node_name}:start", params=None
|
157
148
|
)
|
158
149
|
self.queue.put_nowait(token)
|
159
150
|
# TODO JB: should we use two tokens here?
|
160
151
|
node_token = await self._build_special_token(
|
161
|
-
name=self.current_node_name,
|
162
|
-
params=None
|
152
|
+
name=self.current_node_name, params=None
|
163
153
|
)
|
164
154
|
self.queue.put_nowait(node_token)
|
165
|
-
|
155
|
+
|
166
156
|
async def end_node(self, node_name: str):
|
167
|
-
"""Emits the end token for the current node.
|
168
|
-
"""
|
157
|
+
"""Emits the end token for the current node."""
|
169
158
|
if self._done:
|
170
159
|
raise RuntimeError("Cannot end node on a closed stream")
|
171
|
-
#self.current_node_name = node_name
|
160
|
+
# self.current_node_name = node_name
|
172
161
|
if self.active:
|
173
162
|
node_token = await self._build_special_token(
|
174
|
-
name=f"{self.current_node_name}:end",
|
175
|
-
params=None
|
163
|
+
name=f"{self.current_node_name}:end", params=None
|
176
164
|
)
|
177
165
|
self.queue.put_nowait(node_token)
|
178
166
|
|
@@ -182,10 +170,7 @@ class Callback:
|
|
182
170
|
"""
|
183
171
|
if self._done:
|
184
172
|
return
|
185
|
-
end_token = await self._build_special_token(
|
186
|
-
name="END",
|
187
|
-
params=None
|
188
|
-
)
|
173
|
+
end_token = await self._build_special_token(name="END", params=None)
|
189
174
|
self._done = True # Set done before putting the end token
|
190
175
|
self.queue.put_nowait(end_token)
|
191
176
|
# Don't wait for queue.join() as it can cause deadlock
|
@@ -198,8 +183,10 @@ class Callback:
|
|
198
183
|
raise ValueError(
|
199
184
|
f"Node name mismatch: {self.current_node_name} != {node_name}"
|
200
185
|
)
|
201
|
-
|
202
|
-
async def _build_special_token(
|
186
|
+
|
187
|
+
async def _build_special_token(
|
188
|
+
self, name: str, params: dict[str, Any] | None = None
|
189
|
+
):
|
203
190
|
if params:
|
204
191
|
params_str = ",".join([f"{k}={v}" for k, v in params.items()])
|
205
192
|
else:
|
@@ -209,7 +196,5 @@ class Callback:
|
|
209
196
|
else:
|
210
197
|
identifier = ""
|
211
198
|
return self.special_token_format.format(
|
212
|
-
identifier=identifier,
|
213
|
-
token=name,
|
214
|
-
params=params_str
|
199
|
+
identifier=identifier, token=name, params=params_str
|
215
200
|
)
|
@@ -1,17 +1,33 @@
|
|
1
|
-
from typing import List, Dict, Any, Optional
|
2
|
-
from graphai.nodes.base import _Node
|
1
|
+
from typing import List, Dict, Any, Optional, Protocol, Type
|
3
2
|
from graphai.callback import Callback
|
4
|
-
from
|
3
|
+
from graphai.utils import logger
|
4
|
+
|
5
|
+
|
6
|
+
class NodeProtocol(Protocol):
|
7
|
+
"""Protocol defining the interface of a decorated node."""
|
8
|
+
name: str
|
9
|
+
is_start: bool
|
10
|
+
is_end: bool
|
11
|
+
is_router: bool
|
12
|
+
stream: bool
|
13
|
+
|
14
|
+
async def invoke(
|
15
|
+
self,
|
16
|
+
input: Dict[str, Any],
|
17
|
+
callback: Optional[Callback] = None,
|
18
|
+
state: Optional[Dict[str, Any]] = None
|
19
|
+
) -> Dict[str, Any]: ...
|
5
20
|
|
6
21
|
|
7
22
|
class Graph:
|
8
|
-
def __init__(
|
9
|
-
self
|
23
|
+
def __init__(
|
24
|
+
self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
|
25
|
+
):
|
26
|
+
self.nodes: Dict[str, NodeProtocol] = {}
|
10
27
|
self.edges: List[Any] = []
|
11
|
-
self.start_node: Optional[
|
12
|
-
self.end_nodes: List[
|
13
|
-
self.Callback = Callback
|
14
|
-
self.callback = None
|
28
|
+
self.start_node: Optional[NodeProtocol] = None
|
29
|
+
self.end_nodes: List[NodeProtocol] = []
|
30
|
+
self.Callback: Type[Callback] = Callback
|
15
31
|
self.max_steps = max_steps
|
16
32
|
self.state = initial_state or {}
|
17
33
|
|
@@ -32,7 +48,7 @@ class Graph:
|
|
32
48
|
"""Reset the graph state to an empty dict."""
|
33
49
|
self.state = {}
|
34
50
|
|
35
|
-
def add_node(self, node):
|
51
|
+
def add_node(self, node: NodeProtocol):
|
36
52
|
if node.name in self.nodes:
|
37
53
|
raise Exception(f"Node with name '{node.name}' already exists.")
|
38
54
|
self.nodes[node.name] = node
|
@@ -47,9 +63,9 @@ class Graph:
|
|
47
63
|
if node.is_end:
|
48
64
|
self.end_nodes.append(node)
|
49
65
|
|
50
|
-
def add_edge(self, source:
|
66
|
+
def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str):
|
51
67
|
"""Adds an edge between two nodes that already exist in the graph.
|
52
|
-
|
68
|
+
|
53
69
|
Args:
|
54
70
|
source: The source node or its name.
|
55
71
|
destination: The destination node or its name.
|
@@ -60,7 +76,7 @@ class Graph:
|
|
60
76
|
source_node = self.nodes.get(source)
|
61
77
|
else:
|
62
78
|
# Check if it's a node-like object by looking for required attributes
|
63
|
-
if hasattr(source,
|
79
|
+
if hasattr(source, "name"):
|
64
80
|
source_node = self.nodes.get(source.name)
|
65
81
|
if source_node is None:
|
66
82
|
raise ValueError(
|
@@ -71,7 +87,7 @@ class Graph:
|
|
71
87
|
destination_node = self.nodes.get(destination)
|
72
88
|
else:
|
73
89
|
# Check if it's a node-like object by looking for required attributes
|
74
|
-
if hasattr(destination,
|
90
|
+
if hasattr(destination, "name"):
|
75
91
|
destination_node = self.nodes.get(destination.name)
|
76
92
|
if destination_node is None:
|
77
93
|
raise ValueError(
|
@@ -80,17 +96,19 @@ class Graph:
|
|
80
96
|
edge = Edge(source_node, destination_node)
|
81
97
|
self.edges.append(edge)
|
82
98
|
|
83
|
-
def add_router(
|
99
|
+
def add_router(
|
100
|
+
self, sources: list[NodeProtocol], router: NodeProtocol, destinations: List[NodeProtocol]
|
101
|
+
):
|
84
102
|
if not router.is_router:
|
85
103
|
raise TypeError("A router object must be passed to the router parameter.")
|
86
104
|
[self.add_edge(source, router) for source in sources]
|
87
105
|
for destination in destinations:
|
88
106
|
self.add_edge(router, destination)
|
89
107
|
|
90
|
-
def set_start_node(self, node:
|
108
|
+
def set_start_node(self, node: NodeProtocol):
|
91
109
|
self.start_node = node
|
92
110
|
|
93
|
-
def set_end_node(self, node:
|
111
|
+
def set_end_node(self, node: NodeProtocol):
|
94
112
|
self.end_node = node
|
95
113
|
|
96
114
|
def compile(self):
|
@@ -112,11 +130,15 @@ class Graph:
|
|
112
130
|
f"Instead, got {type(output)} from '{output}'."
|
113
131
|
)
|
114
132
|
|
115
|
-
async def execute(self, input):
|
133
|
+
async def execute(self, input, callback: Callback | None = None):
|
116
134
|
# TODO JB: may need to add init callback here to init the queue on every new execution
|
117
|
-
if
|
118
|
-
|
135
|
+
if callback is None:
|
136
|
+
callback = self.get_callback()
|
137
|
+
|
138
|
+
# Type assertion to tell the type checker that start_node is not None after compile()
|
139
|
+
assert self.start_node is not None, "Graph must be compiled before execution"
|
119
140
|
current_node = self.start_node
|
141
|
+
|
120
142
|
state = input
|
121
143
|
# Don't reset the graph state if it was initialized with initial_state
|
122
144
|
steps = 0
|
@@ -124,11 +146,13 @@ class Graph:
|
|
124
146
|
# we invoke the node here
|
125
147
|
if current_node.stream:
|
126
148
|
# add callback tokens and param here if we are streaming
|
127
|
-
await
|
149
|
+
await callback.start_node(node_name=current_node.name)
|
128
150
|
# Include graph's internal state in the node execution context
|
129
|
-
output = await current_node.invoke(
|
151
|
+
output = await current_node.invoke(
|
152
|
+
input=state, callback=callback, state=self.state
|
153
|
+
)
|
130
154
|
self._validate_output(output=output, node_name=current_node.name)
|
131
|
-
await
|
155
|
+
await callback.end_node(node_name=current_node.name)
|
132
156
|
else:
|
133
157
|
# Include graph's internal state in the node execution context
|
134
158
|
output = await current_node.invoke(input=state, state=self.state)
|
@@ -153,24 +177,38 @@ class Graph:
|
|
153
177
|
"by setting `max_steps` when initializing the Graph object."
|
154
178
|
)
|
155
179
|
# TODO JB: may need to add end callback here to close the queue for every execution
|
156
|
-
if
|
157
|
-
await
|
180
|
+
if callback and "callback" in state:
|
181
|
+
await callback.close()
|
158
182
|
del state["callback"]
|
159
183
|
return state
|
160
184
|
|
161
185
|
def get_callback(self):
|
162
|
-
|
163
|
-
|
186
|
+
"""Get a new instance of the callback class.
|
187
|
+
|
188
|
+
:return: A new instance of the callback class.
|
189
|
+
:rtype: Callback
|
190
|
+
"""
|
191
|
+
callback = self.Callback()
|
192
|
+
return callback
|
193
|
+
|
194
|
+
def set_callback(self, callback_class: Type[Callback]):
|
195
|
+
"""Set the callback class that is returned by the `get_callback` method and used
|
196
|
+
as the default callback when no callback is passed to the `execute` method.
|
197
|
+
|
198
|
+
:param callback_class: The callback class to use as the default callback.
|
199
|
+
:type callback_class: Type[Callback]
|
200
|
+
"""
|
201
|
+
self.Callback = callback_class
|
164
202
|
|
165
|
-
def _get_node_by_name(self, node_name: str) ->
|
203
|
+
def _get_node_by_name(self, node_name: str) -> NodeProtocol:
|
166
204
|
"""Get a node by its name.
|
167
|
-
|
205
|
+
|
168
206
|
Args:
|
169
207
|
node_name: The name of the node to find.
|
170
|
-
|
208
|
+
|
171
209
|
Returns:
|
172
210
|
The node with the given name.
|
173
|
-
|
211
|
+
|
174
212
|
Raises:
|
175
213
|
Exception: If no node with the given name is found.
|
176
214
|
"""
|
@@ -191,12 +229,16 @@ class Graph:
|
|
191
229
|
try:
|
192
230
|
import networkx as nx
|
193
231
|
except ImportError:
|
194
|
-
raise ImportError(
|
232
|
+
raise ImportError(
|
233
|
+
"NetworkX is required for visualization. Please install it with 'pip install networkx'."
|
234
|
+
)
|
195
235
|
|
196
236
|
try:
|
197
237
|
import matplotlib.pyplot as plt
|
198
238
|
except ImportError:
|
199
|
-
raise ImportError(
|
239
|
+
raise ImportError(
|
240
|
+
"Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
|
241
|
+
)
|
200
242
|
|
201
243
|
G = nx.DiGraph()
|
202
244
|
|
@@ -207,7 +249,9 @@ class Graph:
|
|
207
249
|
G.add_edge(edge.source.name, edge.destination.name)
|
208
250
|
|
209
251
|
if nx.is_directed_acyclic_graph(G):
|
210
|
-
logger.info(
|
252
|
+
logger.info(
|
253
|
+
"The graph is acyclic. Visualization will use a topological layout."
|
254
|
+
)
|
211
255
|
# Use topological layout if acyclic
|
212
256
|
# Compute the topological generations
|
213
257
|
generations = list(nx.topological_generations(G))
|
@@ -241,20 +285,30 @@ class Graph:
|
|
241
285
|
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
|
242
286
|
|
243
287
|
else:
|
244
|
-
print(
|
288
|
+
print(
|
289
|
+
"Warning: The graph contains cycles. Visualization will use a spring layout."
|
290
|
+
)
|
245
291
|
pos = nx.spring_layout(G, k=1, iterations=50)
|
246
292
|
|
247
293
|
plt.figure(figsize=(8, 6))
|
248
|
-
nx.draw(
|
249
|
-
|
250
|
-
|
294
|
+
nx.draw(
|
295
|
+
G,
|
296
|
+
pos,
|
297
|
+
with_labels=True,
|
298
|
+
node_color="lightblue",
|
299
|
+
node_size=3000,
|
300
|
+
font_size=8,
|
301
|
+
font_weight="bold",
|
302
|
+
arrows=True,
|
303
|
+
edge_color="gray",
|
304
|
+
arrowsize=20,
|
305
|
+
)
|
251
306
|
|
252
|
-
plt.axis(
|
307
|
+
plt.axis("off")
|
253
308
|
plt.show()
|
254
309
|
|
255
310
|
|
256
|
-
|
257
311
|
class Edge:
|
258
312
|
def __init__(self, source, destination):
|
259
313
|
self.source = source
|
260
|
-
self.destination = destination
|
314
|
+
self.destination = destination
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import inspect
|
2
2
|
from typing import Any, Callable, Dict, Optional
|
3
|
+
from pydantic import Field
|
3
4
|
|
4
5
|
from graphai.callback import Callback
|
5
6
|
from graphai.utils import FunctionSchema
|
@@ -9,7 +10,11 @@ class NodeMeta(type):
|
|
9
10
|
@staticmethod
|
10
11
|
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
|
11
12
|
init_signature = inspect.signature(cls_type.__init__)
|
12
|
-
init_params = {
|
13
|
+
init_params = {
|
14
|
+
name: arg
|
15
|
+
for name, arg in init_signature.parameters.items()
|
16
|
+
if name != "self"
|
17
|
+
}
|
13
18
|
return init_params
|
14
19
|
|
15
20
|
def __call__(cls, *args, **kwargs):
|
@@ -33,18 +38,32 @@ class _Node:
|
|
33
38
|
stream: bool = False,
|
34
39
|
name: str | None = None,
|
35
40
|
) -> Callable:
|
36
|
-
"""Decorator validating node structure.
|
37
|
-
"""
|
41
|
+
"""Decorator validating node structure."""
|
38
42
|
if not callable(func):
|
39
43
|
raise ValueError("Node must be a callable function.")
|
40
|
-
|
44
|
+
|
41
45
|
func_signature = inspect.signature(func)
|
42
|
-
schema = FunctionSchema(func)
|
46
|
+
schema: FunctionSchema = FunctionSchema.from_callable(func)
|
43
47
|
|
44
48
|
class NodeClass:
|
45
49
|
_func_signature = func_signature
|
46
|
-
is_router =
|
47
|
-
|
50
|
+
is_router: bool = Field(
|
51
|
+
default=False, description="Whether the node is a router."
|
52
|
+
)
|
53
|
+
# following attributes will be overridden by the decorator
|
54
|
+
name: str | None = Field(default=None, description="The name of the node.")
|
55
|
+
is_start: bool = Field(
|
56
|
+
default=False, description="Whether the node is the start of the graph."
|
57
|
+
)
|
58
|
+
is_end: bool = Field(
|
59
|
+
default=False, description="Whether the node is the end of the graph."
|
60
|
+
)
|
61
|
+
schema: FunctionSchema | None = Field(
|
62
|
+
default=None, description="The schema of the node."
|
63
|
+
)
|
64
|
+
stream: bool = Field(
|
65
|
+
default=False, description="Whether the node includes streaming object."
|
66
|
+
)
|
48
67
|
|
49
68
|
def __init__(self):
|
50
69
|
self._expected_params = set(self._func_signature.parameters.keys())
|
@@ -56,9 +75,13 @@ class _Node:
|
|
56
75
|
|
57
76
|
async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
|
58
77
|
# filter out unexpected keyword args
|
59
|
-
expected_kwargs = {
|
78
|
+
expected_kwargs = {
|
79
|
+
k: v for k, v in kwargs.items() if k in self._expected_params
|
80
|
+
}
|
60
81
|
# Convert args to kwargs based on the function signature
|
61
|
-
args_names = list(self._func_signature.parameters.keys())[
|
82
|
+
args_names = list(self._func_signature.parameters.keys())[
|
83
|
+
1 : len(args) + 1
|
84
|
+
] # skip 'self'
|
62
85
|
expected_args_kwargs = dict(zip(args_names, args))
|
63
86
|
# Combine filtered args and kwargs
|
64
87
|
combined_params = {**expected_args_kwargs, **expected_kwargs}
|
@@ -87,7 +110,6 @@ class _Node:
|
|
87
110
|
)
|
88
111
|
return filtered_params
|
89
112
|
|
90
|
-
|
91
113
|
@classmethod
|
92
114
|
def get_signature(cls):
|
93
115
|
"""Returns the signature of the decorated function as LLM readable
|
@@ -97,15 +119,24 @@ class _Node:
|
|
97
119
|
if NodeClass._func_signature:
|
98
120
|
for param in NodeClass._func_signature.parameters.values():
|
99
121
|
if param.default is param.empty:
|
100
|
-
signature_components.append(
|
122
|
+
signature_components.append(
|
123
|
+
f"{param.name}: {param.annotation}"
|
124
|
+
)
|
101
125
|
else:
|
102
|
-
signature_components.append(
|
126
|
+
signature_components.append(
|
127
|
+
f"{param.name}: {param.annotation} = {param.default}"
|
128
|
+
)
|
103
129
|
else:
|
104
130
|
return "No signature"
|
105
131
|
return "\n".join(signature_components)
|
106
132
|
|
107
133
|
@classmethod
|
108
|
-
async def invoke(
|
134
|
+
async def invoke(
|
135
|
+
cls,
|
136
|
+
input: Dict[str, Any],
|
137
|
+
callback: Optional[Callback] = None,
|
138
|
+
state: Optional[Dict[str, Any]] = None,
|
139
|
+
):
|
109
140
|
if callback:
|
110
141
|
if stream:
|
111
142
|
input["callback"] = callback
|
@@ -116,13 +147,16 @@ class _Node:
|
|
116
147
|
# Add state to the input if present and the parameter exists in the function signature
|
117
148
|
if state is not None and "state" in cls._func_signature.parameters:
|
118
149
|
input["state"] = state
|
119
|
-
|
150
|
+
|
120
151
|
instance = cls()
|
121
152
|
out = await instance.execute(**input)
|
122
153
|
return out
|
123
154
|
|
124
155
|
NodeClass.__name__ = func.__name__
|
125
|
-
|
156
|
+
node_class_name = name or func.__name__
|
157
|
+
if node_class_name is None:
|
158
|
+
raise ValueError("Unexpected error: node name not set.")
|
159
|
+
NodeClass.name = node_class_name
|
126
160
|
NodeClass.__doc__ = func.__doc__
|
127
161
|
NodeClass.is_start = start
|
128
162
|
NodeClass.is_end = end
|
@@ -141,8 +175,11 @@ class _Node:
|
|
141
175
|
):
|
142
176
|
# We must wrap the call to the decorator in a function for it to work
|
143
177
|
# correctly with or without parenthesis
|
144
|
-
def wrap(
|
178
|
+
def wrap(
|
179
|
+
func: Callable, start=start, end=end, stream=stream, name=name
|
180
|
+
) -> Callable:
|
145
181
|
return self._node(func=func, start=start, end=end, stream=stream, name=name)
|
182
|
+
|
146
183
|
if func:
|
147
184
|
# Decorator is called without parenthesis
|
148
185
|
return wrap(func=func, start=start, end=end, stream=stream, name=name)
|
@@ -0,0 +1,205 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import Any, Callable, List, Optional
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
import logging
|
5
|
+
|
6
|
+
import colorlog
|
7
|
+
|
8
|
+
|
9
|
+
class CustomFormatter(colorlog.ColoredFormatter):
|
10
|
+
"""Custom formatter for the logger."""
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
super().__init__(
|
14
|
+
"%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
|
15
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
16
|
+
log_colors={
|
17
|
+
"DEBUG": "cyan",
|
18
|
+
"INFO": "green",
|
19
|
+
"WARNING": "yellow",
|
20
|
+
"ERROR": "red",
|
21
|
+
"CRITICAL": "bold_red",
|
22
|
+
},
|
23
|
+
reset=True,
|
24
|
+
style="%",
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def add_coloured_handler(logger):
|
29
|
+
"""Add a coloured handler to the logger."""
|
30
|
+
formatter = CustomFormatter()
|
31
|
+
console_handler = logging.StreamHandler()
|
32
|
+
console_handler.setFormatter(formatter)
|
33
|
+
logger.addHandler(console_handler)
|
34
|
+
return logger
|
35
|
+
|
36
|
+
|
37
|
+
def setup_custom_logger(name):
|
38
|
+
"""Setup a custom logger."""
|
39
|
+
logger = logging.getLogger(name)
|
40
|
+
|
41
|
+
if not logger.hasHandlers():
|
42
|
+
add_coloured_handler(logger)
|
43
|
+
logger.setLevel(logging.INFO)
|
44
|
+
logger.propagate = False
|
45
|
+
|
46
|
+
return logger
|
47
|
+
|
48
|
+
|
49
|
+
logger: logging.Logger = setup_custom_logger(__name__)
|
50
|
+
|
51
|
+
|
52
|
+
def openai_type_mapping(param_type: str) -> str:
|
53
|
+
if param_type == "int":
|
54
|
+
return "number"
|
55
|
+
elif param_type == "float":
|
56
|
+
return "number"
|
57
|
+
elif param_type == "str":
|
58
|
+
return "string"
|
59
|
+
elif param_type == "bool":
|
60
|
+
return "boolean"
|
61
|
+
else:
|
62
|
+
return "object"
|
63
|
+
|
64
|
+
|
65
|
+
class Parameter(BaseModel):
|
66
|
+
"""Parameter for a function.
|
67
|
+
|
68
|
+
:param name: The name of the parameter.
|
69
|
+
:type name: str
|
70
|
+
:param description: The description of the parameter.
|
71
|
+
:type description: Optional[str]
|
72
|
+
:param type: The type of the parameter.
|
73
|
+
:type type: str
|
74
|
+
:param default: The default value of the parameter.
|
75
|
+
:type default: Any
|
76
|
+
:param required: Whether the parameter is required.
|
77
|
+
:type required: bool
|
78
|
+
"""
|
79
|
+
|
80
|
+
name: str = Field(description="The name of the parameter")
|
81
|
+
description: Optional[str] = Field(
|
82
|
+
default=None, description="The description of the parameter"
|
83
|
+
)
|
84
|
+
type: str = Field(description="The type of the parameter")
|
85
|
+
default: Any = Field(description="The default value of the parameter")
|
86
|
+
required: bool = Field(description="Whether the parameter is required")
|
87
|
+
|
88
|
+
def to_dict(self) -> dict[str, Any]:
|
89
|
+
"""Convert the parameter to a dictionary for an standard dictionary-based function schema.
|
90
|
+
This is the most common format used by LLM providers, including OpenAI, Ollama, and others.
|
91
|
+
|
92
|
+
:return: The parameter in dictionary format.
|
93
|
+
:rtype: dict[str, Any]
|
94
|
+
"""
|
95
|
+
return {
|
96
|
+
self.name: {
|
97
|
+
"description": self.description,
|
98
|
+
"type": openai_type_mapping(self.type),
|
99
|
+
}
|
100
|
+
}
|
101
|
+
|
102
|
+
|
103
|
+
class FunctionSchema(BaseModel):
|
104
|
+
"""Class that consumes a function and can return a schema required by
|
105
|
+
different LLMs for function calling.
|
106
|
+
"""
|
107
|
+
|
108
|
+
name: str = Field(description="The name of the function")
|
109
|
+
description: str = Field(description="The description of the function")
|
110
|
+
signature: str = Field(description="The signature of the function")
|
111
|
+
output: str = Field(description="The output of the function")
|
112
|
+
parameters: list[Parameter] = Field(description="The parameters of the function")
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def from_callable(cls, function: Callable) -> "FunctionSchema":
|
116
|
+
"""Initialize the FunctionSchema.
|
117
|
+
|
118
|
+
:param function: The function to consume.
|
119
|
+
:type function: Callable
|
120
|
+
"""
|
121
|
+
if callable(function):
|
122
|
+
name = function.__name__
|
123
|
+
description = str(inspect.getdoc(function))
|
124
|
+
if description is None or description == "":
|
125
|
+
logger.warning(f"Function {name} has no docstring")
|
126
|
+
signature = str(inspect.signature(function))
|
127
|
+
output = str(inspect.signature(function).return_annotation)
|
128
|
+
parameters = []
|
129
|
+
for param in inspect.signature(function).parameters.values():
|
130
|
+
parameters.append(
|
131
|
+
Parameter(
|
132
|
+
name=param.name,
|
133
|
+
type=param.annotation.__name__,
|
134
|
+
default=param.default,
|
135
|
+
required=param.default is inspect.Parameter.empty,
|
136
|
+
)
|
137
|
+
)
|
138
|
+
return cls.model_construct(
|
139
|
+
name=name,
|
140
|
+
description=description,
|
141
|
+
signature=signature,
|
142
|
+
output=output,
|
143
|
+
parameters=parameters,
|
144
|
+
)
|
145
|
+
elif isinstance(function, BaseModel):
|
146
|
+
raise NotImplementedError("Pydantic BaseModel not implemented yet.")
|
147
|
+
else:
|
148
|
+
raise TypeError("Function must be a Callable or BaseModel")
|
149
|
+
|
150
|
+
@classmethod
|
151
|
+
def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
|
152
|
+
signature_parts = []
|
153
|
+
for field_name, field_model in model.__annotations__.items():
|
154
|
+
field_info = model.model_fields[field_name]
|
155
|
+
default_value = field_info.default
|
156
|
+
if default_value:
|
157
|
+
default_repr = repr(default_value)
|
158
|
+
signature_part = (
|
159
|
+
f"{field_name}: {field_model.__name__} = {default_repr}"
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
signature_part = f"{field_name}: {field_model.__name__}"
|
163
|
+
signature_parts.append(signature_part)
|
164
|
+
signature = f"({', '.join(signature_parts)}) -> str"
|
165
|
+
return cls.model_construct(
|
166
|
+
name=model.__class__.__name__,
|
167
|
+
description=model.__doc__ or "",
|
168
|
+
signature=signature,
|
169
|
+
output="", # TODO: Implement output
|
170
|
+
parameters=[],
|
171
|
+
)
|
172
|
+
|
173
|
+
def to_dict(self) -> dict:
|
174
|
+
schema_dict = {
|
175
|
+
"type": "function",
|
176
|
+
"function": {
|
177
|
+
"name": self.name,
|
178
|
+
"description": self.description,
|
179
|
+
"parameters": {
|
180
|
+
"type": "object",
|
181
|
+
"properties": {
|
182
|
+
k: v for param in self.parameters for k, v in param.to_dict().items()
|
183
|
+
},
|
184
|
+
"required": [
|
185
|
+
param.name for param in self.parameters if param.required
|
186
|
+
],
|
187
|
+
},
|
188
|
+
},
|
189
|
+
}
|
190
|
+
return schema_dict
|
191
|
+
|
192
|
+
def to_openai(self) -> dict:
|
193
|
+
return self.to_dict()
|
194
|
+
|
195
|
+
|
196
|
+
DEFAULT = set(["default", "openai", "ollama", "litellm"])
|
197
|
+
|
198
|
+
|
199
|
+
def get_schemas(callables: List[Callable], format: str = "default") -> list[dict]:
|
200
|
+
if format in DEFAULT:
|
201
|
+
return [
|
202
|
+
FunctionSchema.from_callable(callable).to_dict() for callable in callables
|
203
|
+
]
|
204
|
+
else:
|
205
|
+
raise ValueError(f"Format {format} not supported")
|
@@ -1,12 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: graphai-lib
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.6
|
4
4
|
Summary: Not an AI framework
|
5
5
|
Requires-Python: <3.14,>=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
7
|
-
Requires-Dist: semantic-router>=0.1.5
|
8
7
|
Requires-Dist: networkx>=3.4.2
|
9
8
|
Requires-Dist: matplotlib>=3.10.0
|
9
|
+
Requires-Dist: pydantic>=2.11.1
|
10
|
+
Requires-Dist: colorlog>=6.9.0
|
10
11
|
Provides-Extra: dev
|
11
12
|
Requires-Dist: ipykernel>=6.25.0; extra == "dev"
|
12
13
|
Requires-Dist: ruff>=0.1.5; extra == "dev"
|
@@ -16,7 +17,7 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
|
16
17
|
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
17
18
|
Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
|
18
19
|
Requires-Dist: mypy>=1.7.1; extra == "dev"
|
19
|
-
Requires-Dist:
|
20
|
+
Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
|
20
21
|
Provides-Extra: docs
|
21
22
|
Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
|
22
23
|
|
@@ -1,6 +1,7 @@
|
|
1
|
-
semantic-router>=0.1.5
|
2
1
|
networkx>=3.4.2
|
3
2
|
matplotlib>=3.10.0
|
3
|
+
pydantic>=2.11.1
|
4
|
+
colorlog>=6.9.0
|
4
5
|
|
5
6
|
[dev]
|
6
7
|
ipykernel>=6.25.0
|
@@ -11,7 +12,7 @@ pytest-cov>=4.1.0
|
|
11
12
|
pytest-xdist>=3.5.0
|
12
13
|
pytest-asyncio>=0.24.0
|
13
14
|
mypy>=1.7.1
|
14
|
-
|
15
|
+
types-networkx>=3.4.2.20250319
|
15
16
|
|
16
17
|
[docs]
|
17
18
|
|
@@ -1,13 +1,14 @@
|
|
1
1
|
[project]
|
2
2
|
name = "graphai-lib"
|
3
|
-
version = "0.0.
|
3
|
+
version = "0.0.6"
|
4
4
|
description = "Not an AI framework"
|
5
5
|
readme = "README.md"
|
6
6
|
requires-python = ">=3.10,<3.14"
|
7
7
|
dependencies = [
|
8
|
-
"semantic-router>=0.1.5",
|
9
8
|
"networkx>=3.4.2",
|
10
9
|
"matplotlib>=3.10.0",
|
10
|
+
"pydantic>=2.11.1",
|
11
|
+
"colorlog>=6.9.0",
|
11
12
|
]
|
12
13
|
|
13
14
|
[project.optional-dependencies]
|
@@ -20,7 +21,7 @@ dev = [
|
|
20
21
|
"pytest-xdist>=3.5.0",
|
21
22
|
"pytest-asyncio>=0.24.0",
|
22
23
|
"mypy>=1.7.1",
|
23
|
-
"
|
24
|
+
"types-networkx>=3.4.2.20250319",
|
24
25
|
]
|
25
26
|
docs = ["pydoc-markdown>=4.8.2 ; python_version < '3.12'"]
|
26
27
|
|
@@ -29,4 +30,4 @@ requires = ["setuptools>=61.0"]
|
|
29
30
|
build-backend = "setuptools.build_meta"
|
30
31
|
|
31
32
|
[tool.setuptools]
|
32
|
-
packages = ["graphai", "graphai.nodes"]
|
33
|
+
packages = ["graphai", "graphai.nodes"]
|
@@ -1,125 +0,0 @@
|
|
1
|
-
import inspect
|
2
|
-
from typing import Any, Callable, Dict, List, Union, Optional
|
3
|
-
from pydantic import BaseModel, Field
|
4
|
-
|
5
|
-
|
6
|
-
class Parameter(BaseModel):
|
7
|
-
class Config:
|
8
|
-
arbitrary_types_allowed = True
|
9
|
-
|
10
|
-
name: str = Field(description="The name of the parameter")
|
11
|
-
description: Optional[str] = Field(
|
12
|
-
default=None, description="The description of the parameter"
|
13
|
-
)
|
14
|
-
type: str = Field(description="The type of the parameter")
|
15
|
-
default: Any = Field(description="The default value of the parameter")
|
16
|
-
required: bool = Field(description="Whether the parameter is required")
|
17
|
-
|
18
|
-
def to_openai(self):
|
19
|
-
return {
|
20
|
-
self.name: {
|
21
|
-
"description": self.description,
|
22
|
-
"type": self.type,
|
23
|
-
}
|
24
|
-
}
|
25
|
-
|
26
|
-
class FunctionSchema:
|
27
|
-
"""Class that consumes a function and can return a schema required by
|
28
|
-
different LLMs for function calling.
|
29
|
-
"""
|
30
|
-
|
31
|
-
name: str = Field(description="The name of the function")
|
32
|
-
description: str = Field(description="The description of the function")
|
33
|
-
signature: str = Field(description="The signature of the function")
|
34
|
-
output: str = Field(description="The output of the function")
|
35
|
-
parameters: List[Parameter] = Field(description="The parameters of the function")
|
36
|
-
|
37
|
-
def __init__(self, function: Union[Callable, BaseModel]):
|
38
|
-
self.function = function
|
39
|
-
if callable(function):
|
40
|
-
self._process_function(function)
|
41
|
-
elif isinstance(function, BaseModel):
|
42
|
-
raise NotImplementedError("Pydantic BaseModel not implemented yet.")
|
43
|
-
else:
|
44
|
-
raise TypeError("Function must be a Callable or BaseModel")
|
45
|
-
|
46
|
-
def _process_function(self, function: Callable):
|
47
|
-
self.name = function.__name__
|
48
|
-
self.description = str(inspect.getdoc(function))
|
49
|
-
self.signature = str(inspect.signature(function))
|
50
|
-
self.output = str(inspect.signature(function).return_annotation)
|
51
|
-
parameters = []
|
52
|
-
for param in inspect.signature(function).parameters.values():
|
53
|
-
parameters.append(
|
54
|
-
Parameter(
|
55
|
-
name=param.name,
|
56
|
-
type=param.annotation.__name__,
|
57
|
-
default=param.default,
|
58
|
-
required=param.default is inspect.Parameter.empty,
|
59
|
-
)
|
60
|
-
)
|
61
|
-
self.parameters = parameters
|
62
|
-
|
63
|
-
def to_openai(self):
|
64
|
-
schema_dict = {
|
65
|
-
"type": "function",
|
66
|
-
"function": {
|
67
|
-
"name": self.name,
|
68
|
-
"description": self.description,
|
69
|
-
"parameters": {
|
70
|
-
"type": "object",
|
71
|
-
"properties": {
|
72
|
-
param.name: {
|
73
|
-
"description": (
|
74
|
-
param.description
|
75
|
-
if isinstance(param.description, str)
|
76
|
-
else "None provided"
|
77
|
-
),
|
78
|
-
"type": self._openai_type_mapping(param.type),
|
79
|
-
}
|
80
|
-
for param in self.parameters
|
81
|
-
},
|
82
|
-
"required": [
|
83
|
-
param.name for param in self.parameters if param.required
|
84
|
-
],
|
85
|
-
},
|
86
|
-
},
|
87
|
-
}
|
88
|
-
return schema_dict
|
89
|
-
|
90
|
-
def _openai_type_mapping(self, param_type: str) -> str:
|
91
|
-
if param_type == "int":
|
92
|
-
return "number"
|
93
|
-
elif param_type == "float":
|
94
|
-
return "number"
|
95
|
-
elif param_type == "str":
|
96
|
-
return "string"
|
97
|
-
elif param_type == "bool":
|
98
|
-
return "boolean"
|
99
|
-
else:
|
100
|
-
return "object"
|
101
|
-
|
102
|
-
|
103
|
-
def get_schema_pydantic(model: BaseModel) -> Dict[str, Any]:
|
104
|
-
signature_parts = []
|
105
|
-
for field_name, field_model in model.__annotations__.items():
|
106
|
-
field_info = model.__fields__[field_name]
|
107
|
-
default_value = field_info.default
|
108
|
-
|
109
|
-
if default_value:
|
110
|
-
default_repr = repr(default_value)
|
111
|
-
signature_part = (
|
112
|
-
f"{field_name}: {field_model.__name__} = {default_repr}"
|
113
|
-
)
|
114
|
-
else:
|
115
|
-
signature_part = f"{field_name}: {field_model.__name__}"
|
116
|
-
|
117
|
-
signature_parts.append(signature_part)
|
118
|
-
signature = f"({', '.join(signature_parts)}) -> str"
|
119
|
-
schema = FunctionSchema(
|
120
|
-
name=model.__class__.__name__,
|
121
|
-
description=model.__doc__,
|
122
|
-
signature=signature,
|
123
|
-
output="", # TODO: Implement output
|
124
|
-
)
|
125
|
-
return schema
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|