graphai-lib 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ Metadata-Version: 2.1
2
+ Name: graphai-lib
3
+ Version: 0.0.1
4
+ Summary:
5
+ License: MIT
6
+ Author: Aurelio AI
7
+ Author-email: hello@aurelio.ai
8
+ Requires-Python: >=3.10,<3.14
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.10
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Requires-Dist: matplotlib (>=3.10.0,<4.0.0)
16
+ Requires-Dist: networkx (>=3.4.2,<4.0.0)
17
+ Requires-Dist: semantic-router (>=0.1.0.dev4)
18
+ Description-Content-Type: text/markdown
19
+
20
+ # Philosophy
21
+
22
+ 1. Async-first
23
+ 2. Minimize abstractions
24
+ 3. One way to do one thing
25
+ 4. Graph-based AI
@@ -0,0 +1,6 @@
1
+ # Philosophy
2
+
3
+ 1. Async-first
4
+ 2. Minimize abstractions
5
+ 3. One way to do one thing
6
+ 4. Graph-based AI
@@ -0,0 +1,4 @@
1
+ from graphai.graph import Graph
2
+ from graphai.nodes import node, router
3
+
4
+ __all__ = ["node", "router", "Graph"]
@@ -0,0 +1,63 @@
1
+ import asyncio
2
+ from typing import Optional
3
+ from collections.abc import AsyncIterator
4
+ from semantic_router.utils.logger import logger
5
+
6
+
7
+ log_stream = True
8
+
9
+ class Callback:
10
+ first_token = True
11
+ current_node_name: Optional[str] = None
12
+ active: bool = True
13
+ queue: asyncio.Queue
14
+
15
+ def __init__(self):
16
+ self.queue = asyncio.Queue()
17
+
18
+ def __call__(self, token: str, node_name: Optional[str] = None):
19
+ self._check_node_name(node_name=node_name)
20
+ # otherwise we just assume node is correct and send token
21
+ self.queue.put_nowait(token)
22
+
23
+ async def acall(self, token: str, node_name: Optional[str] = None):
24
+ self._check_node_name(node_name=node_name)
25
+ # otherwise we just assume node is correct and send token
26
+ self.queue.put_nowait(token)
27
+
28
+ async def aiter(self) -> AsyncIterator[str]:
29
+ """Used by receiver to get the tokens from the stream queue. Creates
30
+ a generator that yields tokens from the queue until the END token is
31
+ received.
32
+ """
33
+ while True:
34
+ token = await self.queue.get()
35
+ yield token
36
+ self.queue.task_done()
37
+ if token == "<graphai:END>":
38
+ break
39
+
40
+ async def start_node(self, node_name: str, active: bool = True):
41
+ self.current_node_name = node_name
42
+ if self.first_token:
43
+ # TODO JB: not sure if we need self.first_token
44
+ self.first_token = False
45
+ self.active = active
46
+ if self.active:
47
+ self.queue.put_nowait(f"<graphai:start:{node_name}>")
48
+
49
+ async def end_node(self, node_name: str):
50
+ self.current_node_name = None
51
+ if self.active:
52
+ self.queue.put_nowait(f"<graphai:end:{node_name}>")
53
+
54
+ async def close(self):
55
+ self.queue.put_nowait("<graphai:END>")
56
+
57
+ def _check_node_name(self, node_name: Optional[str] = None):
58
+ if node_name:
59
+ # we confirm this is the current node
60
+ if self.current_node_name != node_name:
61
+ raise ValueError(
62
+ f"Node name mismatch: {self.current_node_name} != {node_name}"
63
+ )
@@ -0,0 +1,198 @@
1
+ from typing import List, Dict, Any
2
+ from graphai.nodes.base import _Node
3
+ from graphai.callback import Callback
4
+ from semantic_router.utils.logger import logger
5
+
6
+
7
+ class Graph:
8
+ def __init__(self, max_steps: int = 10):
9
+ self.nodes = []
10
+ self.edges = []
11
+ self.start_node = None
12
+ self.end_nodes = []
13
+ self.Callback = Callback
14
+ self.callback = None
15
+ self.max_steps = max_steps
16
+
17
+ def add_node(self, node):
18
+ self.nodes.append(node)
19
+ if node.is_start:
20
+ if self.start_node is not None:
21
+ raise Exception(
22
+ "Multiple start nodes are not allowed. Start node "
23
+ f"'{self.start_node.name}' already exists, so new start "
24
+ f"node '{node.name}' can not be added to the graph."
25
+ )
26
+ self.start_node = node
27
+ if node.is_end:
28
+ self.end_nodes.append(node)
29
+
30
+ def add_edge(self, source: _Node, destination: _Node):
31
+ # TODO add logic to check that source and destination are nodes
32
+ # and they exist in the graph object already
33
+ edge = Edge(source, destination)
34
+ self.edges.append(edge)
35
+
36
+ def add_router(self, sources: list[_Node], router: _Node, destinations: List[_Node]):
37
+ if not router.is_router:
38
+ raise TypeError("A router object must be passed to the router parameter.")
39
+ [self.add_edge(source, router) for source in sources]
40
+ for destination in destinations:
41
+ self.add_edge(router, destination)
42
+
43
+ def set_start_node(self, node: _Node):
44
+ self.start_node = node
45
+
46
+ def set_end_node(self, node: _Node):
47
+ self.end_node = node
48
+
49
+ def compile(self):
50
+ if not self.start_node:
51
+ raise Exception("Start node not defined.")
52
+ if not self.end_nodes:
53
+ raise Exception("No end nodes defined.")
54
+ if not self._is_valid():
55
+ raise Exception("Graph is not valid.")
56
+
57
+ def _is_valid(self):
58
+ # Implement validation logic, e.g., checking for cycles, disconnected components, etc.
59
+ return True
60
+
61
+ def _validate_output(self, output: Dict[str, Any], node_name: str):
62
+ if not isinstance(output, dict):
63
+ raise ValueError(
64
+ f"Expected dictionary output from node {node_name}. "
65
+ f"Instead, got {type(output)} from '{output}'."
66
+ )
67
+
68
+ async def execute(self, input):
69
+ # TODO JB: may need to add init callback here to init the queue on every new execution
70
+ if self.callback is None:
71
+ self.callback = self.get_callback()
72
+ current_node = self.start_node
73
+ state = input
74
+ steps = 0
75
+ while True:
76
+ # we invoke the node here
77
+ if current_node.stream:
78
+ # add callback tokens and param here if we are streaming
79
+ await self.callback.start_node(node_name=current_node.name)
80
+ output = await current_node.invoke(input=state, callback=self.callback)
81
+ self._validate_output(output=output, node_name=current_node.name)
82
+ await self.callback.end_node(node_name=current_node.name)
83
+ else:
84
+ output = await current_node.invoke(input=state)
85
+ self._validate_output(output=output, node_name=current_node.name)
86
+ # add output to state
87
+ state = {**state, **output}
88
+ if current_node.is_end:
89
+ # finish loop if this was an end node
90
+ break
91
+ if current_node.is_router:
92
+ # if we have a router node we let the router decide the next node
93
+ next_node_name = str(output["choice"])
94
+ del output["choice"]
95
+ current_node = self._get_node_by_name(node_name=next_node_name)
96
+ else:
97
+ # otherwise, we have linear path
98
+ current_node = self._get_next_node(current_node=current_node)
99
+ steps += 1
100
+ if steps >= self.max_steps:
101
+ raise Exception(
102
+ f"Max steps reached: {self.max_steps}. You can modify this "
103
+ "by setting `max_steps` when initializing the Graph object."
104
+ )
105
+ # TODO JB: may need to add end callback here to close the queue for every execution
106
+ if self.callback and "callback" in state:
107
+ await self.callback.close()
108
+ del state["callback"]
109
+ return state
110
+
111
+ def get_callback(self):
112
+ self.callback = self.Callback()
113
+ return self.callback
114
+
115
+ def _get_node_by_name(self, node_name: str) -> _Node:
116
+ for node in self.nodes:
117
+ if node.name == node_name:
118
+ return node
119
+ raise Exception(f"Node with name {node_name} not found.")
120
+
121
+ def _get_next_node(self, current_node):
122
+ for edge in self.edges:
123
+ if edge.source == current_node:
124
+ return edge.destination
125
+ raise Exception(
126
+ f"No outgoing edge found for current node '{current_node.name}'."
127
+ )
128
+
129
+ def visualize(self):
130
+ try:
131
+ import networkx as nx
132
+ except ImportError:
133
+ raise ImportError("NetworkX is required for visualization. Please install it with 'pip install networkx'.")
134
+
135
+ try:
136
+ import matplotlib.pyplot as plt
137
+ except ImportError:
138
+ raise ImportError("Matplotlib is required for visualization. Please install it with 'pip install matplotlib'.")
139
+
140
+ G = nx.DiGraph()
141
+
142
+ for node in self.nodes:
143
+ G.add_node(node.name)
144
+
145
+ for edge in self.edges:
146
+ G.add_edge(edge.source.name, edge.destination.name)
147
+
148
+ if nx.is_directed_acyclic_graph(G):
149
+ logger.info("The graph is acyclic. Visualization will use a topological layout.")
150
+ # Use topological layout if acyclic
151
+ # Compute the topological generations
152
+ generations = list(nx.topological_generations(G))
153
+ y_max = len(generations)
154
+
155
+ # Create a dictionary to store the y-coordinate for each node
156
+ y_coord = {}
157
+ for i, generation in enumerate(generations):
158
+ for node in generation:
159
+ y_coord[node] = y_max - i - 1
160
+
161
+ # Set up the layout
162
+ pos = {}
163
+ for i, generation in enumerate(generations):
164
+ x = 0
165
+ for node in generation:
166
+ pos[node] = (x, y_coord[node])
167
+ x += 1
168
+
169
+ # Center each level horizontally
170
+ for i, generation in enumerate(generations):
171
+ x_center = sum(pos[node][0] for node in generation) / len(generation)
172
+ for node in generation:
173
+ pos[node] = (pos[node][0] - x_center, pos[node][1])
174
+
175
+ # Scale the layout
176
+ max_x = max(abs(p[0]) for p in pos.values())
177
+ max_y = max(abs(p[1]) for p in pos.values())
178
+ scale = min(0.8 / max_x, 0.8 / max_y)
179
+ pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
180
+
181
+ else:
182
+ print("Warning: The graph contains cycles. Visualization will use a spring layout.")
183
+ pos = nx.spring_layout(G, k=1, iterations=50)
184
+
185
+ plt.figure(figsize=(8, 6))
186
+ nx.draw(G, pos, with_labels=True, node_color='lightblue',
187
+ node_size=3000, font_size=8, font_weight='bold',
188
+ arrows=True, edge_color='gray', arrowsize=20)
189
+
190
+ plt.axis('off')
191
+ plt.show()
192
+
193
+
194
+
195
+ class Edge:
196
+ def __init__(self, source, destination):
197
+ self.source = source
198
+ self.destination = destination
@@ -0,0 +1,3 @@
1
+ from graphai.nodes.base import node, router
2
+
3
+ __all__ = ["node", "router"]
@@ -0,0 +1,148 @@
1
+ import inspect
2
+ from typing import Any, Callable, Dict, Optional
3
+
4
+ from graphai.callback import Callback
5
+ from graphai.utils import FunctionSchema
6
+
7
+
8
+ class NodeMeta(type):
9
+ @staticmethod
10
+ def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
11
+ init_signature = inspect.signature(cls_type.__init__)
12
+ init_params = {name: arg for name, arg in init_signature.parameters.items() if name != "self"}
13
+ return init_params
14
+
15
+ def __call__(cls, *args, **kwargs):
16
+ named_positional_args = NodeMeta.positional_to_kwargs(cls, args)
17
+ kwargs.update(named_positional_args)
18
+ return super().__call__(**kwargs)
19
+
20
+
21
+ class _Node:
22
+ def __init__(
23
+ self,
24
+ is_router: bool = False,
25
+ ):
26
+ self.is_router = is_router
27
+
28
+ def _node(
29
+ self,
30
+ func: Callable,
31
+ start: bool = False,
32
+ end: bool = False,
33
+ stream: bool = False,
34
+ ) -> Callable:
35
+ """Decorator validating node structure.
36
+ """
37
+ if not callable(func):
38
+ raise ValueError("Node must be a callable function.")
39
+
40
+ func_signature = inspect.signature(func)
41
+ schema = FunctionSchema(func)
42
+
43
+ class NodeClass:
44
+ _func_signature = func_signature
45
+ is_router = None
46
+ _stream = stream
47
+
48
+ def __init__(self):
49
+ self._expected_params = set(self._func_signature.parameters.keys())
50
+
51
+ async def execute(self, *args, **kwargs):
52
+ # Prepare arguments, including callback if stream is True
53
+ params_dict = await self._parse_params(*args, **kwargs)
54
+ return await func(**params_dict) # Pass only the necessary arguments
55
+
56
+ async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
57
+ # filter out unexpected keyword args
58
+ expected_kwargs = {k: v for k, v in kwargs.items() if k in self._expected_params}
59
+ # Convert args to kwargs based on the function signature
60
+ args_names = list(self._func_signature.parameters.keys())[1:len(args)+1] # skip 'self'
61
+ expected_args_kwargs = dict(zip(args_names, args))
62
+ # Combine filtered args and kwargs
63
+ combined_params = {**expected_args_kwargs, **expected_kwargs}
64
+
65
+ # Bind the current instance attributes to the function signature
66
+ if "callback" in self._expected_params and not stream:
67
+ raise ValueError(
68
+ f"Node {func.__name__}: requires stream=True when callback is defined."
69
+ )
70
+ bound_params = self._func_signature.bind_partial(**combined_params)
71
+ # get the default parameters (if any)
72
+ bound_params.apply_defaults()
73
+ params_dict = bound_params.arguments.copy()
74
+ # Filter arguments to match the next node's parameters
75
+ filtered_params = {
76
+ k: v for k, v in params_dict.items() if k in self._expected_params
77
+ }
78
+ # confirm all required parameters are present
79
+ missing_params = [
80
+ p for p in self._expected_params if p not in filtered_params
81
+ ]
82
+ # if anything is missing we raise an error
83
+ if missing_params:
84
+ raise ValueError(
85
+ f"Missing required parameters for the {func.__name__} node: {', '.join(missing_params)}"
86
+ )
87
+ return filtered_params
88
+
89
+
90
+ @classmethod
91
+ def get_signature(cls):
92
+ """Returns the signature of the decorated function as LLM readable
93
+ string.
94
+ """
95
+ signature_components = []
96
+ if NodeClass._func_signature:
97
+ for param in NodeClass._func_signature.parameters.values():
98
+ if param.default is param.empty:
99
+ signature_components.append(f"{param.name}: {param.annotation}")
100
+ else:
101
+ signature_components.append(f"{param.name}: {param.annotation} = {param.default}")
102
+ else:
103
+ return "No signature"
104
+ return "\n".join(signature_components)
105
+
106
+ @classmethod
107
+ async def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None):
108
+ if callback:
109
+ if stream:
110
+ input["callback"] = callback
111
+ else:
112
+ raise ValueError(
113
+ f"Error in node {func.__name__}. When callback provided, stream must be True."
114
+ )
115
+ instance = cls()
116
+ out = await instance.execute(**input)
117
+ return out
118
+
119
+ NodeClass.__name__ = func.__name__
120
+ NodeClass.name = func.__name__
121
+ NodeClass.__doc__ = func.__doc__
122
+ NodeClass.is_start = start
123
+ NodeClass.is_end = end
124
+ NodeClass.is_router = self.is_router
125
+ NodeClass.stream = stream
126
+ NodeClass.schema = schema
127
+ return NodeClass
128
+
129
+ def __call__(
130
+ self,
131
+ func: Optional[Callable] = None,
132
+ start: bool = False,
133
+ end: bool = False,
134
+ stream: bool = False,
135
+ ):
136
+ # We must wrap the call to the decorator in a function for it to work
137
+ # correctly with or without parenthesis
138
+ def wrap(func: Callable, start=start, end=end, stream=stream) -> Callable:
139
+ return self._node(func=func, start=start, end=end, stream=stream)
140
+ if func:
141
+ # Decorator is called without parenthesis
142
+ return wrap(func=func, start=start, end=end, stream=stream)
143
+ # Decorator is called with parenthesis
144
+ return wrap
145
+
146
+
147
+ node = _Node()
148
+ router = _Node(is_router=True)
@@ -0,0 +1,125 @@
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Union, Optional
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class Parameter(BaseModel):
7
+ class Config:
8
+ arbitrary_types_allowed = True
9
+
10
+ name: str = Field(description="The name of the parameter")
11
+ description: Optional[str] = Field(
12
+ default=None, description="The description of the parameter"
13
+ )
14
+ type: str = Field(description="The type of the parameter")
15
+ default: Any = Field(description="The default value of the parameter")
16
+ required: bool = Field(description="Whether the parameter is required")
17
+
18
+ def to_openai(self):
19
+ return {
20
+ self.name: {
21
+ "description": self.description,
22
+ "type": self.type,
23
+ }
24
+ }
25
+
26
+ class FunctionSchema:
27
+ """Class that consumes a function and can return a schema required by
28
+ different LLMs for function calling.
29
+ """
30
+
31
+ name: str = Field(description="The name of the function")
32
+ description: str = Field(description="The description of the function")
33
+ signature: str = Field(description="The signature of the function")
34
+ output: str = Field(description="The output of the function")
35
+ parameters: List[Parameter] = Field(description="The parameters of the function")
36
+
37
+ def __init__(self, function: Union[Callable, BaseModel]):
38
+ self.function = function
39
+ if callable(function):
40
+ self._process_function(function)
41
+ elif isinstance(function, BaseModel):
42
+ raise NotImplementedError("Pydantic BaseModel not implemented yet.")
43
+ else:
44
+ raise TypeError("Function must be a Callable or BaseModel")
45
+
46
+ def _process_function(self, function: Callable):
47
+ self.name = function.__name__
48
+ self.description = str(inspect.getdoc(function))
49
+ self.signature = str(inspect.signature(function))
50
+ self.output = str(inspect.signature(function).return_annotation)
51
+ parameters = []
52
+ for param in inspect.signature(function).parameters.values():
53
+ parameters.append(
54
+ Parameter(
55
+ name=param.name,
56
+ type=param.annotation.__name__,
57
+ default=param.default,
58
+ required=param.default is inspect.Parameter.empty,
59
+ )
60
+ )
61
+ self.parameters = parameters
62
+
63
+ def to_openai(self):
64
+ schema_dict = {
65
+ "type": "function",
66
+ "function": {
67
+ "name": self.name,
68
+ "description": self.description,
69
+ "parameters": {
70
+ "type": "object",
71
+ "properties": {
72
+ param.name: {
73
+ "description": (
74
+ param.description
75
+ if isinstance(param.description, str)
76
+ else "None provided"
77
+ ),
78
+ "type": self._openai_type_mapping(param.type),
79
+ }
80
+ for param in self.parameters
81
+ },
82
+ "required": [
83
+ param.name for param in self.parameters if param.required
84
+ ],
85
+ },
86
+ },
87
+ }
88
+ return schema_dict
89
+
90
+ def _openai_type_mapping(self, param_type: str) -> str:
91
+ if param_type == "int":
92
+ return "number"
93
+ elif param_type == "float":
94
+ return "number"
95
+ elif param_type == "str":
96
+ return "string"
97
+ elif param_type == "bool":
98
+ return "boolean"
99
+ else:
100
+ return "object"
101
+
102
+
103
+ def get_schema_pydantic(model: BaseModel) -> Dict[str, Any]:
104
+ signature_parts = []
105
+ for field_name, field_model in model.__annotations__.items():
106
+ field_info = model.__fields__[field_name]
107
+ default_value = field_info.default
108
+
109
+ if default_value:
110
+ default_repr = repr(default_value)
111
+ signature_part = (
112
+ f"{field_name}: {field_model.__name__} = {default_repr}"
113
+ )
114
+ else:
115
+ signature_part = f"{field_name}: {field_model.__name__}"
116
+
117
+ signature_parts.append(signature_part)
118
+ signature = f"({', '.join(signature_parts)}) -> str"
119
+ schema = FunctionSchema(
120
+ name=model.__class__.__name__,
121
+ description=model.__doc__,
122
+ signature=signature,
123
+ output="", # TODO: Implement output
124
+ )
125
+ return schema
@@ -0,0 +1,29 @@
1
+ [tool.poetry]
2
+ name = "graphai-lib"
3
+ version = "0.0.1"
4
+ description = ""
5
+ authors = ["Aurelio AI <hello@aurelio.ai>"]
6
+ readme = "README.md"
7
+ packages = [{include = "graphai"}]
8
+ license = "MIT"
9
+
10
+ [tool.poetry.dependencies]
11
+ python = ">=3.10,<3.14"
12
+ semantic-router = ">=0.1.0.dev4"
13
+ networkx = "^3.4.2"
14
+ matplotlib = "^3.10.0"
15
+
16
+ [tool.poetry.group.dev.dependencies]
17
+ ipykernel = "^6.25.0"
18
+ ruff = "^0.1.5"
19
+ pytest = "^8.2"
20
+ pytest-mock = "^3.12.0"
21
+ pytest-cov = "^4.1.0"
22
+ pytest-xdist = "^3.5.0"
23
+ pytest-asyncio = "^0.24.0"
24
+ mypy = "^1.7.1"
25
+ black = {extras = ["jupyter"], version = ">=23.12.1,<24.5.0"}
26
+
27
+ [build-system]
28
+ requires = ["poetry-core"]
29
+ build-backend = "poetry.core.masonry.api"