graphai-lib 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphai_lib-0.0.1/PKG-INFO +25 -0
- graphai_lib-0.0.1/README.md +6 -0
- graphai_lib-0.0.1/graphai/__init__.py +4 -0
- graphai_lib-0.0.1/graphai/callback.py +63 -0
- graphai_lib-0.0.1/graphai/graph.py +198 -0
- graphai_lib-0.0.1/graphai/nodes/__init__.py +3 -0
- graphai_lib-0.0.1/graphai/nodes/base.py +148 -0
- graphai_lib-0.0.1/graphai/utils.py +125 -0
- graphai_lib-0.0.1/pyproject.toml +29 -0
@@ -0,0 +1,25 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: graphai-lib
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary:
|
5
|
+
License: MIT
|
6
|
+
Author: Aurelio AI
|
7
|
+
Author-email: hello@aurelio.ai
|
8
|
+
Requires-Python: >=3.10,<3.14
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
11
|
+
Classifier: Programming Language :: Python :: 3.10
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
15
|
+
Requires-Dist: matplotlib (>=3.10.0,<4.0.0)
|
16
|
+
Requires-Dist: networkx (>=3.4.2,<4.0.0)
|
17
|
+
Requires-Dist: semantic-router (>=0.1.0.dev4)
|
18
|
+
Description-Content-Type: text/markdown
|
19
|
+
|
20
|
+
# Philosophy
|
21
|
+
|
22
|
+
1. Async-first
|
23
|
+
2. Minimize abstractions
|
24
|
+
3. One way to do one thing
|
25
|
+
4. Graph-based AI
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import asyncio
|
2
|
+
from typing import Optional
|
3
|
+
from collections.abc import AsyncIterator
|
4
|
+
from semantic_router.utils.logger import logger
|
5
|
+
|
6
|
+
|
7
|
+
log_stream = True
|
8
|
+
|
9
|
+
class Callback:
|
10
|
+
first_token = True
|
11
|
+
current_node_name: Optional[str] = None
|
12
|
+
active: bool = True
|
13
|
+
queue: asyncio.Queue
|
14
|
+
|
15
|
+
def __init__(self):
|
16
|
+
self.queue = asyncio.Queue()
|
17
|
+
|
18
|
+
def __call__(self, token: str, node_name: Optional[str] = None):
|
19
|
+
self._check_node_name(node_name=node_name)
|
20
|
+
# otherwise we just assume node is correct and send token
|
21
|
+
self.queue.put_nowait(token)
|
22
|
+
|
23
|
+
async def acall(self, token: str, node_name: Optional[str] = None):
|
24
|
+
self._check_node_name(node_name=node_name)
|
25
|
+
# otherwise we just assume node is correct and send token
|
26
|
+
self.queue.put_nowait(token)
|
27
|
+
|
28
|
+
async def aiter(self) -> AsyncIterator[str]:
|
29
|
+
"""Used by receiver to get the tokens from the stream queue. Creates
|
30
|
+
a generator that yields tokens from the queue until the END token is
|
31
|
+
received.
|
32
|
+
"""
|
33
|
+
while True:
|
34
|
+
token = await self.queue.get()
|
35
|
+
yield token
|
36
|
+
self.queue.task_done()
|
37
|
+
if token == "<graphai:END>":
|
38
|
+
break
|
39
|
+
|
40
|
+
async def start_node(self, node_name: str, active: bool = True):
|
41
|
+
self.current_node_name = node_name
|
42
|
+
if self.first_token:
|
43
|
+
# TODO JB: not sure if we need self.first_token
|
44
|
+
self.first_token = False
|
45
|
+
self.active = active
|
46
|
+
if self.active:
|
47
|
+
self.queue.put_nowait(f"<graphai:start:{node_name}>")
|
48
|
+
|
49
|
+
async def end_node(self, node_name: str):
|
50
|
+
self.current_node_name = None
|
51
|
+
if self.active:
|
52
|
+
self.queue.put_nowait(f"<graphai:end:{node_name}>")
|
53
|
+
|
54
|
+
async def close(self):
|
55
|
+
self.queue.put_nowait("<graphai:END>")
|
56
|
+
|
57
|
+
def _check_node_name(self, node_name: Optional[str] = None):
|
58
|
+
if node_name:
|
59
|
+
# we confirm this is the current node
|
60
|
+
if self.current_node_name != node_name:
|
61
|
+
raise ValueError(
|
62
|
+
f"Node name mismatch: {self.current_node_name} != {node_name}"
|
63
|
+
)
|
@@ -0,0 +1,198 @@
|
|
1
|
+
from typing import List, Dict, Any
|
2
|
+
from graphai.nodes.base import _Node
|
3
|
+
from graphai.callback import Callback
|
4
|
+
from semantic_router.utils.logger import logger
|
5
|
+
|
6
|
+
|
7
|
+
class Graph:
|
8
|
+
def __init__(self, max_steps: int = 10):
|
9
|
+
self.nodes = []
|
10
|
+
self.edges = []
|
11
|
+
self.start_node = None
|
12
|
+
self.end_nodes = []
|
13
|
+
self.Callback = Callback
|
14
|
+
self.callback = None
|
15
|
+
self.max_steps = max_steps
|
16
|
+
|
17
|
+
def add_node(self, node):
|
18
|
+
self.nodes.append(node)
|
19
|
+
if node.is_start:
|
20
|
+
if self.start_node is not None:
|
21
|
+
raise Exception(
|
22
|
+
"Multiple start nodes are not allowed. Start node "
|
23
|
+
f"'{self.start_node.name}' already exists, so new start "
|
24
|
+
f"node '{node.name}' can not be added to the graph."
|
25
|
+
)
|
26
|
+
self.start_node = node
|
27
|
+
if node.is_end:
|
28
|
+
self.end_nodes.append(node)
|
29
|
+
|
30
|
+
def add_edge(self, source: _Node, destination: _Node):
|
31
|
+
# TODO add logic to check that source and destination are nodes
|
32
|
+
# and they exist in the graph object already
|
33
|
+
edge = Edge(source, destination)
|
34
|
+
self.edges.append(edge)
|
35
|
+
|
36
|
+
def add_router(self, sources: list[_Node], router: _Node, destinations: List[_Node]):
|
37
|
+
if not router.is_router:
|
38
|
+
raise TypeError("A router object must be passed to the router parameter.")
|
39
|
+
[self.add_edge(source, router) for source in sources]
|
40
|
+
for destination in destinations:
|
41
|
+
self.add_edge(router, destination)
|
42
|
+
|
43
|
+
def set_start_node(self, node: _Node):
|
44
|
+
self.start_node = node
|
45
|
+
|
46
|
+
def set_end_node(self, node: _Node):
|
47
|
+
self.end_node = node
|
48
|
+
|
49
|
+
def compile(self):
|
50
|
+
if not self.start_node:
|
51
|
+
raise Exception("Start node not defined.")
|
52
|
+
if not self.end_nodes:
|
53
|
+
raise Exception("No end nodes defined.")
|
54
|
+
if not self._is_valid():
|
55
|
+
raise Exception("Graph is not valid.")
|
56
|
+
|
57
|
+
def _is_valid(self):
|
58
|
+
# Implement validation logic, e.g., checking for cycles, disconnected components, etc.
|
59
|
+
return True
|
60
|
+
|
61
|
+
def _validate_output(self, output: Dict[str, Any], node_name: str):
|
62
|
+
if not isinstance(output, dict):
|
63
|
+
raise ValueError(
|
64
|
+
f"Expected dictionary output from node {node_name}. "
|
65
|
+
f"Instead, got {type(output)} from '{output}'."
|
66
|
+
)
|
67
|
+
|
68
|
+
async def execute(self, input):
|
69
|
+
# TODO JB: may need to add init callback here to init the queue on every new execution
|
70
|
+
if self.callback is None:
|
71
|
+
self.callback = self.get_callback()
|
72
|
+
current_node = self.start_node
|
73
|
+
state = input
|
74
|
+
steps = 0
|
75
|
+
while True:
|
76
|
+
# we invoke the node here
|
77
|
+
if current_node.stream:
|
78
|
+
# add callback tokens and param here if we are streaming
|
79
|
+
await self.callback.start_node(node_name=current_node.name)
|
80
|
+
output = await current_node.invoke(input=state, callback=self.callback)
|
81
|
+
self._validate_output(output=output, node_name=current_node.name)
|
82
|
+
await self.callback.end_node(node_name=current_node.name)
|
83
|
+
else:
|
84
|
+
output = await current_node.invoke(input=state)
|
85
|
+
self._validate_output(output=output, node_name=current_node.name)
|
86
|
+
# add output to state
|
87
|
+
state = {**state, **output}
|
88
|
+
if current_node.is_end:
|
89
|
+
# finish loop if this was an end node
|
90
|
+
break
|
91
|
+
if current_node.is_router:
|
92
|
+
# if we have a router node we let the router decide the next node
|
93
|
+
next_node_name = str(output["choice"])
|
94
|
+
del output["choice"]
|
95
|
+
current_node = self._get_node_by_name(node_name=next_node_name)
|
96
|
+
else:
|
97
|
+
# otherwise, we have linear path
|
98
|
+
current_node = self._get_next_node(current_node=current_node)
|
99
|
+
steps += 1
|
100
|
+
if steps >= self.max_steps:
|
101
|
+
raise Exception(
|
102
|
+
f"Max steps reached: {self.max_steps}. You can modify this "
|
103
|
+
"by setting `max_steps` when initializing the Graph object."
|
104
|
+
)
|
105
|
+
# TODO JB: may need to add end callback here to close the queue for every execution
|
106
|
+
if self.callback and "callback" in state:
|
107
|
+
await self.callback.close()
|
108
|
+
del state["callback"]
|
109
|
+
return state
|
110
|
+
|
111
|
+
def get_callback(self):
|
112
|
+
self.callback = self.Callback()
|
113
|
+
return self.callback
|
114
|
+
|
115
|
+
def _get_node_by_name(self, node_name: str) -> _Node:
|
116
|
+
for node in self.nodes:
|
117
|
+
if node.name == node_name:
|
118
|
+
return node
|
119
|
+
raise Exception(f"Node with name {node_name} not found.")
|
120
|
+
|
121
|
+
def _get_next_node(self, current_node):
|
122
|
+
for edge in self.edges:
|
123
|
+
if edge.source == current_node:
|
124
|
+
return edge.destination
|
125
|
+
raise Exception(
|
126
|
+
f"No outgoing edge found for current node '{current_node.name}'."
|
127
|
+
)
|
128
|
+
|
129
|
+
def visualize(self):
|
130
|
+
try:
|
131
|
+
import networkx as nx
|
132
|
+
except ImportError:
|
133
|
+
raise ImportError("NetworkX is required for visualization. Please install it with 'pip install networkx'.")
|
134
|
+
|
135
|
+
try:
|
136
|
+
import matplotlib.pyplot as plt
|
137
|
+
except ImportError:
|
138
|
+
raise ImportError("Matplotlib is required for visualization. Please install it with 'pip install matplotlib'.")
|
139
|
+
|
140
|
+
G = nx.DiGraph()
|
141
|
+
|
142
|
+
for node in self.nodes:
|
143
|
+
G.add_node(node.name)
|
144
|
+
|
145
|
+
for edge in self.edges:
|
146
|
+
G.add_edge(edge.source.name, edge.destination.name)
|
147
|
+
|
148
|
+
if nx.is_directed_acyclic_graph(G):
|
149
|
+
logger.info("The graph is acyclic. Visualization will use a topological layout.")
|
150
|
+
# Use topological layout if acyclic
|
151
|
+
# Compute the topological generations
|
152
|
+
generations = list(nx.topological_generations(G))
|
153
|
+
y_max = len(generations)
|
154
|
+
|
155
|
+
# Create a dictionary to store the y-coordinate for each node
|
156
|
+
y_coord = {}
|
157
|
+
for i, generation in enumerate(generations):
|
158
|
+
for node in generation:
|
159
|
+
y_coord[node] = y_max - i - 1
|
160
|
+
|
161
|
+
# Set up the layout
|
162
|
+
pos = {}
|
163
|
+
for i, generation in enumerate(generations):
|
164
|
+
x = 0
|
165
|
+
for node in generation:
|
166
|
+
pos[node] = (x, y_coord[node])
|
167
|
+
x += 1
|
168
|
+
|
169
|
+
# Center each level horizontally
|
170
|
+
for i, generation in enumerate(generations):
|
171
|
+
x_center = sum(pos[node][0] for node in generation) / len(generation)
|
172
|
+
for node in generation:
|
173
|
+
pos[node] = (pos[node][0] - x_center, pos[node][1])
|
174
|
+
|
175
|
+
# Scale the layout
|
176
|
+
max_x = max(abs(p[0]) for p in pos.values())
|
177
|
+
max_y = max(abs(p[1]) for p in pos.values())
|
178
|
+
scale = min(0.8 / max_x, 0.8 / max_y)
|
179
|
+
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
|
180
|
+
|
181
|
+
else:
|
182
|
+
print("Warning: The graph contains cycles. Visualization will use a spring layout.")
|
183
|
+
pos = nx.spring_layout(G, k=1, iterations=50)
|
184
|
+
|
185
|
+
plt.figure(figsize=(8, 6))
|
186
|
+
nx.draw(G, pos, with_labels=True, node_color='lightblue',
|
187
|
+
node_size=3000, font_size=8, font_weight='bold',
|
188
|
+
arrows=True, edge_color='gray', arrowsize=20)
|
189
|
+
|
190
|
+
plt.axis('off')
|
191
|
+
plt.show()
|
192
|
+
|
193
|
+
|
194
|
+
|
195
|
+
class Edge:
|
196
|
+
def __init__(self, source, destination):
|
197
|
+
self.source = source
|
198
|
+
self.destination = destination
|
@@ -0,0 +1,148 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import Any, Callable, Dict, Optional
|
3
|
+
|
4
|
+
from graphai.callback import Callback
|
5
|
+
from graphai.utils import FunctionSchema
|
6
|
+
|
7
|
+
|
8
|
+
class NodeMeta(type):
|
9
|
+
@staticmethod
|
10
|
+
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
|
11
|
+
init_signature = inspect.signature(cls_type.__init__)
|
12
|
+
init_params = {name: arg for name, arg in init_signature.parameters.items() if name != "self"}
|
13
|
+
return init_params
|
14
|
+
|
15
|
+
def __call__(cls, *args, **kwargs):
|
16
|
+
named_positional_args = NodeMeta.positional_to_kwargs(cls, args)
|
17
|
+
kwargs.update(named_positional_args)
|
18
|
+
return super().__call__(**kwargs)
|
19
|
+
|
20
|
+
|
21
|
+
class _Node:
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
is_router: bool = False,
|
25
|
+
):
|
26
|
+
self.is_router = is_router
|
27
|
+
|
28
|
+
def _node(
|
29
|
+
self,
|
30
|
+
func: Callable,
|
31
|
+
start: bool = False,
|
32
|
+
end: bool = False,
|
33
|
+
stream: bool = False,
|
34
|
+
) -> Callable:
|
35
|
+
"""Decorator validating node structure.
|
36
|
+
"""
|
37
|
+
if not callable(func):
|
38
|
+
raise ValueError("Node must be a callable function.")
|
39
|
+
|
40
|
+
func_signature = inspect.signature(func)
|
41
|
+
schema = FunctionSchema(func)
|
42
|
+
|
43
|
+
class NodeClass:
|
44
|
+
_func_signature = func_signature
|
45
|
+
is_router = None
|
46
|
+
_stream = stream
|
47
|
+
|
48
|
+
def __init__(self):
|
49
|
+
self._expected_params = set(self._func_signature.parameters.keys())
|
50
|
+
|
51
|
+
async def execute(self, *args, **kwargs):
|
52
|
+
# Prepare arguments, including callback if stream is True
|
53
|
+
params_dict = await self._parse_params(*args, **kwargs)
|
54
|
+
return await func(**params_dict) # Pass only the necessary arguments
|
55
|
+
|
56
|
+
async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
|
57
|
+
# filter out unexpected keyword args
|
58
|
+
expected_kwargs = {k: v for k, v in kwargs.items() if k in self._expected_params}
|
59
|
+
# Convert args to kwargs based on the function signature
|
60
|
+
args_names = list(self._func_signature.parameters.keys())[1:len(args)+1] # skip 'self'
|
61
|
+
expected_args_kwargs = dict(zip(args_names, args))
|
62
|
+
# Combine filtered args and kwargs
|
63
|
+
combined_params = {**expected_args_kwargs, **expected_kwargs}
|
64
|
+
|
65
|
+
# Bind the current instance attributes to the function signature
|
66
|
+
if "callback" in self._expected_params and not stream:
|
67
|
+
raise ValueError(
|
68
|
+
f"Node {func.__name__}: requires stream=True when callback is defined."
|
69
|
+
)
|
70
|
+
bound_params = self._func_signature.bind_partial(**combined_params)
|
71
|
+
# get the default parameters (if any)
|
72
|
+
bound_params.apply_defaults()
|
73
|
+
params_dict = bound_params.arguments.copy()
|
74
|
+
# Filter arguments to match the next node's parameters
|
75
|
+
filtered_params = {
|
76
|
+
k: v for k, v in params_dict.items() if k in self._expected_params
|
77
|
+
}
|
78
|
+
# confirm all required parameters are present
|
79
|
+
missing_params = [
|
80
|
+
p for p in self._expected_params if p not in filtered_params
|
81
|
+
]
|
82
|
+
# if anything is missing we raise an error
|
83
|
+
if missing_params:
|
84
|
+
raise ValueError(
|
85
|
+
f"Missing required parameters for the {func.__name__} node: {', '.join(missing_params)}"
|
86
|
+
)
|
87
|
+
return filtered_params
|
88
|
+
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
def get_signature(cls):
|
92
|
+
"""Returns the signature of the decorated function as LLM readable
|
93
|
+
string.
|
94
|
+
"""
|
95
|
+
signature_components = []
|
96
|
+
if NodeClass._func_signature:
|
97
|
+
for param in NodeClass._func_signature.parameters.values():
|
98
|
+
if param.default is param.empty:
|
99
|
+
signature_components.append(f"{param.name}: {param.annotation}")
|
100
|
+
else:
|
101
|
+
signature_components.append(f"{param.name}: {param.annotation} = {param.default}")
|
102
|
+
else:
|
103
|
+
return "No signature"
|
104
|
+
return "\n".join(signature_components)
|
105
|
+
|
106
|
+
@classmethod
|
107
|
+
async def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None):
|
108
|
+
if callback:
|
109
|
+
if stream:
|
110
|
+
input["callback"] = callback
|
111
|
+
else:
|
112
|
+
raise ValueError(
|
113
|
+
f"Error in node {func.__name__}. When callback provided, stream must be True."
|
114
|
+
)
|
115
|
+
instance = cls()
|
116
|
+
out = await instance.execute(**input)
|
117
|
+
return out
|
118
|
+
|
119
|
+
NodeClass.__name__ = func.__name__
|
120
|
+
NodeClass.name = func.__name__
|
121
|
+
NodeClass.__doc__ = func.__doc__
|
122
|
+
NodeClass.is_start = start
|
123
|
+
NodeClass.is_end = end
|
124
|
+
NodeClass.is_router = self.is_router
|
125
|
+
NodeClass.stream = stream
|
126
|
+
NodeClass.schema = schema
|
127
|
+
return NodeClass
|
128
|
+
|
129
|
+
def __call__(
|
130
|
+
self,
|
131
|
+
func: Optional[Callable] = None,
|
132
|
+
start: bool = False,
|
133
|
+
end: bool = False,
|
134
|
+
stream: bool = False,
|
135
|
+
):
|
136
|
+
# We must wrap the call to the decorator in a function for it to work
|
137
|
+
# correctly with or without parenthesis
|
138
|
+
def wrap(func: Callable, start=start, end=end, stream=stream) -> Callable:
|
139
|
+
return self._node(func=func, start=start, end=end, stream=stream)
|
140
|
+
if func:
|
141
|
+
# Decorator is called without parenthesis
|
142
|
+
return wrap(func=func, start=start, end=end, stream=stream)
|
143
|
+
# Decorator is called with parenthesis
|
144
|
+
return wrap
|
145
|
+
|
146
|
+
|
147
|
+
node = _Node()
|
148
|
+
router = _Node(is_router=True)
|
@@ -0,0 +1,125 @@
|
|
1
|
+
import inspect
|
2
|
+
from typing import Any, Callable, Dict, List, Union, Optional
|
3
|
+
from pydantic import BaseModel, Field
|
4
|
+
|
5
|
+
|
6
|
+
class Parameter(BaseModel):
|
7
|
+
class Config:
|
8
|
+
arbitrary_types_allowed = True
|
9
|
+
|
10
|
+
name: str = Field(description="The name of the parameter")
|
11
|
+
description: Optional[str] = Field(
|
12
|
+
default=None, description="The description of the parameter"
|
13
|
+
)
|
14
|
+
type: str = Field(description="The type of the parameter")
|
15
|
+
default: Any = Field(description="The default value of the parameter")
|
16
|
+
required: bool = Field(description="Whether the parameter is required")
|
17
|
+
|
18
|
+
def to_openai(self):
|
19
|
+
return {
|
20
|
+
self.name: {
|
21
|
+
"description": self.description,
|
22
|
+
"type": self.type,
|
23
|
+
}
|
24
|
+
}
|
25
|
+
|
26
|
+
class FunctionSchema:
|
27
|
+
"""Class that consumes a function and can return a schema required by
|
28
|
+
different LLMs for function calling.
|
29
|
+
"""
|
30
|
+
|
31
|
+
name: str = Field(description="The name of the function")
|
32
|
+
description: str = Field(description="The description of the function")
|
33
|
+
signature: str = Field(description="The signature of the function")
|
34
|
+
output: str = Field(description="The output of the function")
|
35
|
+
parameters: List[Parameter] = Field(description="The parameters of the function")
|
36
|
+
|
37
|
+
def __init__(self, function: Union[Callable, BaseModel]):
|
38
|
+
self.function = function
|
39
|
+
if callable(function):
|
40
|
+
self._process_function(function)
|
41
|
+
elif isinstance(function, BaseModel):
|
42
|
+
raise NotImplementedError("Pydantic BaseModel not implemented yet.")
|
43
|
+
else:
|
44
|
+
raise TypeError("Function must be a Callable or BaseModel")
|
45
|
+
|
46
|
+
def _process_function(self, function: Callable):
|
47
|
+
self.name = function.__name__
|
48
|
+
self.description = str(inspect.getdoc(function))
|
49
|
+
self.signature = str(inspect.signature(function))
|
50
|
+
self.output = str(inspect.signature(function).return_annotation)
|
51
|
+
parameters = []
|
52
|
+
for param in inspect.signature(function).parameters.values():
|
53
|
+
parameters.append(
|
54
|
+
Parameter(
|
55
|
+
name=param.name,
|
56
|
+
type=param.annotation.__name__,
|
57
|
+
default=param.default,
|
58
|
+
required=param.default is inspect.Parameter.empty,
|
59
|
+
)
|
60
|
+
)
|
61
|
+
self.parameters = parameters
|
62
|
+
|
63
|
+
def to_openai(self):
|
64
|
+
schema_dict = {
|
65
|
+
"type": "function",
|
66
|
+
"function": {
|
67
|
+
"name": self.name,
|
68
|
+
"description": self.description,
|
69
|
+
"parameters": {
|
70
|
+
"type": "object",
|
71
|
+
"properties": {
|
72
|
+
param.name: {
|
73
|
+
"description": (
|
74
|
+
param.description
|
75
|
+
if isinstance(param.description, str)
|
76
|
+
else "None provided"
|
77
|
+
),
|
78
|
+
"type": self._openai_type_mapping(param.type),
|
79
|
+
}
|
80
|
+
for param in self.parameters
|
81
|
+
},
|
82
|
+
"required": [
|
83
|
+
param.name for param in self.parameters if param.required
|
84
|
+
],
|
85
|
+
},
|
86
|
+
},
|
87
|
+
}
|
88
|
+
return schema_dict
|
89
|
+
|
90
|
+
def _openai_type_mapping(self, param_type: str) -> str:
|
91
|
+
if param_type == "int":
|
92
|
+
return "number"
|
93
|
+
elif param_type == "float":
|
94
|
+
return "number"
|
95
|
+
elif param_type == "str":
|
96
|
+
return "string"
|
97
|
+
elif param_type == "bool":
|
98
|
+
return "boolean"
|
99
|
+
else:
|
100
|
+
return "object"
|
101
|
+
|
102
|
+
|
103
|
+
def get_schema_pydantic(model: BaseModel) -> Dict[str, Any]:
|
104
|
+
signature_parts = []
|
105
|
+
for field_name, field_model in model.__annotations__.items():
|
106
|
+
field_info = model.__fields__[field_name]
|
107
|
+
default_value = field_info.default
|
108
|
+
|
109
|
+
if default_value:
|
110
|
+
default_repr = repr(default_value)
|
111
|
+
signature_part = (
|
112
|
+
f"{field_name}: {field_model.__name__} = {default_repr}"
|
113
|
+
)
|
114
|
+
else:
|
115
|
+
signature_part = f"{field_name}: {field_model.__name__}"
|
116
|
+
|
117
|
+
signature_parts.append(signature_part)
|
118
|
+
signature = f"({', '.join(signature_parts)}) -> str"
|
119
|
+
schema = FunctionSchema(
|
120
|
+
name=model.__class__.__name__,
|
121
|
+
description=model.__doc__,
|
122
|
+
signature=signature,
|
123
|
+
output="", # TODO: Implement output
|
124
|
+
)
|
125
|
+
return schema
|
@@ -0,0 +1,29 @@
|
|
1
|
+
[tool.poetry]
|
2
|
+
name = "graphai-lib"
|
3
|
+
version = "0.0.1"
|
4
|
+
description = ""
|
5
|
+
authors = ["Aurelio AI <hello@aurelio.ai>"]
|
6
|
+
readme = "README.md"
|
7
|
+
packages = [{include = "graphai"}]
|
8
|
+
license = "MIT"
|
9
|
+
|
10
|
+
[tool.poetry.dependencies]
|
11
|
+
python = ">=3.10,<3.14"
|
12
|
+
semantic-router = ">=0.1.0.dev4"
|
13
|
+
networkx = "^3.4.2"
|
14
|
+
matplotlib = "^3.10.0"
|
15
|
+
|
16
|
+
[tool.poetry.group.dev.dependencies]
|
17
|
+
ipykernel = "^6.25.0"
|
18
|
+
ruff = "^0.1.5"
|
19
|
+
pytest = "^8.2"
|
20
|
+
pytest-mock = "^3.12.0"
|
21
|
+
pytest-cov = "^4.1.0"
|
22
|
+
pytest-xdist = "^3.5.0"
|
23
|
+
pytest-asyncio = "^0.24.0"
|
24
|
+
mypy = "^1.7.1"
|
25
|
+
black = {extras = ["jupyter"], version = ">=23.12.1,<24.5.0"}
|
26
|
+
|
27
|
+
[build-system]
|
28
|
+
requires = ["poetry-core"]
|
29
|
+
build-backend = "poetry.core.masonry.api"
|