graphai-lib 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphai/__init__.py +1 -1
- graphai/callback.py +29 -44
- graphai/graph.py +42 -20
- graphai/nodes/__init__.py +1 -1
- graphai/nodes/base.py +53 -16
- graphai/utils.py +148 -68
- {graphai_lib-0.0.4.dist-info → graphai_lib-0.0.5.dist-info}/METADATA +4 -3
- graphai_lib-0.0.5.dist-info/RECORD +10 -0
- {graphai_lib-0.0.4.dist-info → graphai_lib-0.0.5.dist-info}/WHEEL +1 -1
- graphai_lib-0.0.4.dist-info/RECORD +0 -10
- {graphai_lib-0.0.4.dist-info → graphai_lib-0.0.5.dist-info}/top_level.txt +0 -0
graphai/__init__.py
CHANGED
graphai/callback.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1
1
|
import asyncio
|
2
2
|
from pydantic import Field
|
3
|
-
from typing import Optional
|
3
|
+
from typing import Optional, Any
|
4
4
|
from collections.abc import AsyncIterator
|
5
|
-
from semantic_router.utils.logger import logger
|
6
5
|
|
7
6
|
|
8
7
|
log_stream = True
|
9
8
|
|
9
|
+
|
10
10
|
class Callback:
|
11
11
|
identifier: str = Field(
|
12
12
|
default="graphai",
|
@@ -14,7 +14,7 @@ class Callback:
|
|
14
14
|
"The identifier for special tokens. This allows us to easily "
|
15
15
|
"identify special tokens in the stream so we can handle them "
|
16
16
|
"correctly in any downstream process."
|
17
|
-
)
|
17
|
+
),
|
18
18
|
)
|
19
19
|
special_token_format: str = Field(
|
20
20
|
default="<{identifier}:{token}:{params}>",
|
@@ -32,8 +32,8 @@ class Callback:
|
|
32
32
|
examples=[
|
33
33
|
"<{identifier}:{token}:{params}>",
|
34
34
|
"<[{identifier} | {token} | {params}]>",
|
35
|
-
"<{token}:{params}>"
|
36
|
-
]
|
35
|
+
"<{token}:{params}>",
|
36
|
+
],
|
37
37
|
)
|
38
38
|
token_format: str = Field(
|
39
39
|
default="{token}",
|
@@ -41,27 +41,23 @@ class Callback:
|
|
41
41
|
"The format for streamed tokens. This is used to format the "
|
42
42
|
"tokens typically returned from LLMs. By default, no special "
|
43
43
|
"formatting is applied."
|
44
|
-
)
|
44
|
+
),
|
45
45
|
)
|
46
46
|
_first_token: bool = Field(
|
47
47
|
default=True,
|
48
48
|
description="Whether this is the first token in the stream.",
|
49
|
-
exclude=True
|
49
|
+
exclude=True,
|
50
50
|
)
|
51
51
|
_current_node_name: Optional[str] = Field(
|
52
|
-
default=None,
|
53
|
-
description="The name of the current node.",
|
54
|
-
exclude=True
|
52
|
+
default=None, description="The name of the current node.", exclude=True
|
55
53
|
)
|
56
54
|
_active: bool = Field(
|
57
|
-
default=True,
|
58
|
-
description="Whether the callback is active.",
|
59
|
-
exclude=True
|
55
|
+
default=True, description="Whether the callback is active.", exclude=True
|
60
56
|
)
|
61
57
|
_done: bool = Field(
|
62
58
|
default=False,
|
63
59
|
description="Whether the stream is done and should be closed.",
|
64
|
-
exclude=True
|
60
|
+
exclude=True,
|
65
61
|
)
|
66
62
|
queue: asyncio.Queue
|
67
63
|
|
@@ -83,7 +79,7 @@ class Callback:
|
|
83
79
|
@property
|
84
80
|
def first_token(self) -> bool:
|
85
81
|
return self._first_token
|
86
|
-
|
82
|
+
|
87
83
|
@first_token.setter
|
88
84
|
def first_token(self, value: bool):
|
89
85
|
self._first_token = value
|
@@ -91,7 +87,7 @@ class Callback:
|
|
91
87
|
@property
|
92
88
|
def current_node_name(self) -> Optional[str]:
|
93
89
|
return self._current_node_name
|
94
|
-
|
90
|
+
|
95
91
|
@current_node_name.setter
|
96
92
|
def current_node_name(self, value: Optional[str]):
|
97
93
|
self._current_node_name = value
|
@@ -99,7 +95,7 @@ class Callback:
|
|
99
95
|
@property
|
100
96
|
def active(self) -> bool:
|
101
97
|
return self._active
|
102
|
-
|
98
|
+
|
103
99
|
@active.setter
|
104
100
|
def active(self, value: bool):
|
105
101
|
self._active = value
|
@@ -110,7 +106,7 @@ class Callback:
|
|
110
106
|
self._check_node_name(node_name=node_name)
|
111
107
|
# otherwise we just assume node is correct and send token
|
112
108
|
self.queue.put_nowait(token)
|
113
|
-
|
109
|
+
|
114
110
|
async def acall(self, token: str, node_name: Optional[str] = None):
|
115
111
|
# TODO JB: do we need to have `node_name` param?
|
116
112
|
if self._done:
|
@@ -118,16 +114,13 @@ class Callback:
|
|
118
114
|
self._check_node_name(node_name=node_name)
|
119
115
|
# otherwise we just assume node is correct and send token
|
120
116
|
self.queue.put_nowait(token)
|
121
|
-
|
117
|
+
|
122
118
|
async def aiter(self) -> AsyncIterator[str]:
|
123
119
|
"""Used by receiver to get the tokens from the stream queue. Creates
|
124
120
|
a generator that yields tokens from the queue until the END token is
|
125
121
|
received.
|
126
122
|
"""
|
127
|
-
end_token = await self._build_special_token(
|
128
|
-
name="END",
|
129
|
-
params=None
|
130
|
-
)
|
123
|
+
end_token = await self._build_special_token(name="END", params=None)
|
131
124
|
while True: # Keep going until we see the END token
|
132
125
|
try:
|
133
126
|
if self._done and self.queue.empty():
|
@@ -142,8 +135,7 @@ class Callback:
|
|
142
135
|
self._done = True # Mark as done after processing all tokens
|
143
136
|
|
144
137
|
async def start_node(self, node_name: str, active: bool = True):
|
145
|
-
"""Starts a new node and emits the start token.
|
146
|
-
"""
|
138
|
+
"""Starts a new node and emits the start token."""
|
147
139
|
if self._done:
|
148
140
|
raise RuntimeError("Cannot start node on a closed stream")
|
149
141
|
self.current_node_name = node_name
|
@@ -152,27 +144,23 @@ class Callback:
|
|
152
144
|
self.active = active
|
153
145
|
if self.active:
|
154
146
|
token = await self._build_special_token(
|
155
|
-
name=f"{self.current_node_name}:start",
|
156
|
-
params=None
|
147
|
+
name=f"{self.current_node_name}:start", params=None
|
157
148
|
)
|
158
149
|
self.queue.put_nowait(token)
|
159
150
|
# TODO JB: should we use two tokens here?
|
160
151
|
node_token = await self._build_special_token(
|
161
|
-
name=self.current_node_name,
|
162
|
-
params=None
|
152
|
+
name=self.current_node_name, params=None
|
163
153
|
)
|
164
154
|
self.queue.put_nowait(node_token)
|
165
|
-
|
155
|
+
|
166
156
|
async def end_node(self, node_name: str):
|
167
|
-
"""Emits the end token for the current node.
|
168
|
-
"""
|
157
|
+
"""Emits the end token for the current node."""
|
169
158
|
if self._done:
|
170
159
|
raise RuntimeError("Cannot end node on a closed stream")
|
171
|
-
#self.current_node_name = node_name
|
160
|
+
# self.current_node_name = node_name
|
172
161
|
if self.active:
|
173
162
|
node_token = await self._build_special_token(
|
174
|
-
name=f"{self.current_node_name}:end",
|
175
|
-
params=None
|
163
|
+
name=f"{self.current_node_name}:end", params=None
|
176
164
|
)
|
177
165
|
self.queue.put_nowait(node_token)
|
178
166
|
|
@@ -182,10 +170,7 @@ class Callback:
|
|
182
170
|
"""
|
183
171
|
if self._done:
|
184
172
|
return
|
185
|
-
end_token = await self._build_special_token(
|
186
|
-
name="END",
|
187
|
-
params=None
|
188
|
-
)
|
173
|
+
end_token = await self._build_special_token(name="END", params=None)
|
189
174
|
self._done = True # Set done before putting the end token
|
190
175
|
self.queue.put_nowait(end_token)
|
191
176
|
# Don't wait for queue.join() as it can cause deadlock
|
@@ -198,8 +183,10 @@ class Callback:
|
|
198
183
|
raise ValueError(
|
199
184
|
f"Node name mismatch: {self.current_node_name} != {node_name}"
|
200
185
|
)
|
201
|
-
|
202
|
-
async def _build_special_token(
|
186
|
+
|
187
|
+
async def _build_special_token(
|
188
|
+
self, name: str, params: dict[str, Any] | None = None
|
189
|
+
):
|
203
190
|
if params:
|
204
191
|
params_str = ",".join([f"{k}={v}" for k, v in params.items()])
|
205
192
|
else:
|
@@ -209,7 +196,5 @@ class Callback:
|
|
209
196
|
else:
|
210
197
|
identifier = ""
|
211
198
|
return self.special_token_format.format(
|
212
|
-
identifier=identifier,
|
213
|
-
token=name,
|
214
|
-
params=params_str
|
199
|
+
identifier=identifier, token=name, params=params_str
|
215
200
|
)
|
graphai/graph.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from typing import List, Dict, Any, Optional
|
2
2
|
from graphai.nodes.base import _Node
|
3
3
|
from graphai.callback import Callback
|
4
|
-
from
|
4
|
+
from graphai.utils import logger
|
5
5
|
|
6
6
|
|
7
7
|
class Graph:
|
8
|
-
def __init__(
|
8
|
+
def __init__(
|
9
|
+
self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
|
10
|
+
):
|
9
11
|
self.nodes: Dict[str, _Node] = {}
|
10
12
|
self.edges: List[Any] = []
|
11
13
|
self.start_node: Optional[_Node] = None
|
@@ -49,7 +51,7 @@ class Graph:
|
|
49
51
|
|
50
52
|
def add_edge(self, source: _Node | str, destination: _Node | str):
|
51
53
|
"""Adds an edge between two nodes that already exist in the graph.
|
52
|
-
|
54
|
+
|
53
55
|
Args:
|
54
56
|
source: The source node or its name.
|
55
57
|
destination: The destination node or its name.
|
@@ -60,7 +62,7 @@ class Graph:
|
|
60
62
|
source_node = self.nodes.get(source)
|
61
63
|
else:
|
62
64
|
# Check if it's a node-like object by looking for required attributes
|
63
|
-
if hasattr(source,
|
65
|
+
if hasattr(source, "name"):
|
64
66
|
source_node = self.nodes.get(source.name)
|
65
67
|
if source_node is None:
|
66
68
|
raise ValueError(
|
@@ -71,7 +73,7 @@ class Graph:
|
|
71
73
|
destination_node = self.nodes.get(destination)
|
72
74
|
else:
|
73
75
|
# Check if it's a node-like object by looking for required attributes
|
74
|
-
if hasattr(destination,
|
76
|
+
if hasattr(destination, "name"):
|
75
77
|
destination_node = self.nodes.get(destination.name)
|
76
78
|
if destination_node is None:
|
77
79
|
raise ValueError(
|
@@ -80,7 +82,9 @@ class Graph:
|
|
80
82
|
edge = Edge(source_node, destination_node)
|
81
83
|
self.edges.append(edge)
|
82
84
|
|
83
|
-
def add_router(
|
85
|
+
def add_router(
|
86
|
+
self, sources: list[_Node], router: _Node, destinations: List[_Node]
|
87
|
+
):
|
84
88
|
if not router.is_router:
|
85
89
|
raise TypeError("A router object must be passed to the router parameter.")
|
86
90
|
[self.add_edge(source, router) for source in sources]
|
@@ -126,7 +130,9 @@ class Graph:
|
|
126
130
|
# add callback tokens and param here if we are streaming
|
127
131
|
await self.callback.start_node(node_name=current_node.name)
|
128
132
|
# Include graph's internal state in the node execution context
|
129
|
-
output = await current_node.invoke(
|
133
|
+
output = await current_node.invoke(
|
134
|
+
input=state, callback=self.callback, state=self.state
|
135
|
+
)
|
130
136
|
self._validate_output(output=output, node_name=current_node.name)
|
131
137
|
await self.callback.end_node(node_name=current_node.name)
|
132
138
|
else:
|
@@ -164,13 +170,13 @@ class Graph:
|
|
164
170
|
|
165
171
|
def _get_node_by_name(self, node_name: str) -> _Node:
|
166
172
|
"""Get a node by its name.
|
167
|
-
|
173
|
+
|
168
174
|
Args:
|
169
175
|
node_name: The name of the node to find.
|
170
|
-
|
176
|
+
|
171
177
|
Returns:
|
172
178
|
The node with the given name.
|
173
|
-
|
179
|
+
|
174
180
|
Raises:
|
175
181
|
Exception: If no node with the given name is found.
|
176
182
|
"""
|
@@ -191,12 +197,16 @@ class Graph:
|
|
191
197
|
try:
|
192
198
|
import networkx as nx
|
193
199
|
except ImportError:
|
194
|
-
raise ImportError(
|
200
|
+
raise ImportError(
|
201
|
+
"NetworkX is required for visualization. Please install it with 'pip install networkx'."
|
202
|
+
)
|
195
203
|
|
196
204
|
try:
|
197
205
|
import matplotlib.pyplot as plt
|
198
206
|
except ImportError:
|
199
|
-
raise ImportError(
|
207
|
+
raise ImportError(
|
208
|
+
"Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
|
209
|
+
)
|
200
210
|
|
201
211
|
G = nx.DiGraph()
|
202
212
|
|
@@ -207,7 +217,9 @@ class Graph:
|
|
207
217
|
G.add_edge(edge.source.name, edge.destination.name)
|
208
218
|
|
209
219
|
if nx.is_directed_acyclic_graph(G):
|
210
|
-
logger.info(
|
220
|
+
logger.info(
|
221
|
+
"The graph is acyclic. Visualization will use a topological layout."
|
222
|
+
)
|
211
223
|
# Use topological layout if acyclic
|
212
224
|
# Compute the topological generations
|
213
225
|
generations = list(nx.topological_generations(G))
|
@@ -241,20 +253,30 @@ class Graph:
|
|
241
253
|
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
|
242
254
|
|
243
255
|
else:
|
244
|
-
print(
|
256
|
+
print(
|
257
|
+
"Warning: The graph contains cycles. Visualization will use a spring layout."
|
258
|
+
)
|
245
259
|
pos = nx.spring_layout(G, k=1, iterations=50)
|
246
260
|
|
247
261
|
plt.figure(figsize=(8, 6))
|
248
|
-
nx.draw(
|
249
|
-
|
250
|
-
|
262
|
+
nx.draw(
|
263
|
+
G,
|
264
|
+
pos,
|
265
|
+
with_labels=True,
|
266
|
+
node_color="lightblue",
|
267
|
+
node_size=3000,
|
268
|
+
font_size=8,
|
269
|
+
font_weight="bold",
|
270
|
+
arrows=True,
|
271
|
+
edge_color="gray",
|
272
|
+
arrowsize=20,
|
273
|
+
)
|
251
274
|
|
252
|
-
plt.axis(
|
275
|
+
plt.axis("off")
|
253
276
|
plt.show()
|
254
277
|
|
255
278
|
|
256
|
-
|
257
279
|
class Edge:
|
258
280
|
def __init__(self, source, destination):
|
259
281
|
self.source = source
|
260
|
-
self.destination = destination
|
282
|
+
self.destination = destination
|
graphai/nodes/__init__.py
CHANGED
graphai/nodes/base.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import inspect
|
2
2
|
from typing import Any, Callable, Dict, Optional
|
3
|
+
from pydantic import Field
|
3
4
|
|
4
5
|
from graphai.callback import Callback
|
5
6
|
from graphai.utils import FunctionSchema
|
@@ -9,7 +10,11 @@ class NodeMeta(type):
|
|
9
10
|
@staticmethod
|
10
11
|
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
|
11
12
|
init_signature = inspect.signature(cls_type.__init__)
|
12
|
-
init_params = {
|
13
|
+
init_params = {
|
14
|
+
name: arg
|
15
|
+
for name, arg in init_signature.parameters.items()
|
16
|
+
if name != "self"
|
17
|
+
}
|
13
18
|
return init_params
|
14
19
|
|
15
20
|
def __call__(cls, *args, **kwargs):
|
@@ -33,18 +38,32 @@ class _Node:
|
|
33
38
|
stream: bool = False,
|
34
39
|
name: str | None = None,
|
35
40
|
) -> Callable:
|
36
|
-
"""Decorator validating node structure.
|
37
|
-
"""
|
41
|
+
"""Decorator validating node structure."""
|
38
42
|
if not callable(func):
|
39
43
|
raise ValueError("Node must be a callable function.")
|
40
|
-
|
44
|
+
|
41
45
|
func_signature = inspect.signature(func)
|
42
|
-
schema = FunctionSchema(func)
|
46
|
+
schema: FunctionSchema = FunctionSchema.from_callable(func)
|
43
47
|
|
44
48
|
class NodeClass:
|
45
49
|
_func_signature = func_signature
|
46
|
-
is_router =
|
47
|
-
|
50
|
+
is_router: bool = Field(
|
51
|
+
default=False, description="Whether the node is a router."
|
52
|
+
)
|
53
|
+
# following attributes will be overridden by the decorator
|
54
|
+
name: str | None = Field(default=None, description="The name of the node.")
|
55
|
+
is_start: bool = Field(
|
56
|
+
default=False, description="Whether the node is the start of the graph."
|
57
|
+
)
|
58
|
+
is_end: bool = Field(
|
59
|
+
default=False, description="Whether the node is the end of the graph."
|
60
|
+
)
|
61
|
+
schema: FunctionSchema | None = Field(
|
62
|
+
default=None, description="The schema of the node."
|
63
|
+
)
|
64
|
+
stream: bool = Field(
|
65
|
+
default=False, description="Whether the node includes streaming object."
|
66
|
+
)
|
48
67
|
|
49
68
|
def __init__(self):
|
50
69
|
self._expected_params = set(self._func_signature.parameters.keys())
|
@@ -56,9 +75,13 @@ class _Node:
|
|
56
75
|
|
57
76
|
async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
|
58
77
|
# filter out unexpected keyword args
|
59
|
-
expected_kwargs = {
|
78
|
+
expected_kwargs = {
|
79
|
+
k: v for k, v in kwargs.items() if k in self._expected_params
|
80
|
+
}
|
60
81
|
# Convert args to kwargs based on the function signature
|
61
|
-
args_names = list(self._func_signature.parameters.keys())[
|
82
|
+
args_names = list(self._func_signature.parameters.keys())[
|
83
|
+
1 : len(args) + 1
|
84
|
+
] # skip 'self'
|
62
85
|
expected_args_kwargs = dict(zip(args_names, args))
|
63
86
|
# Combine filtered args and kwargs
|
64
87
|
combined_params = {**expected_args_kwargs, **expected_kwargs}
|
@@ -87,7 +110,6 @@ class _Node:
|
|
87
110
|
)
|
88
111
|
return filtered_params
|
89
112
|
|
90
|
-
|
91
113
|
@classmethod
|
92
114
|
def get_signature(cls):
|
93
115
|
"""Returns the signature of the decorated function as LLM readable
|
@@ -97,15 +119,24 @@ class _Node:
|
|
97
119
|
if NodeClass._func_signature:
|
98
120
|
for param in NodeClass._func_signature.parameters.values():
|
99
121
|
if param.default is param.empty:
|
100
|
-
signature_components.append(
|
122
|
+
signature_components.append(
|
123
|
+
f"{param.name}: {param.annotation}"
|
124
|
+
)
|
101
125
|
else:
|
102
|
-
signature_components.append(
|
126
|
+
signature_components.append(
|
127
|
+
f"{param.name}: {param.annotation} = {param.default}"
|
128
|
+
)
|
103
129
|
else:
|
104
130
|
return "No signature"
|
105
131
|
return "\n".join(signature_components)
|
106
132
|
|
107
133
|
@classmethod
|
108
|
-
async def invoke(
|
134
|
+
async def invoke(
|
135
|
+
cls,
|
136
|
+
input: Dict[str, Any],
|
137
|
+
callback: Optional[Callback] = None,
|
138
|
+
state: Optional[Dict[str, Any]] = None,
|
139
|
+
):
|
109
140
|
if callback:
|
110
141
|
if stream:
|
111
142
|
input["callback"] = callback
|
@@ -116,13 +147,16 @@ class _Node:
|
|
116
147
|
# Add state to the input if present and the parameter exists in the function signature
|
117
148
|
if state is not None and "state" in cls._func_signature.parameters:
|
118
149
|
input["state"] = state
|
119
|
-
|
150
|
+
|
120
151
|
instance = cls()
|
121
152
|
out = await instance.execute(**input)
|
122
153
|
return out
|
123
154
|
|
124
155
|
NodeClass.__name__ = func.__name__
|
125
|
-
|
156
|
+
node_class_name = name or func.__name__
|
157
|
+
if node_class_name is None:
|
158
|
+
raise ValueError("Unexpected error: node name not set.")
|
159
|
+
NodeClass.name = node_class_name
|
126
160
|
NodeClass.__doc__ = func.__doc__
|
127
161
|
NodeClass.is_start = start
|
128
162
|
NodeClass.is_end = end
|
@@ -141,8 +175,11 @@ class _Node:
|
|
141
175
|
):
|
142
176
|
# We must wrap the call to the decorator in a function for it to work
|
143
177
|
# correctly with or without parenthesis
|
144
|
-
def wrap(
|
178
|
+
def wrap(
|
179
|
+
func: Callable, start=start, end=end, stream=stream, name=name
|
180
|
+
) -> Callable:
|
145
181
|
return self._node(func=func, start=start, end=end, stream=stream, name=name)
|
182
|
+
|
146
183
|
if func:
|
147
184
|
# Decorator is called without parenthesis
|
148
185
|
return wrap(func=func, start=start, end=end, stream=stream, name=name)
|
graphai/utils.py
CHANGED
@@ -1,11 +1,81 @@
|
|
1
1
|
import inspect
|
2
|
-
from typing import Any, Callable,
|
2
|
+
from typing import Any, Callable, List, Optional
|
3
3
|
from pydantic import BaseModel, Field
|
4
|
+
import logging
|
5
|
+
|
6
|
+
import colorlog
|
7
|
+
|
8
|
+
|
9
|
+
class CustomFormatter(colorlog.ColoredFormatter):
|
10
|
+
"""Custom formatter for the logger."""
|
11
|
+
|
12
|
+
def __init__(self):
|
13
|
+
super().__init__(
|
14
|
+
"%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
|
15
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
16
|
+
log_colors={
|
17
|
+
"DEBUG": "cyan",
|
18
|
+
"INFO": "green",
|
19
|
+
"WARNING": "yellow",
|
20
|
+
"ERROR": "red",
|
21
|
+
"CRITICAL": "bold_red",
|
22
|
+
},
|
23
|
+
reset=True,
|
24
|
+
style="%",
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
def add_coloured_handler(logger):
|
29
|
+
"""Add a coloured handler to the logger."""
|
30
|
+
formatter = CustomFormatter()
|
31
|
+
console_handler = logging.StreamHandler()
|
32
|
+
console_handler.setFormatter(formatter)
|
33
|
+
logger.addHandler(console_handler)
|
34
|
+
return logger
|
35
|
+
|
36
|
+
|
37
|
+
def setup_custom_logger(name):
|
38
|
+
"""Setup a custom logger."""
|
39
|
+
logger = logging.getLogger(name)
|
40
|
+
|
41
|
+
if not logger.hasHandlers():
|
42
|
+
add_coloured_handler(logger)
|
43
|
+
logger.setLevel(logging.INFO)
|
44
|
+
logger.propagate = False
|
45
|
+
|
46
|
+
return logger
|
47
|
+
|
48
|
+
|
49
|
+
logger: logging.Logger = setup_custom_logger(__name__)
|
50
|
+
|
51
|
+
|
52
|
+
def openai_type_mapping(param_type: str) -> str:
|
53
|
+
if param_type == "int":
|
54
|
+
return "number"
|
55
|
+
elif param_type == "float":
|
56
|
+
return "number"
|
57
|
+
elif param_type == "str":
|
58
|
+
return "string"
|
59
|
+
elif param_type == "bool":
|
60
|
+
return "boolean"
|
61
|
+
else:
|
62
|
+
return "object"
|
4
63
|
|
5
64
|
|
6
65
|
class Parameter(BaseModel):
|
7
|
-
|
8
|
-
|
66
|
+
"""Parameter for a function.
|
67
|
+
|
68
|
+
:param name: The name of the parameter.
|
69
|
+
:type name: str
|
70
|
+
:param description: The description of the parameter.
|
71
|
+
:type description: Optional[str]
|
72
|
+
:param type: The type of the parameter.
|
73
|
+
:type type: str
|
74
|
+
:param default: The default value of the parameter.
|
75
|
+
:type default: Any
|
76
|
+
:param required: Whether the parameter is required.
|
77
|
+
:type required: bool
|
78
|
+
"""
|
9
79
|
|
10
80
|
name: str = Field(description="The name of the parameter")
|
11
81
|
description: Optional[str] = Field(
|
@@ -15,15 +85,22 @@ class Parameter(BaseModel):
|
|
15
85
|
default: Any = Field(description="The default value of the parameter")
|
16
86
|
required: bool = Field(description="Whether the parameter is required")
|
17
87
|
|
18
|
-
def
|
88
|
+
def to_dict(self) -> dict[str, Any]:
|
89
|
+
"""Convert the parameter to a dictionary for an standard dictionary-based function schema.
|
90
|
+
This is the most common format used by LLM providers, including OpenAI, Ollama, and others.
|
91
|
+
|
92
|
+
:return: The parameter in dictionary format.
|
93
|
+
:rtype: dict[str, Any]
|
94
|
+
"""
|
19
95
|
return {
|
20
96
|
self.name: {
|
21
97
|
"description": self.description,
|
22
|
-
"type": self.type,
|
98
|
+
"type": openai_type_mapping(self.type),
|
23
99
|
}
|
24
100
|
}
|
25
101
|
|
26
|
-
|
102
|
+
|
103
|
+
class FunctionSchema(BaseModel):
|
27
104
|
"""Class that consumes a function and can return a schema required by
|
28
105
|
different LLMs for function calling.
|
29
106
|
"""
|
@@ -32,35 +109,68 @@ class FunctionSchema:
|
|
32
109
|
description: str = Field(description="The description of the function")
|
33
110
|
signature: str = Field(description="The signature of the function")
|
34
111
|
output: str = Field(description="The output of the function")
|
35
|
-
parameters:
|
112
|
+
parameters: list[Parameter] = Field(description="The parameters of the function")
|
113
|
+
|
114
|
+
@classmethod
|
115
|
+
def from_callable(cls, function: Callable) -> "FunctionSchema":
|
116
|
+
"""Initialize the FunctionSchema.
|
36
117
|
|
37
|
-
|
38
|
-
|
118
|
+
:param function: The function to consume.
|
119
|
+
:type function: Callable
|
120
|
+
"""
|
39
121
|
if callable(function):
|
40
|
-
|
122
|
+
name = function.__name__
|
123
|
+
description = str(inspect.getdoc(function))
|
124
|
+
if description is None or description == "":
|
125
|
+
logger.warning(f"Function {name} has no docstring")
|
126
|
+
signature = str(inspect.signature(function))
|
127
|
+
output = str(inspect.signature(function).return_annotation)
|
128
|
+
parameters = []
|
129
|
+
for param in inspect.signature(function).parameters.values():
|
130
|
+
parameters.append(
|
131
|
+
Parameter(
|
132
|
+
name=param.name,
|
133
|
+
type=param.annotation.__name__,
|
134
|
+
default=param.default,
|
135
|
+
required=param.default is inspect.Parameter.empty,
|
136
|
+
)
|
137
|
+
)
|
138
|
+
return cls.model_construct(
|
139
|
+
name=name,
|
140
|
+
description=description,
|
141
|
+
signature=signature,
|
142
|
+
output=output,
|
143
|
+
parameters=parameters,
|
144
|
+
)
|
41
145
|
elif isinstance(function, BaseModel):
|
42
146
|
raise NotImplementedError("Pydantic BaseModel not implemented yet.")
|
43
147
|
else:
|
44
148
|
raise TypeError("Function must be a Callable or BaseModel")
|
45
149
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
type=param.annotation.__name__,
|
57
|
-
default=param.default,
|
58
|
-
required=param.default is inspect.Parameter.empty,
|
150
|
+
@classmethod
|
151
|
+
def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
|
152
|
+
signature_parts = []
|
153
|
+
for field_name, field_model in model.__annotations__.items():
|
154
|
+
field_info = model.model_fields[field_name]
|
155
|
+
default_value = field_info.default
|
156
|
+
if default_value:
|
157
|
+
default_repr = repr(default_value)
|
158
|
+
signature_part = (
|
159
|
+
f"{field_name}: {field_model.__name__} = {default_repr}"
|
59
160
|
)
|
60
|
-
|
61
|
-
|
161
|
+
else:
|
162
|
+
signature_part = f"{field_name}: {field_model.__name__}"
|
163
|
+
signature_parts.append(signature_part)
|
164
|
+
signature = f"({', '.join(signature_parts)}) -> str"
|
165
|
+
return cls.model_construct(
|
166
|
+
name=model.__class__.__name__,
|
167
|
+
description=model.__doc__ or "",
|
168
|
+
signature=signature,
|
169
|
+
output="", # TODO: Implement output
|
170
|
+
parameters=[],
|
171
|
+
)
|
62
172
|
|
63
|
-
def
|
173
|
+
def to_dict(self) -> dict:
|
64
174
|
schema_dict = {
|
65
175
|
"type": "function",
|
66
176
|
"function": {
|
@@ -69,15 +179,7 @@ class FunctionSchema:
|
|
69
179
|
"parameters": {
|
70
180
|
"type": "object",
|
71
181
|
"properties": {
|
72
|
-
param.
|
73
|
-
"description": (
|
74
|
-
param.description
|
75
|
-
if isinstance(param.description, str)
|
76
|
-
else "None provided"
|
77
|
-
),
|
78
|
-
"type": self._openai_type_mapping(param.type),
|
79
|
-
}
|
80
|
-
for param in self.parameters
|
182
|
+
k: v for param in self.parameters for k, v in param.to_dict().items()
|
81
183
|
},
|
82
184
|
"required": [
|
83
185
|
param.name for param in self.parameters if param.required
|
@@ -87,39 +189,17 @@ class FunctionSchema:
|
|
87
189
|
}
|
88
190
|
return schema_dict
|
89
191
|
|
90
|
-
def
|
91
|
-
|
92
|
-
return "number"
|
93
|
-
elif param_type == "float":
|
94
|
-
return "number"
|
95
|
-
elif param_type == "str":
|
96
|
-
return "string"
|
97
|
-
elif param_type == "bool":
|
98
|
-
return "boolean"
|
99
|
-
else:
|
100
|
-
return "object"
|
192
|
+
def to_openai(self) -> dict:
|
193
|
+
return self.to_dict()
|
101
194
|
|
102
195
|
|
103
|
-
|
104
|
-
signature_parts = []
|
105
|
-
for field_name, field_model in model.__annotations__.items():
|
106
|
-
field_info = model.__fields__[field_name]
|
107
|
-
default_value = field_info.default
|
196
|
+
DEFAULT = set(["default", "openai", "ollama", "litellm"])
|
108
197
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
)
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
signature_parts.append(signature_part)
|
118
|
-
signature = f"({', '.join(signature_parts)}) -> str"
|
119
|
-
schema = FunctionSchema(
|
120
|
-
name=model.__class__.__name__,
|
121
|
-
description=model.__doc__,
|
122
|
-
signature=signature,
|
123
|
-
output="", # TODO: Implement output
|
124
|
-
)
|
125
|
-
return schema
|
198
|
+
|
199
|
+
def get_schemas(callables: List[Callable], format: str = "default") -> list[dict]:
|
200
|
+
if format in DEFAULT:
|
201
|
+
return [
|
202
|
+
FunctionSchema.from_callable(callable).to_dict() for callable in callables
|
203
|
+
]
|
204
|
+
else:
|
205
|
+
raise ValueError(f"Format {format} not supported")
|
@@ -1,12 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: graphai-lib
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.5
|
4
4
|
Summary: Not an AI framework
|
5
5
|
Requires-Python: <3.14,>=3.10
|
6
6
|
Description-Content-Type: text/markdown
|
7
|
-
Requires-Dist: semantic-router>=0.1.5
|
8
7
|
Requires-Dist: networkx>=3.4.2
|
9
8
|
Requires-Dist: matplotlib>=3.10.0
|
9
|
+
Requires-Dist: pydantic>=2.11.1
|
10
|
+
Requires-Dist: colorlog>=6.9.0
|
10
11
|
Provides-Extra: dev
|
11
12
|
Requires-Dist: ipykernel>=6.25.0; extra == "dev"
|
12
13
|
Requires-Dist: ruff>=0.1.5; extra == "dev"
|
@@ -16,7 +17,7 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
|
16
17
|
Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
|
17
18
|
Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
|
18
19
|
Requires-Dist: mypy>=1.7.1; extra == "dev"
|
19
|
-
Requires-Dist:
|
20
|
+
Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
|
20
21
|
Provides-Extra: docs
|
21
22
|
Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
|
22
23
|
|
@@ -0,0 +1,10 @@
|
|
1
|
+
graphai/__init__.py,sha256=kZJ21W6gwN-eRzvrWQf8xDTPCIWTIuyWq1IGDS9tn7Y,110
|
2
|
+
graphai/callback.py,sha256=M2gEpj7uVvANg2dVgxKFMUSPIM362YSilShMIFfrr8s,7351
|
3
|
+
graphai/graph.py,sha256=5RgG5mYE8xyRCa68w84MXKwCBOnbdXMmgg1Zx5kZN9k,10563
|
4
|
+
graphai/utils.py,sha256=LlL-Wx643nIeRFAl2xcv0crNQcA_0563epRo8ZsyL40,6898
|
5
|
+
graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
|
6
|
+
graphai/nodes/base.py,sha256=-ZOfJhxews5CGutYB5lfoIVvJ6dqdYJWeeKzDMz9odg,7624
|
7
|
+
graphai_lib-0.0.5.dist-info/METADATA,sha256=SO5iiJuOrY5uDCPa44IiRoPTG4kbms8Qtb29aCtHqmw,1006
|
8
|
+
graphai_lib-0.0.5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
9
|
+
graphai_lib-0.0.5.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
|
10
|
+
graphai_lib-0.0.5.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
graphai/__init__.py,sha256=EHigFOWewDXLZXbdfjZH9kdLPhw6NT0ChS77lNAVAA8,109
|
2
|
-
graphai/callback.py,sha256=K-h44pyL2VLXwJzIB_bcVYp5R6xv8zNca5FmN6994Uk,7598
|
3
|
-
graphai/graph.py,sha256=EALHEhbXAaJmTvm7cL3Tdh0moRIw7lIyJDCNnCts2QA,10335
|
4
|
-
graphai/utils.py,sha256=zrgpk82rIn7lwh631KhN-OgMAJMdbm0k5GPL1eMf2sQ,4522
|
5
|
-
graphai/nodes/__init__.py,sha256=4826Ubk5yUfbVH7F8DmoTKQyax624Q2QJHsGxqgQ_ng,73
|
6
|
-
graphai/nodes/base.py,sha256=SZdYhFfXdtFmabFbMRcEGd8_h8w-g6s4I7hMEo6JCk8,6331
|
7
|
-
graphai_lib-0.0.4.dist-info/METADATA,sha256=wfxK82ZO-slBl-xVNmdbYGVvPo7YKm8-SLfdrkNWtKs,982
|
8
|
-
graphai_lib-0.0.4.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
|
9
|
-
graphai_lib-0.0.4.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
|
10
|
-
graphai_lib-0.0.4.dist-info/RECORD,,
|
File without changes
|