graphai-lib 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graphai/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  from graphai.graph import Graph
2
2
  from graphai.nodes import node, router
3
3
 
4
- __all__ = ["node", "router", "Graph"]
4
+ __all__ = ["node", "router", "Graph"]
graphai/callback.py CHANGED
@@ -1,12 +1,12 @@
1
1
  import asyncio
2
2
  from pydantic import Field
3
- from typing import Optional
3
+ from typing import Optional, Any
4
4
  from collections.abc import AsyncIterator
5
- from semantic_router.utils.logger import logger
6
5
 
7
6
 
8
7
  log_stream = True
9
8
 
9
+
10
10
  class Callback:
11
11
  identifier: str = Field(
12
12
  default="graphai",
@@ -14,7 +14,7 @@ class Callback:
14
14
  "The identifier for special tokens. This allows us to easily "
15
15
  "identify special tokens in the stream so we can handle them "
16
16
  "correctly in any downstream process."
17
- )
17
+ ),
18
18
  )
19
19
  special_token_format: str = Field(
20
20
  default="<{identifier}:{token}:{params}>",
@@ -32,8 +32,8 @@ class Callback:
32
32
  examples=[
33
33
  "<{identifier}:{token}:{params}>",
34
34
  "<[{identifier} | {token} | {params}]>",
35
- "<{token}:{params}>"
36
- ]
35
+ "<{token}:{params}>",
36
+ ],
37
37
  )
38
38
  token_format: str = Field(
39
39
  default="{token}",
@@ -41,27 +41,23 @@ class Callback:
41
41
  "The format for streamed tokens. This is used to format the "
42
42
  "tokens typically returned from LLMs. By default, no special "
43
43
  "formatting is applied."
44
- )
44
+ ),
45
45
  )
46
46
  _first_token: bool = Field(
47
47
  default=True,
48
48
  description="Whether this is the first token in the stream.",
49
- exclude=True
49
+ exclude=True,
50
50
  )
51
51
  _current_node_name: Optional[str] = Field(
52
- default=None,
53
- description="The name of the current node.",
54
- exclude=True
52
+ default=None, description="The name of the current node.", exclude=True
55
53
  )
56
54
  _active: bool = Field(
57
- default=True,
58
- description="Whether the callback is active.",
59
- exclude=True
55
+ default=True, description="Whether the callback is active.", exclude=True
60
56
  )
61
57
  _done: bool = Field(
62
58
  default=False,
63
59
  description="Whether the stream is done and should be closed.",
64
- exclude=True
60
+ exclude=True,
65
61
  )
66
62
  queue: asyncio.Queue
67
63
 
@@ -83,7 +79,7 @@ class Callback:
83
79
  @property
84
80
  def first_token(self) -> bool:
85
81
  return self._first_token
86
-
82
+
87
83
  @first_token.setter
88
84
  def first_token(self, value: bool):
89
85
  self._first_token = value
@@ -91,7 +87,7 @@ class Callback:
91
87
  @property
92
88
  def current_node_name(self) -> Optional[str]:
93
89
  return self._current_node_name
94
-
90
+
95
91
  @current_node_name.setter
96
92
  def current_node_name(self, value: Optional[str]):
97
93
  self._current_node_name = value
@@ -99,7 +95,7 @@ class Callback:
99
95
  @property
100
96
  def active(self) -> bool:
101
97
  return self._active
102
-
98
+
103
99
  @active.setter
104
100
  def active(self, value: bool):
105
101
  self._active = value
@@ -110,7 +106,7 @@ class Callback:
110
106
  self._check_node_name(node_name=node_name)
111
107
  # otherwise we just assume node is correct and send token
112
108
  self.queue.put_nowait(token)
113
-
109
+
114
110
  async def acall(self, token: str, node_name: Optional[str] = None):
115
111
  # TODO JB: do we need to have `node_name` param?
116
112
  if self._done:
@@ -118,16 +114,13 @@ class Callback:
118
114
  self._check_node_name(node_name=node_name)
119
115
  # otherwise we just assume node is correct and send token
120
116
  self.queue.put_nowait(token)
121
-
117
+
122
118
  async def aiter(self) -> AsyncIterator[str]:
123
119
  """Used by receiver to get the tokens from the stream queue. Creates
124
120
  a generator that yields tokens from the queue until the END token is
125
121
  received.
126
122
  """
127
- end_token = await self._build_special_token(
128
- name="END",
129
- params=None
130
- )
123
+ end_token = await self._build_special_token(name="END", params=None)
131
124
  while True: # Keep going until we see the END token
132
125
  try:
133
126
  if self._done and self.queue.empty():
@@ -142,8 +135,7 @@ class Callback:
142
135
  self._done = True # Mark as done after processing all tokens
143
136
 
144
137
  async def start_node(self, node_name: str, active: bool = True):
145
- """Starts a new node and emits the start token.
146
- """
138
+ """Starts a new node and emits the start token."""
147
139
  if self._done:
148
140
  raise RuntimeError("Cannot start node on a closed stream")
149
141
  self.current_node_name = node_name
@@ -152,27 +144,23 @@ class Callback:
152
144
  self.active = active
153
145
  if self.active:
154
146
  token = await self._build_special_token(
155
- name=f"{self.current_node_name}:start",
156
- params=None
147
+ name=f"{self.current_node_name}:start", params=None
157
148
  )
158
149
  self.queue.put_nowait(token)
159
150
  # TODO JB: should we use two tokens here?
160
151
  node_token = await self._build_special_token(
161
- name=self.current_node_name,
162
- params=None
152
+ name=self.current_node_name, params=None
163
153
  )
164
154
  self.queue.put_nowait(node_token)
165
-
155
+
166
156
  async def end_node(self, node_name: str):
167
- """Emits the end token for the current node.
168
- """
157
+ """Emits the end token for the current node."""
169
158
  if self._done:
170
159
  raise RuntimeError("Cannot end node on a closed stream")
171
- #self.current_node_name = node_name
160
+ # self.current_node_name = node_name
172
161
  if self.active:
173
162
  node_token = await self._build_special_token(
174
- name=f"{self.current_node_name}:end",
175
- params=None
163
+ name=f"{self.current_node_name}:end", params=None
176
164
  )
177
165
  self.queue.put_nowait(node_token)
178
166
 
@@ -182,10 +170,7 @@ class Callback:
182
170
  """
183
171
  if self._done:
184
172
  return
185
- end_token = await self._build_special_token(
186
- name="END",
187
- params=None
188
- )
173
+ end_token = await self._build_special_token(name="END", params=None)
189
174
  self._done = True # Set done before putting the end token
190
175
  self.queue.put_nowait(end_token)
191
176
  # Don't wait for queue.join() as it can cause deadlock
@@ -198,8 +183,10 @@ class Callback:
198
183
  raise ValueError(
199
184
  f"Node name mismatch: {self.current_node_name} != {node_name}"
200
185
  )
201
-
202
- async def _build_special_token(self, name: str, params: dict[str, any] | None = None):
186
+
187
+ async def _build_special_token(
188
+ self, name: str, params: dict[str, Any] | None = None
189
+ ):
203
190
  if params:
204
191
  params_str = ",".join([f"{k}={v}" for k, v in params.items()])
205
192
  else:
@@ -209,7 +196,5 @@ class Callback:
209
196
  else:
210
197
  identifier = ""
211
198
  return self.special_token_format.format(
212
- identifier=identifier,
213
- token=name,
214
- params=params_str
199
+ identifier=identifier, token=name, params=params_str
215
200
  )
graphai/graph.py CHANGED
@@ -1,17 +1,33 @@
1
- from typing import List, Dict, Any, Optional
2
- from graphai.nodes.base import _Node
1
+ from typing import List, Dict, Any, Optional, Protocol, Type
3
2
  from graphai.callback import Callback
4
- from semantic_router.utils.logger import logger
3
+ from graphai.utils import logger
4
+
5
+
6
+ class NodeProtocol(Protocol):
7
+ """Protocol defining the interface of a decorated node."""
8
+ name: str
9
+ is_start: bool
10
+ is_end: bool
11
+ is_router: bool
12
+ stream: bool
13
+
14
+ async def invoke(
15
+ self,
16
+ input: Dict[str, Any],
17
+ callback: Optional[Callback] = None,
18
+ state: Optional[Dict[str, Any]] = None
19
+ ) -> Dict[str, Any]: ...
5
20
 
6
21
 
7
22
  class Graph:
8
- def __init__(self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None):
9
- self.nodes: Dict[str, _Node] = {}
23
+ def __init__(
24
+ self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
25
+ ):
26
+ self.nodes: Dict[str, NodeProtocol] = {}
10
27
  self.edges: List[Any] = []
11
- self.start_node: Optional[_Node] = None
12
- self.end_nodes: List[_Node] = []
13
- self.Callback = Callback
14
- self.callback = None
28
+ self.start_node: Optional[NodeProtocol] = None
29
+ self.end_nodes: List[NodeProtocol] = []
30
+ self.Callback: Type[Callback] = Callback
15
31
  self.max_steps = max_steps
16
32
  self.state = initial_state or {}
17
33
 
@@ -32,7 +48,7 @@ class Graph:
32
48
  """Reset the graph state to an empty dict."""
33
49
  self.state = {}
34
50
 
35
- def add_node(self, node):
51
+ def add_node(self, node: NodeProtocol):
36
52
  if node.name in self.nodes:
37
53
  raise Exception(f"Node with name '{node.name}' already exists.")
38
54
  self.nodes[node.name] = node
@@ -47,9 +63,9 @@ class Graph:
47
63
  if node.is_end:
48
64
  self.end_nodes.append(node)
49
65
 
50
- def add_edge(self, source: _Node | str, destination: _Node | str):
66
+ def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str):
51
67
  """Adds an edge between two nodes that already exist in the graph.
52
-
68
+
53
69
  Args:
54
70
  source: The source node or its name.
55
71
  destination: The destination node or its name.
@@ -60,7 +76,7 @@ class Graph:
60
76
  source_node = self.nodes.get(source)
61
77
  else:
62
78
  # Check if it's a node-like object by looking for required attributes
63
- if hasattr(source, 'name'):
79
+ if hasattr(source, "name"):
64
80
  source_node = self.nodes.get(source.name)
65
81
  if source_node is None:
66
82
  raise ValueError(
@@ -71,7 +87,7 @@ class Graph:
71
87
  destination_node = self.nodes.get(destination)
72
88
  else:
73
89
  # Check if it's a node-like object by looking for required attributes
74
- if hasattr(destination, 'name'):
90
+ if hasattr(destination, "name"):
75
91
  destination_node = self.nodes.get(destination.name)
76
92
  if destination_node is None:
77
93
  raise ValueError(
@@ -80,17 +96,19 @@ class Graph:
80
96
  edge = Edge(source_node, destination_node)
81
97
  self.edges.append(edge)
82
98
 
83
- def add_router(self, sources: list[_Node], router: _Node, destinations: List[_Node]):
99
+ def add_router(
100
+ self, sources: list[NodeProtocol], router: NodeProtocol, destinations: List[NodeProtocol]
101
+ ):
84
102
  if not router.is_router:
85
103
  raise TypeError("A router object must be passed to the router parameter.")
86
104
  [self.add_edge(source, router) for source in sources]
87
105
  for destination in destinations:
88
106
  self.add_edge(router, destination)
89
107
 
90
- def set_start_node(self, node: _Node):
108
+ def set_start_node(self, node: NodeProtocol):
91
109
  self.start_node = node
92
110
 
93
- def set_end_node(self, node: _Node):
111
+ def set_end_node(self, node: NodeProtocol):
94
112
  self.end_node = node
95
113
 
96
114
  def compile(self):
@@ -112,11 +130,15 @@ class Graph:
112
130
  f"Instead, got {type(output)} from '{output}'."
113
131
  )
114
132
 
115
- async def execute(self, input):
133
+ async def execute(self, input, callback: Callback | None = None):
116
134
  # TODO JB: may need to add init callback here to init the queue on every new execution
117
- if self.callback is None:
118
- self.callback = self.get_callback()
135
+ if callback is None:
136
+ callback = self.get_callback()
137
+
138
+ # Type assertion to tell the type checker that start_node is not None after compile()
139
+ assert self.start_node is not None, "Graph must be compiled before execution"
119
140
  current_node = self.start_node
141
+
120
142
  state = input
121
143
  # Don't reset the graph state if it was initialized with initial_state
122
144
  steps = 0
@@ -124,11 +146,13 @@ class Graph:
124
146
  # we invoke the node here
125
147
  if current_node.stream:
126
148
  # add callback tokens and param here if we are streaming
127
- await self.callback.start_node(node_name=current_node.name)
149
+ await callback.start_node(node_name=current_node.name)
128
150
  # Include graph's internal state in the node execution context
129
- output = await current_node.invoke(input=state, callback=self.callback, state=self.state)
151
+ output = await current_node.invoke(
152
+ input=state, callback=callback, state=self.state
153
+ )
130
154
  self._validate_output(output=output, node_name=current_node.name)
131
- await self.callback.end_node(node_name=current_node.name)
155
+ await callback.end_node(node_name=current_node.name)
132
156
  else:
133
157
  # Include graph's internal state in the node execution context
134
158
  output = await current_node.invoke(input=state, state=self.state)
@@ -153,24 +177,38 @@ class Graph:
153
177
  "by setting `max_steps` when initializing the Graph object."
154
178
  )
155
179
  # TODO JB: may need to add end callback here to close the queue for every execution
156
- if self.callback and "callback" in state:
157
- await self.callback.close()
180
+ if callback and "callback" in state:
181
+ await callback.close()
158
182
  del state["callback"]
159
183
  return state
160
184
 
161
185
  def get_callback(self):
162
- self.callback = self.Callback()
163
- return self.callback
186
+ """Get a new instance of the callback class.
187
+
188
+ :return: A new instance of the callback class.
189
+ :rtype: Callback
190
+ """
191
+ callback = self.Callback()
192
+ return callback
193
+
194
+ def set_callback(self, callback_class: Type[Callback]):
195
+ """Set the callback class that is returned by the `get_callback` method and used
196
+ as the default callback when no callback is passed to the `execute` method.
197
+
198
+ :param callback_class: The callback class to use as the default callback.
199
+ :type callback_class: Type[Callback]
200
+ """
201
+ self.Callback = callback_class
164
202
 
165
- def _get_node_by_name(self, node_name: str) -> _Node:
203
+ def _get_node_by_name(self, node_name: str) -> NodeProtocol:
166
204
  """Get a node by its name.
167
-
205
+
168
206
  Args:
169
207
  node_name: The name of the node to find.
170
-
208
+
171
209
  Returns:
172
210
  The node with the given name.
173
-
211
+
174
212
  Raises:
175
213
  Exception: If no node with the given name is found.
176
214
  """
@@ -191,12 +229,16 @@ class Graph:
191
229
  try:
192
230
  import networkx as nx
193
231
  except ImportError:
194
- raise ImportError("NetworkX is required for visualization. Please install it with 'pip install networkx'.")
232
+ raise ImportError(
233
+ "NetworkX is required for visualization. Please install it with 'pip install networkx'."
234
+ )
195
235
 
196
236
  try:
197
237
  import matplotlib.pyplot as plt
198
238
  except ImportError:
199
- raise ImportError("Matplotlib is required for visualization. Please install it with 'pip install matplotlib'.")
239
+ raise ImportError(
240
+ "Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
241
+ )
200
242
 
201
243
  G = nx.DiGraph()
202
244
 
@@ -207,7 +249,9 @@ class Graph:
207
249
  G.add_edge(edge.source.name, edge.destination.name)
208
250
 
209
251
  if nx.is_directed_acyclic_graph(G):
210
- logger.info("The graph is acyclic. Visualization will use a topological layout.")
252
+ logger.info(
253
+ "The graph is acyclic. Visualization will use a topological layout."
254
+ )
211
255
  # Use topological layout if acyclic
212
256
  # Compute the topological generations
213
257
  generations = list(nx.topological_generations(G))
@@ -241,20 +285,30 @@ class Graph:
241
285
  pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
242
286
 
243
287
  else:
244
- print("Warning: The graph contains cycles. Visualization will use a spring layout.")
288
+ print(
289
+ "Warning: The graph contains cycles. Visualization will use a spring layout."
290
+ )
245
291
  pos = nx.spring_layout(G, k=1, iterations=50)
246
292
 
247
293
  plt.figure(figsize=(8, 6))
248
- nx.draw(G, pos, with_labels=True, node_color='lightblue',
249
- node_size=3000, font_size=8, font_weight='bold',
250
- arrows=True, edge_color='gray', arrowsize=20)
294
+ nx.draw(
295
+ G,
296
+ pos,
297
+ with_labels=True,
298
+ node_color="lightblue",
299
+ node_size=3000,
300
+ font_size=8,
301
+ font_weight="bold",
302
+ arrows=True,
303
+ edge_color="gray",
304
+ arrowsize=20,
305
+ )
251
306
 
252
- plt.axis('off')
307
+ plt.axis("off")
253
308
  plt.show()
254
309
 
255
310
 
256
-
257
311
  class Edge:
258
312
  def __init__(self, source, destination):
259
313
  self.source = source
260
- self.destination = destination
314
+ self.destination = destination
graphai/nodes/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from graphai.nodes.base import node, router
2
2
 
3
- __all__ = ["node", "router"]
3
+ __all__ = ["node", "router"]
graphai/nodes/base.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  from typing import Any, Callable, Dict, Optional
3
+ from pydantic import Field
3
4
 
4
5
  from graphai.callback import Callback
5
6
  from graphai.utils import FunctionSchema
@@ -9,7 +10,11 @@ class NodeMeta(type):
9
10
  @staticmethod
10
11
  def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
11
12
  init_signature = inspect.signature(cls_type.__init__)
12
- init_params = {name: arg for name, arg in init_signature.parameters.items() if name != "self"}
13
+ init_params = {
14
+ name: arg
15
+ for name, arg in init_signature.parameters.items()
16
+ if name != "self"
17
+ }
13
18
  return init_params
14
19
 
15
20
  def __call__(cls, *args, **kwargs):
@@ -33,18 +38,32 @@ class _Node:
33
38
  stream: bool = False,
34
39
  name: str | None = None,
35
40
  ) -> Callable:
36
- """Decorator validating node structure.
37
- """
41
+ """Decorator validating node structure."""
38
42
  if not callable(func):
39
43
  raise ValueError("Node must be a callable function.")
40
-
44
+
41
45
  func_signature = inspect.signature(func)
42
- schema = FunctionSchema(func)
46
+ schema: FunctionSchema = FunctionSchema.from_callable(func)
43
47
 
44
48
  class NodeClass:
45
49
  _func_signature = func_signature
46
- is_router = None
47
- _stream = stream
50
+ is_router: bool = Field(
51
+ default=False, description="Whether the node is a router."
52
+ )
53
+ # following attributes will be overridden by the decorator
54
+ name: str | None = Field(default=None, description="The name of the node.")
55
+ is_start: bool = Field(
56
+ default=False, description="Whether the node is the start of the graph."
57
+ )
58
+ is_end: bool = Field(
59
+ default=False, description="Whether the node is the end of the graph."
60
+ )
61
+ schema: FunctionSchema | None = Field(
62
+ default=None, description="The schema of the node."
63
+ )
64
+ stream: bool = Field(
65
+ default=False, description="Whether the node includes streaming object."
66
+ )
48
67
 
49
68
  def __init__(self):
50
69
  self._expected_params = set(self._func_signature.parameters.keys())
@@ -56,9 +75,13 @@ class _Node:
56
75
 
57
76
  async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
58
77
  # filter out unexpected keyword args
59
- expected_kwargs = {k: v for k, v in kwargs.items() if k in self._expected_params}
78
+ expected_kwargs = {
79
+ k: v for k, v in kwargs.items() if k in self._expected_params
80
+ }
60
81
  # Convert args to kwargs based on the function signature
61
- args_names = list(self._func_signature.parameters.keys())[1:len(args)+1] # skip 'self'
82
+ args_names = list(self._func_signature.parameters.keys())[
83
+ 1 : len(args) + 1
84
+ ] # skip 'self'
62
85
  expected_args_kwargs = dict(zip(args_names, args))
63
86
  # Combine filtered args and kwargs
64
87
  combined_params = {**expected_args_kwargs, **expected_kwargs}
@@ -87,7 +110,6 @@ class _Node:
87
110
  )
88
111
  return filtered_params
89
112
 
90
-
91
113
  @classmethod
92
114
  def get_signature(cls):
93
115
  """Returns the signature of the decorated function as LLM readable
@@ -97,15 +119,24 @@ class _Node:
97
119
  if NodeClass._func_signature:
98
120
  for param in NodeClass._func_signature.parameters.values():
99
121
  if param.default is param.empty:
100
- signature_components.append(f"{param.name}: {param.annotation}")
122
+ signature_components.append(
123
+ f"{param.name}: {param.annotation}"
124
+ )
101
125
  else:
102
- signature_components.append(f"{param.name}: {param.annotation} = {param.default}")
126
+ signature_components.append(
127
+ f"{param.name}: {param.annotation} = {param.default}"
128
+ )
103
129
  else:
104
130
  return "No signature"
105
131
  return "\n".join(signature_components)
106
132
 
107
133
  @classmethod
108
- async def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None, state: Optional[Dict[str, Any]] = None):
134
+ async def invoke(
135
+ cls,
136
+ input: Dict[str, Any],
137
+ callback: Optional[Callback] = None,
138
+ state: Optional[Dict[str, Any]] = None,
139
+ ):
109
140
  if callback:
110
141
  if stream:
111
142
  input["callback"] = callback
@@ -116,13 +147,16 @@ class _Node:
116
147
  # Add state to the input if present and the parameter exists in the function signature
117
148
  if state is not None and "state" in cls._func_signature.parameters:
118
149
  input["state"] = state
119
-
150
+
120
151
  instance = cls()
121
152
  out = await instance.execute(**input)
122
153
  return out
123
154
 
124
155
  NodeClass.__name__ = func.__name__
125
- NodeClass.name = name or func.__name__
156
+ node_class_name = name or func.__name__
157
+ if node_class_name is None:
158
+ raise ValueError("Unexpected error: node name not set.")
159
+ NodeClass.name = node_class_name
126
160
  NodeClass.__doc__ = func.__doc__
127
161
  NodeClass.is_start = start
128
162
  NodeClass.is_end = end
@@ -141,8 +175,11 @@ class _Node:
141
175
  ):
142
176
  # We must wrap the call to the decorator in a function for it to work
143
177
  # correctly with or without parenthesis
144
- def wrap(func: Callable, start=start, end=end, stream=stream, name=name) -> Callable:
178
+ def wrap(
179
+ func: Callable, start=start, end=end, stream=stream, name=name
180
+ ) -> Callable:
145
181
  return self._node(func=func, start=start, end=end, stream=stream, name=name)
182
+
146
183
  if func:
147
184
  # Decorator is called without parenthesis
148
185
  return wrap(func=func, start=start, end=end, stream=stream, name=name)
graphai/utils.py CHANGED
@@ -1,11 +1,81 @@
1
1
  import inspect
2
- from typing import Any, Callable, Dict, List, Union, Optional
2
+ from typing import Any, Callable, List, Optional
3
3
  from pydantic import BaseModel, Field
4
+ import logging
5
+
6
+ import colorlog
7
+
8
+
9
+ class CustomFormatter(colorlog.ColoredFormatter):
10
+ """Custom formatter for the logger."""
11
+
12
+ def __init__(self):
13
+ super().__init__(
14
+ "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ log_colors={
17
+ "DEBUG": "cyan",
18
+ "INFO": "green",
19
+ "WARNING": "yellow",
20
+ "ERROR": "red",
21
+ "CRITICAL": "bold_red",
22
+ },
23
+ reset=True,
24
+ style="%",
25
+ )
26
+
27
+
28
+ def add_coloured_handler(logger):
29
+ """Add a coloured handler to the logger."""
30
+ formatter = CustomFormatter()
31
+ console_handler = logging.StreamHandler()
32
+ console_handler.setFormatter(formatter)
33
+ logger.addHandler(console_handler)
34
+ return logger
35
+
36
+
37
+ def setup_custom_logger(name):
38
+ """Setup a custom logger."""
39
+ logger = logging.getLogger(name)
40
+
41
+ if not logger.hasHandlers():
42
+ add_coloured_handler(logger)
43
+ logger.setLevel(logging.INFO)
44
+ logger.propagate = False
45
+
46
+ return logger
47
+
48
+
49
+ logger: logging.Logger = setup_custom_logger(__name__)
50
+
51
+
52
+ def openai_type_mapping(param_type: str) -> str:
53
+ if param_type == "int":
54
+ return "number"
55
+ elif param_type == "float":
56
+ return "number"
57
+ elif param_type == "str":
58
+ return "string"
59
+ elif param_type == "bool":
60
+ return "boolean"
61
+ else:
62
+ return "object"
4
63
 
5
64
 
6
65
  class Parameter(BaseModel):
7
- class Config:
8
- arbitrary_types_allowed = True
66
+ """Parameter for a function.
67
+
68
+ :param name: The name of the parameter.
69
+ :type name: str
70
+ :param description: The description of the parameter.
71
+ :type description: Optional[str]
72
+ :param type: The type of the parameter.
73
+ :type type: str
74
+ :param default: The default value of the parameter.
75
+ :type default: Any
76
+ :param required: Whether the parameter is required.
77
+ :type required: bool
78
+ """
9
79
 
10
80
  name: str = Field(description="The name of the parameter")
11
81
  description: Optional[str] = Field(
@@ -15,15 +85,22 @@ class Parameter(BaseModel):
15
85
  default: Any = Field(description="The default value of the parameter")
16
86
  required: bool = Field(description="Whether the parameter is required")
17
87
 
18
- def to_openai(self):
88
+ def to_dict(self) -> dict[str, Any]:
89
+ """Convert the parameter to a dictionary for an standard dictionary-based function schema.
90
+ This is the most common format used by LLM providers, including OpenAI, Ollama, and others.
91
+
92
+ :return: The parameter in dictionary format.
93
+ :rtype: dict[str, Any]
94
+ """
19
95
  return {
20
96
  self.name: {
21
97
  "description": self.description,
22
- "type": self.type,
98
+ "type": openai_type_mapping(self.type),
23
99
  }
24
100
  }
25
101
 
26
- class FunctionSchema:
102
+
103
+ class FunctionSchema(BaseModel):
27
104
  """Class that consumes a function and can return a schema required by
28
105
  different LLMs for function calling.
29
106
  """
@@ -32,35 +109,68 @@ class FunctionSchema:
32
109
  description: str = Field(description="The description of the function")
33
110
  signature: str = Field(description="The signature of the function")
34
111
  output: str = Field(description="The output of the function")
35
- parameters: List[Parameter] = Field(description="The parameters of the function")
112
+ parameters: list[Parameter] = Field(description="The parameters of the function")
113
+
114
+ @classmethod
115
+ def from_callable(cls, function: Callable) -> "FunctionSchema":
116
+ """Initialize the FunctionSchema.
36
117
 
37
- def __init__(self, function: Union[Callable, BaseModel]):
38
- self.function = function
118
+ :param function: The function to consume.
119
+ :type function: Callable
120
+ """
39
121
  if callable(function):
40
- self._process_function(function)
122
+ name = function.__name__
123
+ description = str(inspect.getdoc(function))
124
+ if description is None or description == "":
125
+ logger.warning(f"Function {name} has no docstring")
126
+ signature = str(inspect.signature(function))
127
+ output = str(inspect.signature(function).return_annotation)
128
+ parameters = []
129
+ for param in inspect.signature(function).parameters.values():
130
+ parameters.append(
131
+ Parameter(
132
+ name=param.name,
133
+ type=param.annotation.__name__,
134
+ default=param.default,
135
+ required=param.default is inspect.Parameter.empty,
136
+ )
137
+ )
138
+ return cls.model_construct(
139
+ name=name,
140
+ description=description,
141
+ signature=signature,
142
+ output=output,
143
+ parameters=parameters,
144
+ )
41
145
  elif isinstance(function, BaseModel):
42
146
  raise NotImplementedError("Pydantic BaseModel not implemented yet.")
43
147
  else:
44
148
  raise TypeError("Function must be a Callable or BaseModel")
45
149
 
46
- def _process_function(self, function: Callable):
47
- self.name = function.__name__
48
- self.description = str(inspect.getdoc(function))
49
- self.signature = str(inspect.signature(function))
50
- self.output = str(inspect.signature(function).return_annotation)
51
- parameters = []
52
- for param in inspect.signature(function).parameters.values():
53
- parameters.append(
54
- Parameter(
55
- name=param.name,
56
- type=param.annotation.__name__,
57
- default=param.default,
58
- required=param.default is inspect.Parameter.empty,
150
+ @classmethod
151
+ def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
152
+ signature_parts = []
153
+ for field_name, field_model in model.__annotations__.items():
154
+ field_info = model.model_fields[field_name]
155
+ default_value = field_info.default
156
+ if default_value:
157
+ default_repr = repr(default_value)
158
+ signature_part = (
159
+ f"{field_name}: {field_model.__name__} = {default_repr}"
59
160
  )
60
- )
61
- self.parameters = parameters
161
+ else:
162
+ signature_part = f"{field_name}: {field_model.__name__}"
163
+ signature_parts.append(signature_part)
164
+ signature = f"({', '.join(signature_parts)}) -> str"
165
+ return cls.model_construct(
166
+ name=model.__class__.__name__,
167
+ description=model.__doc__ or "",
168
+ signature=signature,
169
+ output="", # TODO: Implement output
170
+ parameters=[],
171
+ )
62
172
 
63
- def to_openai(self):
173
+ def to_dict(self) -> dict:
64
174
  schema_dict = {
65
175
  "type": "function",
66
176
  "function": {
@@ -69,15 +179,7 @@ class FunctionSchema:
69
179
  "parameters": {
70
180
  "type": "object",
71
181
  "properties": {
72
- param.name: {
73
- "description": (
74
- param.description
75
- if isinstance(param.description, str)
76
- else "None provided"
77
- ),
78
- "type": self._openai_type_mapping(param.type),
79
- }
80
- for param in self.parameters
182
+ k: v for param in self.parameters for k, v in param.to_dict().items()
81
183
  },
82
184
  "required": [
83
185
  param.name for param in self.parameters if param.required
@@ -87,39 +189,17 @@ class FunctionSchema:
87
189
  }
88
190
  return schema_dict
89
191
 
90
- def _openai_type_mapping(self, param_type: str) -> str:
91
- if param_type == "int":
92
- return "number"
93
- elif param_type == "float":
94
- return "number"
95
- elif param_type == "str":
96
- return "string"
97
- elif param_type == "bool":
98
- return "boolean"
99
- else:
100
- return "object"
192
+ def to_openai(self) -> dict:
193
+ return self.to_dict()
101
194
 
102
195
 
103
- def get_schema_pydantic(model: BaseModel) -> Dict[str, Any]:
104
- signature_parts = []
105
- for field_name, field_model in model.__annotations__.items():
106
- field_info = model.__fields__[field_name]
107
- default_value = field_info.default
196
+ DEFAULT = set(["default", "openai", "ollama", "litellm"])
108
197
 
109
- if default_value:
110
- default_repr = repr(default_value)
111
- signature_part = (
112
- f"{field_name}: {field_model.__name__} = {default_repr}"
113
- )
114
- else:
115
- signature_part = f"{field_name}: {field_model.__name__}"
116
-
117
- signature_parts.append(signature_part)
118
- signature = f"({', '.join(signature_parts)}) -> str"
119
- schema = FunctionSchema(
120
- name=model.__class__.__name__,
121
- description=model.__doc__,
122
- signature=signature,
123
- output="", # TODO: Implement output
124
- )
125
- return schema
198
+
199
+ def get_schemas(callables: List[Callable], format: str = "default") -> list[dict]:
200
+ if format in DEFAULT:
201
+ return [
202
+ FunctionSchema.from_callable(callable).to_dict() for callable in callables
203
+ ]
204
+ else:
205
+ raise ValueError(f"Format {format} not supported")
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.4
3
+ Version: 0.0.6
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
7
- Requires-Dist: semantic-router>=0.1.5
8
7
  Requires-Dist: networkx>=3.4.2
9
8
  Requires-Dist: matplotlib>=3.10.0
9
+ Requires-Dist: pydantic>=2.11.1
10
+ Requires-Dist: colorlog>=6.9.0
10
11
  Provides-Extra: dev
11
12
  Requires-Dist: ipykernel>=6.25.0; extra == "dev"
12
13
  Requires-Dist: ruff>=0.1.5; extra == "dev"
@@ -16,7 +17,7 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
16
17
  Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
17
18
  Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
18
19
  Requires-Dist: mypy>=1.7.1; extra == "dev"
19
- Requires-Dist: black[jupyter]<24.5.0,>=23.12.1; extra == "dev"
20
+ Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
20
21
  Provides-Extra: docs
21
22
  Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
22
23
 
@@ -0,0 +1,10 @@
1
+ graphai/__init__.py,sha256=kZJ21W6gwN-eRzvrWQf8xDTPCIWTIuyWq1IGDS9tn7Y,110
2
+ graphai/callback.py,sha256=M2gEpj7uVvANg2dVgxKFMUSPIM362YSilShMIFfrr8s,7351
3
+ graphai/graph.py,sha256=p5iAXQqx0266bAS2zGztv34UXW3-TiA7d2XIsKxAW54,11744
4
+ graphai/utils.py,sha256=LlL-Wx643nIeRFAl2xcv0crNQcA_0563epRo8ZsyL40,6898
5
+ graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
6
+ graphai/nodes/base.py,sha256=-ZOfJhxews5CGutYB5lfoIVvJ6dqdYJWeeKzDMz9odg,7624
7
+ graphai_lib-0.0.6.dist-info/METADATA,sha256=DZ-RIueKiVmtUJN2vXm02ZD-_aB5uiduWwN6HxriP4w,1006
8
+ graphai_lib-0.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ graphai_lib-0.0.6.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
10
+ graphai_lib-0.0.6.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (77.0.3)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,10 +0,0 @@
1
- graphai/__init__.py,sha256=EHigFOWewDXLZXbdfjZH9kdLPhw6NT0ChS77lNAVAA8,109
2
- graphai/callback.py,sha256=K-h44pyL2VLXwJzIB_bcVYp5R6xv8zNca5FmN6994Uk,7598
3
- graphai/graph.py,sha256=EALHEhbXAaJmTvm7cL3Tdh0moRIw7lIyJDCNnCts2QA,10335
4
- graphai/utils.py,sha256=zrgpk82rIn7lwh631KhN-OgMAJMdbm0k5GPL1eMf2sQ,4522
5
- graphai/nodes/__init__.py,sha256=4826Ubk5yUfbVH7F8DmoTKQyax624Q2QJHsGxqgQ_ng,73
6
- graphai/nodes/base.py,sha256=SZdYhFfXdtFmabFbMRcEGd8_h8w-g6s4I7hMEo6JCk8,6331
7
- graphai_lib-0.0.4.dist-info/METADATA,sha256=wfxK82ZO-slBl-xVNmdbYGVvPo7YKm8-SLfdrkNWtKs,982
8
- graphai_lib-0.0.4.dist-info/WHEEL,sha256=1tXe9gY0PYatrMPMDd6jXqjfpz_B-Wqm32CPfRC58XU,91
9
- graphai_lib-0.0.4.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
10
- graphai_lib-0.0.4.dist-info/RECORD,,