graphai-lib 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
7
- Requires-Dist: semantic-router>=0.1.5
8
7
  Requires-Dist: networkx>=3.4.2
9
8
  Requires-Dist: matplotlib>=3.10.0
9
+ Requires-Dist: pydantic>=2.11.1
10
+ Requires-Dist: colorlog>=6.9.0
10
11
  Provides-Extra: dev
11
12
  Requires-Dist: ipykernel>=6.25.0; extra == "dev"
12
13
  Requires-Dist: ruff>=0.1.5; extra == "dev"
@@ -16,7 +17,7 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
16
17
  Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
17
18
  Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
18
19
  Requires-Dist: mypy>=1.7.1; extra == "dev"
19
- Requires-Dist: black[jupyter]<24.5.0,>=23.12.1; extra == "dev"
20
+ Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
20
21
  Provides-Extra: docs
21
22
  Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
22
23
 
@@ -1,4 +1,4 @@
1
1
  from graphai.graph import Graph
2
2
  from graphai.nodes import node, router
3
3
 
4
- __all__ = ["node", "router", "Graph"]
4
+ __all__ = ["node", "router", "Graph"]
@@ -1,12 +1,12 @@
1
1
  import asyncio
2
2
  from pydantic import Field
3
- from typing import Optional
3
+ from typing import Optional, Any
4
4
  from collections.abc import AsyncIterator
5
- from semantic_router.utils.logger import logger
6
5
 
7
6
 
8
7
  log_stream = True
9
8
 
9
+
10
10
  class Callback:
11
11
  identifier: str = Field(
12
12
  default="graphai",
@@ -14,7 +14,7 @@ class Callback:
14
14
  "The identifier for special tokens. This allows us to easily "
15
15
  "identify special tokens in the stream so we can handle them "
16
16
  "correctly in any downstream process."
17
- )
17
+ ),
18
18
  )
19
19
  special_token_format: str = Field(
20
20
  default="<{identifier}:{token}:{params}>",
@@ -32,8 +32,8 @@ class Callback:
32
32
  examples=[
33
33
  "<{identifier}:{token}:{params}>",
34
34
  "<[{identifier} | {token} | {params}]>",
35
- "<{token}:{params}>"
36
- ]
35
+ "<{token}:{params}>",
36
+ ],
37
37
  )
38
38
  token_format: str = Field(
39
39
  default="{token}",
@@ -41,27 +41,23 @@ class Callback:
41
41
  "The format for streamed tokens. This is used to format the "
42
42
  "tokens typically returned from LLMs. By default, no special "
43
43
  "formatting is applied."
44
- )
44
+ ),
45
45
  )
46
46
  _first_token: bool = Field(
47
47
  default=True,
48
48
  description="Whether this is the first token in the stream.",
49
- exclude=True
49
+ exclude=True,
50
50
  )
51
51
  _current_node_name: Optional[str] = Field(
52
- default=None,
53
- description="The name of the current node.",
54
- exclude=True
52
+ default=None, description="The name of the current node.", exclude=True
55
53
  )
56
54
  _active: bool = Field(
57
- default=True,
58
- description="Whether the callback is active.",
59
- exclude=True
55
+ default=True, description="Whether the callback is active.", exclude=True
60
56
  )
61
57
  _done: bool = Field(
62
58
  default=False,
63
59
  description="Whether the stream is done and should be closed.",
64
- exclude=True
60
+ exclude=True,
65
61
  )
66
62
  queue: asyncio.Queue
67
63
 
@@ -83,7 +79,7 @@ class Callback:
83
79
  @property
84
80
  def first_token(self) -> bool:
85
81
  return self._first_token
86
-
82
+
87
83
  @first_token.setter
88
84
  def first_token(self, value: bool):
89
85
  self._first_token = value
@@ -91,7 +87,7 @@ class Callback:
91
87
  @property
92
88
  def current_node_name(self) -> Optional[str]:
93
89
  return self._current_node_name
94
-
90
+
95
91
  @current_node_name.setter
96
92
  def current_node_name(self, value: Optional[str]):
97
93
  self._current_node_name = value
@@ -99,7 +95,7 @@ class Callback:
99
95
  @property
100
96
  def active(self) -> bool:
101
97
  return self._active
102
-
98
+
103
99
  @active.setter
104
100
  def active(self, value: bool):
105
101
  self._active = value
@@ -110,7 +106,7 @@ class Callback:
110
106
  self._check_node_name(node_name=node_name)
111
107
  # otherwise we just assume node is correct and send token
112
108
  self.queue.put_nowait(token)
113
-
109
+
114
110
  async def acall(self, token: str, node_name: Optional[str] = None):
115
111
  # TODO JB: do we need to have `node_name` param?
116
112
  if self._done:
@@ -118,16 +114,13 @@ class Callback:
118
114
  self._check_node_name(node_name=node_name)
119
115
  # otherwise we just assume node is correct and send token
120
116
  self.queue.put_nowait(token)
121
-
117
+
122
118
  async def aiter(self) -> AsyncIterator[str]:
123
119
  """Used by receiver to get the tokens from the stream queue. Creates
124
120
  a generator that yields tokens from the queue until the END token is
125
121
  received.
126
122
  """
127
- end_token = await self._build_special_token(
128
- name="END",
129
- params=None
130
- )
123
+ end_token = await self._build_special_token(name="END", params=None)
131
124
  while True: # Keep going until we see the END token
132
125
  try:
133
126
  if self._done and self.queue.empty():
@@ -142,8 +135,7 @@ class Callback:
142
135
  self._done = True # Mark as done after processing all tokens
143
136
 
144
137
  async def start_node(self, node_name: str, active: bool = True):
145
- """Starts a new node and emits the start token.
146
- """
138
+ """Starts a new node and emits the start token."""
147
139
  if self._done:
148
140
  raise RuntimeError("Cannot start node on a closed stream")
149
141
  self.current_node_name = node_name
@@ -152,27 +144,23 @@ class Callback:
152
144
  self.active = active
153
145
  if self.active:
154
146
  token = await self._build_special_token(
155
- name=f"{self.current_node_name}:start",
156
- params=None
147
+ name=f"{self.current_node_name}:start", params=None
157
148
  )
158
149
  self.queue.put_nowait(token)
159
150
  # TODO JB: should we use two tokens here?
160
151
  node_token = await self._build_special_token(
161
- name=self.current_node_name,
162
- params=None
152
+ name=self.current_node_name, params=None
163
153
  )
164
154
  self.queue.put_nowait(node_token)
165
-
155
+
166
156
  async def end_node(self, node_name: str):
167
- """Emits the end token for the current node.
168
- """
157
+ """Emits the end token for the current node."""
169
158
  if self._done:
170
159
  raise RuntimeError("Cannot end node on a closed stream")
171
- #self.current_node_name = node_name
160
+ # self.current_node_name = node_name
172
161
  if self.active:
173
162
  node_token = await self._build_special_token(
174
- name=f"{self.current_node_name}:end",
175
- params=None
163
+ name=f"{self.current_node_name}:end", params=None
176
164
  )
177
165
  self.queue.put_nowait(node_token)
178
166
 
@@ -182,10 +170,7 @@ class Callback:
182
170
  """
183
171
  if self._done:
184
172
  return
185
- end_token = await self._build_special_token(
186
- name="END",
187
- params=None
188
- )
173
+ end_token = await self._build_special_token(name="END", params=None)
189
174
  self._done = True # Set done before putting the end token
190
175
  self.queue.put_nowait(end_token)
191
176
  # Don't wait for queue.join() as it can cause deadlock
@@ -198,8 +183,10 @@ class Callback:
198
183
  raise ValueError(
199
184
  f"Node name mismatch: {self.current_node_name} != {node_name}"
200
185
  )
201
-
202
- async def _build_special_token(self, name: str, params: dict[str, any] | None = None):
186
+
187
+ async def _build_special_token(
188
+ self, name: str, params: dict[str, Any] | None = None
189
+ ):
203
190
  if params:
204
191
  params_str = ",".join([f"{k}={v}" for k, v in params.items()])
205
192
  else:
@@ -209,7 +196,5 @@ class Callback:
209
196
  else:
210
197
  identifier = ""
211
198
  return self.special_token_format.format(
212
- identifier=identifier,
213
- token=name,
214
- params=params_str
199
+ identifier=identifier, token=name, params=params_str
215
200
  )
@@ -1,11 +1,13 @@
1
1
  from typing import List, Dict, Any, Optional
2
2
  from graphai.nodes.base import _Node
3
3
  from graphai.callback import Callback
4
- from semantic_router.utils.logger import logger
4
+ from graphai.utils import logger
5
5
 
6
6
 
7
7
  class Graph:
8
- def __init__(self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None):
8
+ def __init__(
9
+ self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
10
+ ):
9
11
  self.nodes: Dict[str, _Node] = {}
10
12
  self.edges: List[Any] = []
11
13
  self.start_node: Optional[_Node] = None
@@ -49,7 +51,7 @@ class Graph:
49
51
 
50
52
  def add_edge(self, source: _Node | str, destination: _Node | str):
51
53
  """Adds an edge between two nodes that already exist in the graph.
52
-
54
+
53
55
  Args:
54
56
  source: The source node or its name.
55
57
  destination: The destination node or its name.
@@ -60,7 +62,7 @@ class Graph:
60
62
  source_node = self.nodes.get(source)
61
63
  else:
62
64
  # Check if it's a node-like object by looking for required attributes
63
- if hasattr(source, 'name'):
65
+ if hasattr(source, "name"):
64
66
  source_node = self.nodes.get(source.name)
65
67
  if source_node is None:
66
68
  raise ValueError(
@@ -71,7 +73,7 @@ class Graph:
71
73
  destination_node = self.nodes.get(destination)
72
74
  else:
73
75
  # Check if it's a node-like object by looking for required attributes
74
- if hasattr(destination, 'name'):
76
+ if hasattr(destination, "name"):
75
77
  destination_node = self.nodes.get(destination.name)
76
78
  if destination_node is None:
77
79
  raise ValueError(
@@ -80,7 +82,9 @@ class Graph:
80
82
  edge = Edge(source_node, destination_node)
81
83
  self.edges.append(edge)
82
84
 
83
- def add_router(self, sources: list[_Node], router: _Node, destinations: List[_Node]):
85
+ def add_router(
86
+ self, sources: list[_Node], router: _Node, destinations: List[_Node]
87
+ ):
84
88
  if not router.is_router:
85
89
  raise TypeError("A router object must be passed to the router parameter.")
86
90
  [self.add_edge(source, router) for source in sources]
@@ -126,7 +130,9 @@ class Graph:
126
130
  # add callback tokens and param here if we are streaming
127
131
  await self.callback.start_node(node_name=current_node.name)
128
132
  # Include graph's internal state in the node execution context
129
- output = await current_node.invoke(input=state, callback=self.callback, state=self.state)
133
+ output = await current_node.invoke(
134
+ input=state, callback=self.callback, state=self.state
135
+ )
130
136
  self._validate_output(output=output, node_name=current_node.name)
131
137
  await self.callback.end_node(node_name=current_node.name)
132
138
  else:
@@ -164,13 +170,13 @@ class Graph:
164
170
 
165
171
  def _get_node_by_name(self, node_name: str) -> _Node:
166
172
  """Get a node by its name.
167
-
173
+
168
174
  Args:
169
175
  node_name: The name of the node to find.
170
-
176
+
171
177
  Returns:
172
178
  The node with the given name.
173
-
179
+
174
180
  Raises:
175
181
  Exception: If no node with the given name is found.
176
182
  """
@@ -191,12 +197,16 @@ class Graph:
191
197
  try:
192
198
  import networkx as nx
193
199
  except ImportError:
194
- raise ImportError("NetworkX is required for visualization. Please install it with 'pip install networkx'.")
200
+ raise ImportError(
201
+ "NetworkX is required for visualization. Please install it with 'pip install networkx'."
202
+ )
195
203
 
196
204
  try:
197
205
  import matplotlib.pyplot as plt
198
206
  except ImportError:
199
- raise ImportError("Matplotlib is required for visualization. Please install it with 'pip install matplotlib'.")
207
+ raise ImportError(
208
+ "Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
209
+ )
200
210
 
201
211
  G = nx.DiGraph()
202
212
 
@@ -207,7 +217,9 @@ class Graph:
207
217
  G.add_edge(edge.source.name, edge.destination.name)
208
218
 
209
219
  if nx.is_directed_acyclic_graph(G):
210
- logger.info("The graph is acyclic. Visualization will use a topological layout.")
220
+ logger.info(
221
+ "The graph is acyclic. Visualization will use a topological layout."
222
+ )
211
223
  # Use topological layout if acyclic
212
224
  # Compute the topological generations
213
225
  generations = list(nx.topological_generations(G))
@@ -241,20 +253,30 @@ class Graph:
241
253
  pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
242
254
 
243
255
  else:
244
- print("Warning: The graph contains cycles. Visualization will use a spring layout.")
256
+ print(
257
+ "Warning: The graph contains cycles. Visualization will use a spring layout."
258
+ )
245
259
  pos = nx.spring_layout(G, k=1, iterations=50)
246
260
 
247
261
  plt.figure(figsize=(8, 6))
248
- nx.draw(G, pos, with_labels=True, node_color='lightblue',
249
- node_size=3000, font_size=8, font_weight='bold',
250
- arrows=True, edge_color='gray', arrowsize=20)
262
+ nx.draw(
263
+ G,
264
+ pos,
265
+ with_labels=True,
266
+ node_color="lightblue",
267
+ node_size=3000,
268
+ font_size=8,
269
+ font_weight="bold",
270
+ arrows=True,
271
+ edge_color="gray",
272
+ arrowsize=20,
273
+ )
251
274
 
252
- plt.axis('off')
275
+ plt.axis("off")
253
276
  plt.show()
254
277
 
255
278
 
256
-
257
279
  class Edge:
258
280
  def __init__(self, source, destination):
259
281
  self.source = source
260
- self.destination = destination
282
+ self.destination = destination
@@ -1,3 +1,3 @@
1
1
  from graphai.nodes.base import node, router
2
2
 
3
- __all__ = ["node", "router"]
3
+ __all__ = ["node", "router"]
@@ -1,5 +1,6 @@
1
1
  import inspect
2
2
  from typing import Any, Callable, Dict, Optional
3
+ from pydantic import Field
3
4
 
4
5
  from graphai.callback import Callback
5
6
  from graphai.utils import FunctionSchema
@@ -9,7 +10,11 @@ class NodeMeta(type):
9
10
  @staticmethod
10
11
  def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
11
12
  init_signature = inspect.signature(cls_type.__init__)
12
- init_params = {name: arg for name, arg in init_signature.parameters.items() if name != "self"}
13
+ init_params = {
14
+ name: arg
15
+ for name, arg in init_signature.parameters.items()
16
+ if name != "self"
17
+ }
13
18
  return init_params
14
19
 
15
20
  def __call__(cls, *args, **kwargs):
@@ -33,18 +38,32 @@ class _Node:
33
38
  stream: bool = False,
34
39
  name: str | None = None,
35
40
  ) -> Callable:
36
- """Decorator validating node structure.
37
- """
41
+ """Decorator validating node structure."""
38
42
  if not callable(func):
39
43
  raise ValueError("Node must be a callable function.")
40
-
44
+
41
45
  func_signature = inspect.signature(func)
42
- schema = FunctionSchema(func)
46
+ schema: FunctionSchema = FunctionSchema.from_callable(func)
43
47
 
44
48
  class NodeClass:
45
49
  _func_signature = func_signature
46
- is_router = None
47
- _stream = stream
50
+ is_router: bool = Field(
51
+ default=False, description="Whether the node is a router."
52
+ )
53
+ # following attributes will be overridden by the decorator
54
+ name: str | None = Field(default=None, description="The name of the node.")
55
+ is_start: bool = Field(
56
+ default=False, description="Whether the node is the start of the graph."
57
+ )
58
+ is_end: bool = Field(
59
+ default=False, description="Whether the node is the end of the graph."
60
+ )
61
+ schema: FunctionSchema | None = Field(
62
+ default=None, description="The schema of the node."
63
+ )
64
+ stream: bool = Field(
65
+ default=False, description="Whether the node includes streaming object."
66
+ )
48
67
 
49
68
  def __init__(self):
50
69
  self._expected_params = set(self._func_signature.parameters.keys())
@@ -56,9 +75,13 @@ class _Node:
56
75
 
57
76
  async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
58
77
  # filter out unexpected keyword args
59
- expected_kwargs = {k: v for k, v in kwargs.items() if k in self._expected_params}
78
+ expected_kwargs = {
79
+ k: v for k, v in kwargs.items() if k in self._expected_params
80
+ }
60
81
  # Convert args to kwargs based on the function signature
61
- args_names = list(self._func_signature.parameters.keys())[1:len(args)+1] # skip 'self'
82
+ args_names = list(self._func_signature.parameters.keys())[
83
+ 1 : len(args) + 1
84
+ ] # skip 'self'
62
85
  expected_args_kwargs = dict(zip(args_names, args))
63
86
  # Combine filtered args and kwargs
64
87
  combined_params = {**expected_args_kwargs, **expected_kwargs}
@@ -87,7 +110,6 @@ class _Node:
87
110
  )
88
111
  return filtered_params
89
112
 
90
-
91
113
  @classmethod
92
114
  def get_signature(cls):
93
115
  """Returns the signature of the decorated function as LLM readable
@@ -97,15 +119,24 @@ class _Node:
97
119
  if NodeClass._func_signature:
98
120
  for param in NodeClass._func_signature.parameters.values():
99
121
  if param.default is param.empty:
100
- signature_components.append(f"{param.name}: {param.annotation}")
122
+ signature_components.append(
123
+ f"{param.name}: {param.annotation}"
124
+ )
101
125
  else:
102
- signature_components.append(f"{param.name}: {param.annotation} = {param.default}")
126
+ signature_components.append(
127
+ f"{param.name}: {param.annotation} = {param.default}"
128
+ )
103
129
  else:
104
130
  return "No signature"
105
131
  return "\n".join(signature_components)
106
132
 
107
133
  @classmethod
108
- async def invoke(cls, input: Dict[str, Any], callback: Optional[Callback] = None, state: Optional[Dict[str, Any]] = None):
134
+ async def invoke(
135
+ cls,
136
+ input: Dict[str, Any],
137
+ callback: Optional[Callback] = None,
138
+ state: Optional[Dict[str, Any]] = None,
139
+ ):
109
140
  if callback:
110
141
  if stream:
111
142
  input["callback"] = callback
@@ -116,13 +147,16 @@ class _Node:
116
147
  # Add state to the input if present and the parameter exists in the function signature
117
148
  if state is not None and "state" in cls._func_signature.parameters:
118
149
  input["state"] = state
119
-
150
+
120
151
  instance = cls()
121
152
  out = await instance.execute(**input)
122
153
  return out
123
154
 
124
155
  NodeClass.__name__ = func.__name__
125
- NodeClass.name = name or func.__name__
156
+ node_class_name = name or func.__name__
157
+ if node_class_name is None:
158
+ raise ValueError("Unexpected error: node name not set.")
159
+ NodeClass.name = node_class_name
126
160
  NodeClass.__doc__ = func.__doc__
127
161
  NodeClass.is_start = start
128
162
  NodeClass.is_end = end
@@ -141,8 +175,11 @@ class _Node:
141
175
  ):
142
176
  # We must wrap the call to the decorator in a function for it to work
143
177
  # correctly with or without parenthesis
144
- def wrap(func: Callable, start=start, end=end, stream=stream, name=name) -> Callable:
178
+ def wrap(
179
+ func: Callable, start=start, end=end, stream=stream, name=name
180
+ ) -> Callable:
145
181
  return self._node(func=func, start=start, end=end, stream=stream, name=name)
182
+
146
183
  if func:
147
184
  # Decorator is called without parenthesis
148
185
  return wrap(func=func, start=start, end=end, stream=stream, name=name)
@@ -0,0 +1,205 @@
1
+ import inspect
2
+ from typing import Any, Callable, List, Optional
3
+ from pydantic import BaseModel, Field
4
+ import logging
5
+
6
+ import colorlog
7
+
8
+
9
+ class CustomFormatter(colorlog.ColoredFormatter):
10
+ """Custom formatter for the logger."""
11
+
12
+ def __init__(self):
13
+ super().__init__(
14
+ "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ log_colors={
17
+ "DEBUG": "cyan",
18
+ "INFO": "green",
19
+ "WARNING": "yellow",
20
+ "ERROR": "red",
21
+ "CRITICAL": "bold_red",
22
+ },
23
+ reset=True,
24
+ style="%",
25
+ )
26
+
27
+
28
+ def add_coloured_handler(logger):
29
+ """Add a coloured handler to the logger."""
30
+ formatter = CustomFormatter()
31
+ console_handler = logging.StreamHandler()
32
+ console_handler.setFormatter(formatter)
33
+ logger.addHandler(console_handler)
34
+ return logger
35
+
36
+
37
+ def setup_custom_logger(name):
38
+ """Setup a custom logger."""
39
+ logger = logging.getLogger(name)
40
+
41
+ if not logger.hasHandlers():
42
+ add_coloured_handler(logger)
43
+ logger.setLevel(logging.INFO)
44
+ logger.propagate = False
45
+
46
+ return logger
47
+
48
+
49
+ logger: logging.Logger = setup_custom_logger(__name__)
50
+
51
+
52
+ def openai_type_mapping(param_type: str) -> str:
53
+ if param_type == "int":
54
+ return "number"
55
+ elif param_type == "float":
56
+ return "number"
57
+ elif param_type == "str":
58
+ return "string"
59
+ elif param_type == "bool":
60
+ return "boolean"
61
+ else:
62
+ return "object"
63
+
64
+
65
+ class Parameter(BaseModel):
66
+ """Parameter for a function.
67
+
68
+ :param name: The name of the parameter.
69
+ :type name: str
70
+ :param description: The description of the parameter.
71
+ :type description: Optional[str]
72
+ :param type: The type of the parameter.
73
+ :type type: str
74
+ :param default: The default value of the parameter.
75
+ :type default: Any
76
+ :param required: Whether the parameter is required.
77
+ :type required: bool
78
+ """
79
+
80
+ name: str = Field(description="The name of the parameter")
81
+ description: Optional[str] = Field(
82
+ default=None, description="The description of the parameter"
83
+ )
84
+ type: str = Field(description="The type of the parameter")
85
+ default: Any = Field(description="The default value of the parameter")
86
+ required: bool = Field(description="Whether the parameter is required")
87
+
88
+ def to_dict(self) -> dict[str, Any]:
89
+ """Convert the parameter to a dictionary for an standard dictionary-based function schema.
90
+ This is the most common format used by LLM providers, including OpenAI, Ollama, and others.
91
+
92
+ :return: The parameter in dictionary format.
93
+ :rtype: dict[str, Any]
94
+ """
95
+ return {
96
+ self.name: {
97
+ "description": self.description,
98
+ "type": openai_type_mapping(self.type),
99
+ }
100
+ }
101
+
102
+
103
+ class FunctionSchema(BaseModel):
104
+ """Class that consumes a function and can return a schema required by
105
+ different LLMs for function calling.
106
+ """
107
+
108
+ name: str = Field(description="The name of the function")
109
+ description: str = Field(description="The description of the function")
110
+ signature: str = Field(description="The signature of the function")
111
+ output: str = Field(description="The output of the function")
112
+ parameters: list[Parameter] = Field(description="The parameters of the function")
113
+
114
+ @classmethod
115
+ def from_callable(cls, function: Callable) -> "FunctionSchema":
116
+ """Initialize the FunctionSchema.
117
+
118
+ :param function: The function to consume.
119
+ :type function: Callable
120
+ """
121
+ if callable(function):
122
+ name = function.__name__
123
+ description = str(inspect.getdoc(function))
124
+ if description is None or description == "":
125
+ logger.warning(f"Function {name} has no docstring")
126
+ signature = str(inspect.signature(function))
127
+ output = str(inspect.signature(function).return_annotation)
128
+ parameters = []
129
+ for param in inspect.signature(function).parameters.values():
130
+ parameters.append(
131
+ Parameter(
132
+ name=param.name,
133
+ type=param.annotation.__name__,
134
+ default=param.default,
135
+ required=param.default is inspect.Parameter.empty,
136
+ )
137
+ )
138
+ return cls.model_construct(
139
+ name=name,
140
+ description=description,
141
+ signature=signature,
142
+ output=output,
143
+ parameters=parameters,
144
+ )
145
+ elif isinstance(function, BaseModel):
146
+ raise NotImplementedError("Pydantic BaseModel not implemented yet.")
147
+ else:
148
+ raise TypeError("Function must be a Callable or BaseModel")
149
+
150
+ @classmethod
151
+ def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
152
+ signature_parts = []
153
+ for field_name, field_model in model.__annotations__.items():
154
+ field_info = model.model_fields[field_name]
155
+ default_value = field_info.default
156
+ if default_value:
157
+ default_repr = repr(default_value)
158
+ signature_part = (
159
+ f"{field_name}: {field_model.__name__} = {default_repr}"
160
+ )
161
+ else:
162
+ signature_part = f"{field_name}: {field_model.__name__}"
163
+ signature_parts.append(signature_part)
164
+ signature = f"({', '.join(signature_parts)}) -> str"
165
+ return cls.model_construct(
166
+ name=model.__class__.__name__,
167
+ description=model.__doc__ or "",
168
+ signature=signature,
169
+ output="", # TODO: Implement output
170
+ parameters=[],
171
+ )
172
+
173
+ def to_dict(self) -> dict:
174
+ schema_dict = {
175
+ "type": "function",
176
+ "function": {
177
+ "name": self.name,
178
+ "description": self.description,
179
+ "parameters": {
180
+ "type": "object",
181
+ "properties": {
182
+ k: v for param in self.parameters for k, v in param.to_dict().items()
183
+ },
184
+ "required": [
185
+ param.name for param in self.parameters if param.required
186
+ ],
187
+ },
188
+ },
189
+ }
190
+ return schema_dict
191
+
192
+ def to_openai(self) -> dict:
193
+ return self.to_dict()
194
+
195
+
196
+ DEFAULT = set(["default", "openai", "ollama", "litellm"])
197
+
198
+
199
+ def get_schemas(callables: List[Callable], format: str = "default") -> list[dict]:
200
+ if format in DEFAULT:
201
+ return [
202
+ FunctionSchema.from_callable(callable).to_dict() for callable in callables
203
+ ]
204
+ else:
205
+ raise ValueError(f"Format {format} not supported")
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
7
- Requires-Dist: semantic-router>=0.1.5
8
7
  Requires-Dist: networkx>=3.4.2
9
8
  Requires-Dist: matplotlib>=3.10.0
9
+ Requires-Dist: pydantic>=2.11.1
10
+ Requires-Dist: colorlog>=6.9.0
10
11
  Provides-Extra: dev
11
12
  Requires-Dist: ipykernel>=6.25.0; extra == "dev"
12
13
  Requires-Dist: ruff>=0.1.5; extra == "dev"
@@ -16,7 +17,7 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
16
17
  Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
17
18
  Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
18
19
  Requires-Dist: mypy>=1.7.1; extra == "dev"
19
- Requires-Dist: black[jupyter]<24.5.0,>=23.12.1; extra == "dev"
20
+ Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
20
21
  Provides-Extra: docs
21
22
  Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
22
23
 
@@ -1,6 +1,7 @@
1
- semantic-router>=0.1.5
2
1
  networkx>=3.4.2
3
2
  matplotlib>=3.10.0
3
+ pydantic>=2.11.1
4
+ colorlog>=6.9.0
4
5
 
5
6
  [dev]
6
7
  ipykernel>=6.25.0
@@ -11,7 +12,7 @@ pytest-cov>=4.1.0
11
12
  pytest-xdist>=3.5.0
12
13
  pytest-asyncio>=0.24.0
13
14
  mypy>=1.7.1
14
- black[jupyter]<24.5.0,>=23.12.1
15
+ types-networkx>=3.4.2.20250319
15
16
 
16
17
  [docs]
17
18
 
@@ -1,13 +1,14 @@
1
1
  [project]
2
2
  name = "graphai-lib"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  description = "Not an AI framework"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.14"
7
7
  dependencies = [
8
- "semantic-router>=0.1.5",
9
8
  "networkx>=3.4.2",
10
9
  "matplotlib>=3.10.0",
10
+ "pydantic>=2.11.1",
11
+ "colorlog>=6.9.0",
11
12
  ]
12
13
 
13
14
  [project.optional-dependencies]
@@ -20,7 +21,7 @@ dev = [
20
21
  "pytest-xdist>=3.5.0",
21
22
  "pytest-asyncio>=0.24.0",
22
23
  "mypy>=1.7.1",
23
- "black[jupyter]>=23.12.1,<24.5.0",
24
+ "types-networkx>=3.4.2.20250319",
24
25
  ]
25
26
  docs = ["pydoc-markdown>=4.8.2 ; python_version < '3.12'"]
26
27
 
@@ -29,4 +30,4 @@ requires = ["setuptools>=61.0"]
29
30
  build-backend = "setuptools.build_meta"
30
31
 
31
32
  [tool.setuptools]
32
- packages = ["graphai", "graphai.nodes"]
33
+ packages = ["graphai", "graphai.nodes"]
@@ -1,125 +0,0 @@
1
- import inspect
2
- from typing import Any, Callable, Dict, List, Union, Optional
3
- from pydantic import BaseModel, Field
4
-
5
-
6
- class Parameter(BaseModel):
7
- class Config:
8
- arbitrary_types_allowed = True
9
-
10
- name: str = Field(description="The name of the parameter")
11
- description: Optional[str] = Field(
12
- default=None, description="The description of the parameter"
13
- )
14
- type: str = Field(description="The type of the parameter")
15
- default: Any = Field(description="The default value of the parameter")
16
- required: bool = Field(description="Whether the parameter is required")
17
-
18
- def to_openai(self):
19
- return {
20
- self.name: {
21
- "description": self.description,
22
- "type": self.type,
23
- }
24
- }
25
-
26
- class FunctionSchema:
27
- """Class that consumes a function and can return a schema required by
28
- different LLMs for function calling.
29
- """
30
-
31
- name: str = Field(description="The name of the function")
32
- description: str = Field(description="The description of the function")
33
- signature: str = Field(description="The signature of the function")
34
- output: str = Field(description="The output of the function")
35
- parameters: List[Parameter] = Field(description="The parameters of the function")
36
-
37
- def __init__(self, function: Union[Callable, BaseModel]):
38
- self.function = function
39
- if callable(function):
40
- self._process_function(function)
41
- elif isinstance(function, BaseModel):
42
- raise NotImplementedError("Pydantic BaseModel not implemented yet.")
43
- else:
44
- raise TypeError("Function must be a Callable or BaseModel")
45
-
46
- def _process_function(self, function: Callable):
47
- self.name = function.__name__
48
- self.description = str(inspect.getdoc(function))
49
- self.signature = str(inspect.signature(function))
50
- self.output = str(inspect.signature(function).return_annotation)
51
- parameters = []
52
- for param in inspect.signature(function).parameters.values():
53
- parameters.append(
54
- Parameter(
55
- name=param.name,
56
- type=param.annotation.__name__,
57
- default=param.default,
58
- required=param.default is inspect.Parameter.empty,
59
- )
60
- )
61
- self.parameters = parameters
62
-
63
- def to_openai(self):
64
- schema_dict = {
65
- "type": "function",
66
- "function": {
67
- "name": self.name,
68
- "description": self.description,
69
- "parameters": {
70
- "type": "object",
71
- "properties": {
72
- param.name: {
73
- "description": (
74
- param.description
75
- if isinstance(param.description, str)
76
- else "None provided"
77
- ),
78
- "type": self._openai_type_mapping(param.type),
79
- }
80
- for param in self.parameters
81
- },
82
- "required": [
83
- param.name for param in self.parameters if param.required
84
- ],
85
- },
86
- },
87
- }
88
- return schema_dict
89
-
90
- def _openai_type_mapping(self, param_type: str) -> str:
91
- if param_type == "int":
92
- return "number"
93
- elif param_type == "float":
94
- return "number"
95
- elif param_type == "str":
96
- return "string"
97
- elif param_type == "bool":
98
- return "boolean"
99
- else:
100
- return "object"
101
-
102
-
103
- def get_schema_pydantic(model: BaseModel) -> Dict[str, Any]:
104
- signature_parts = []
105
- for field_name, field_model in model.__annotations__.items():
106
- field_info = model.__fields__[field_name]
107
- default_value = field_info.default
108
-
109
- if default_value:
110
- default_repr = repr(default_value)
111
- signature_part = (
112
- f"{field_name}: {field_model.__name__} = {default_repr}"
113
- )
114
- else:
115
- signature_part = f"{field_name}: {field_model.__name__}"
116
-
117
- signature_parts.append(signature_part)
118
- signature = f"({', '.join(signature_parts)}) -> str"
119
- schema = FunctionSchema(
120
- name=model.__class__.__name__,
121
- description=model.__doc__,
122
- signature=signature,
123
- output="", # TODO: Implement output
124
- )
125
- return schema
File without changes
File without changes