graphai-lib 0.0.3__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
7
- Requires-Dist: semantic-router>=0.1.5
8
7
  Requires-Dist: networkx>=3.4.2
9
8
  Requires-Dist: matplotlib>=3.10.0
9
+ Requires-Dist: pydantic>=2.11.1
10
+ Requires-Dist: colorlog>=6.9.0
10
11
  Provides-Extra: dev
11
12
  Requires-Dist: ipykernel>=6.25.0; extra == "dev"
12
13
  Requires-Dist: ruff>=0.1.5; extra == "dev"
@@ -16,7 +17,9 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
16
17
  Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
17
18
  Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
18
19
  Requires-Dist: mypy>=1.7.1; extra == "dev"
19
- Requires-Dist: black[jupyter]<24.5.0,>=23.12.1; extra == "dev"
20
+ Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
21
+ Provides-Extra: docs
22
+ Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
20
23
 
21
24
  # Philosophy
22
25
 
@@ -1,4 +1,4 @@
1
1
  from graphai.graph import Graph
2
2
  from graphai.nodes import node, router
3
3
 
4
- __all__ = ["node", "router", "Graph"]
4
+ __all__ = ["node", "router", "Graph"]
@@ -1,12 +1,12 @@
1
1
  import asyncio
2
2
  from pydantic import Field
3
- from typing import Optional
3
+ from typing import Optional, Any
4
4
  from collections.abc import AsyncIterator
5
- from semantic_router.utils.logger import logger
6
5
 
7
6
 
8
7
  log_stream = True
9
8
 
9
+
10
10
  class Callback:
11
11
  identifier: str = Field(
12
12
  default="graphai",
@@ -14,7 +14,7 @@ class Callback:
14
14
  "The identifier for special tokens. This allows us to easily "
15
15
  "identify special tokens in the stream so we can handle them "
16
16
  "correctly in any downstream process."
17
- )
17
+ ),
18
18
  )
19
19
  special_token_format: str = Field(
20
20
  default="<{identifier}:{token}:{params}>",
@@ -32,8 +32,8 @@ class Callback:
32
32
  examples=[
33
33
  "<{identifier}:{token}:{params}>",
34
34
  "<[{identifier} | {token} | {params}]>",
35
- "<{token}:{params}>"
36
- ]
35
+ "<{token}:{params}>",
36
+ ],
37
37
  )
38
38
  token_format: str = Field(
39
39
  default="{token}",
@@ -41,27 +41,23 @@ class Callback:
41
41
  "The format for streamed tokens. This is used to format the "
42
42
  "tokens typically returned from LLMs. By default, no special "
43
43
  "formatting is applied."
44
- )
44
+ ),
45
45
  )
46
46
  _first_token: bool = Field(
47
47
  default=True,
48
48
  description="Whether this is the first token in the stream.",
49
- exclude=True
49
+ exclude=True,
50
50
  )
51
51
  _current_node_name: Optional[str] = Field(
52
- default=None,
53
- description="The name of the current node.",
54
- exclude=True
52
+ default=None, description="The name of the current node.", exclude=True
55
53
  )
56
54
  _active: bool = Field(
57
- default=True,
58
- description="Whether the callback is active.",
59
- exclude=True
55
+ default=True, description="Whether the callback is active.", exclude=True
60
56
  )
61
57
  _done: bool = Field(
62
58
  default=False,
63
59
  description="Whether the stream is done and should be closed.",
64
- exclude=True
60
+ exclude=True,
65
61
  )
66
62
  queue: asyncio.Queue
67
63
 
@@ -83,7 +79,7 @@ class Callback:
83
79
  @property
84
80
  def first_token(self) -> bool:
85
81
  return self._first_token
86
-
82
+
87
83
  @first_token.setter
88
84
  def first_token(self, value: bool):
89
85
  self._first_token = value
@@ -91,7 +87,7 @@ class Callback:
91
87
  @property
92
88
  def current_node_name(self) -> Optional[str]:
93
89
  return self._current_node_name
94
-
90
+
95
91
  @current_node_name.setter
96
92
  def current_node_name(self, value: Optional[str]):
97
93
  self._current_node_name = value
@@ -99,7 +95,7 @@ class Callback:
99
95
  @property
100
96
  def active(self) -> bool:
101
97
  return self._active
102
-
98
+
103
99
  @active.setter
104
100
  def active(self, value: bool):
105
101
  self._active = value
@@ -110,7 +106,7 @@ class Callback:
110
106
  self._check_node_name(node_name=node_name)
111
107
  # otherwise we just assume node is correct and send token
112
108
  self.queue.put_nowait(token)
113
-
109
+
114
110
  async def acall(self, token: str, node_name: Optional[str] = None):
115
111
  # TODO JB: do we need to have `node_name` param?
116
112
  if self._done:
@@ -118,16 +114,13 @@ class Callback:
118
114
  self._check_node_name(node_name=node_name)
119
115
  # otherwise we just assume node is correct and send token
120
116
  self.queue.put_nowait(token)
121
-
117
+
122
118
  async def aiter(self) -> AsyncIterator[str]:
123
119
  """Used by receiver to get the tokens from the stream queue. Creates
124
120
  a generator that yields tokens from the queue until the END token is
125
121
  received.
126
122
  """
127
- end_token = await self._build_special_token(
128
- name="END",
129
- params=None
130
- )
123
+ end_token = await self._build_special_token(name="END", params=None)
131
124
  while True: # Keep going until we see the END token
132
125
  try:
133
126
  if self._done and self.queue.empty():
@@ -142,8 +135,7 @@ class Callback:
142
135
  self._done = True # Mark as done after processing all tokens
143
136
 
144
137
  async def start_node(self, node_name: str, active: bool = True):
145
- """Starts a new node and emits the start token.
146
- """
138
+ """Starts a new node and emits the start token."""
147
139
  if self._done:
148
140
  raise RuntimeError("Cannot start node on a closed stream")
149
141
  self.current_node_name = node_name
@@ -152,27 +144,23 @@ class Callback:
152
144
  self.active = active
153
145
  if self.active:
154
146
  token = await self._build_special_token(
155
- name=f"{self.current_node_name}:start",
156
- params=None
147
+ name=f"{self.current_node_name}:start", params=None
157
148
  )
158
149
  self.queue.put_nowait(token)
159
150
  # TODO JB: should we use two tokens here?
160
151
  node_token = await self._build_special_token(
161
- name=self.current_node_name,
162
- params=None
152
+ name=self.current_node_name, params=None
163
153
  )
164
154
  self.queue.put_nowait(node_token)
165
-
155
+
166
156
  async def end_node(self, node_name: str):
167
- """Emits the end token for the current node.
168
- """
157
+ """Emits the end token for the current node."""
169
158
  if self._done:
170
159
  raise RuntimeError("Cannot end node on a closed stream")
171
- #self.current_node_name = node_name
160
+ # self.current_node_name = node_name
172
161
  if self.active:
173
162
  node_token = await self._build_special_token(
174
- name=f"{self.current_node_name}:end",
175
- params=None
163
+ name=f"{self.current_node_name}:end", params=None
176
164
  )
177
165
  self.queue.put_nowait(node_token)
178
166
 
@@ -182,10 +170,7 @@ class Callback:
182
170
  """
183
171
  if self._done:
184
172
  return
185
- end_token = await self._build_special_token(
186
- name="END",
187
- params=None
188
- )
173
+ end_token = await self._build_special_token(name="END", params=None)
189
174
  self._done = True # Set done before putting the end token
190
175
  self.queue.put_nowait(end_token)
191
176
  # Don't wait for queue.join() as it can cause deadlock
@@ -198,8 +183,10 @@ class Callback:
198
183
  raise ValueError(
199
184
  f"Node name mismatch: {self.current_node_name} != {node_name}"
200
185
  )
201
-
202
- async def _build_special_token(self, name: str, params: dict[str, any] | None = None):
186
+
187
+ async def _build_special_token(
188
+ self, name: str, params: dict[str, Any] | None = None
189
+ ):
203
190
  if params:
204
191
  params_str = ",".join([f"{k}={v}" for k, v in params.items()])
205
192
  else:
@@ -209,7 +196,5 @@ class Callback:
209
196
  else:
210
197
  identifier = ""
211
198
  return self.special_token_format.format(
212
- identifier=identifier,
213
- token=name,
214
- params=params_str
199
+ identifier=identifier, token=name, params=params_str
215
200
  )
@@ -1,11 +1,13 @@
1
1
  from typing import List, Dict, Any, Optional
2
2
  from graphai.nodes.base import _Node
3
3
  from graphai.callback import Callback
4
- from semantic_router.utils.logger import logger
4
+ from graphai.utils import logger
5
5
 
6
6
 
7
7
  class Graph:
8
- def __init__(self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None):
8
+ def __init__(
9
+ self, max_steps: int = 10, initial_state: Optional[Dict[str, Any]] = None
10
+ ):
9
11
  self.nodes: Dict[str, _Node] = {}
10
12
  self.edges: List[Any] = []
11
13
  self.start_node: Optional[_Node] = None
@@ -49,7 +51,7 @@ class Graph:
49
51
 
50
52
  def add_edge(self, source: _Node | str, destination: _Node | str):
51
53
  """Adds an edge between two nodes that already exist in the graph.
52
-
54
+
53
55
  Args:
54
56
  source: The source node or its name.
55
57
  destination: The destination node or its name.
@@ -60,7 +62,7 @@ class Graph:
60
62
  source_node = self.nodes.get(source)
61
63
  else:
62
64
  # Check if it's a node-like object by looking for required attributes
63
- if hasattr(source, 'name'):
65
+ if hasattr(source, "name"):
64
66
  source_node = self.nodes.get(source.name)
65
67
  if source_node is None:
66
68
  raise ValueError(
@@ -71,7 +73,7 @@ class Graph:
71
73
  destination_node = self.nodes.get(destination)
72
74
  else:
73
75
  # Check if it's a node-like object by looking for required attributes
74
- if hasattr(destination, 'name'):
76
+ if hasattr(destination, "name"):
75
77
  destination_node = self.nodes.get(destination.name)
76
78
  if destination_node is None:
77
79
  raise ValueError(
@@ -80,7 +82,9 @@ class Graph:
80
82
  edge = Edge(source_node, destination_node)
81
83
  self.edges.append(edge)
82
84
 
83
- def add_router(self, sources: list[_Node], router: _Node, destinations: List[_Node]):
85
+ def add_router(
86
+ self, sources: list[_Node], router: _Node, destinations: List[_Node]
87
+ ):
84
88
  if not router.is_router:
85
89
  raise TypeError("A router object must be passed to the router parameter.")
86
90
  [self.add_edge(source, router) for source in sources]
@@ -126,7 +130,9 @@ class Graph:
126
130
  # add callback tokens and param here if we are streaming
127
131
  await self.callback.start_node(node_name=current_node.name)
128
132
  # Include graph's internal state in the node execution context
129
- output = await current_node.invoke(input=state, callback=self.callback, state=self.state)
133
+ output = await current_node.invoke(
134
+ input=state, callback=self.callback, state=self.state
135
+ )
130
136
  self._validate_output(output=output, node_name=current_node.name)
131
137
  await self.callback.end_node(node_name=current_node.name)
132
138
  else:
@@ -164,13 +170,13 @@ class Graph:
164
170
 
165
171
  def _get_node_by_name(self, node_name: str) -> _Node:
166
172
  """Get a node by its name.
167
-
173
+
168
174
  Args:
169
175
  node_name: The name of the node to find.
170
-
176
+
171
177
  Returns:
172
178
  The node with the given name.
173
-
179
+
174
180
  Raises:
175
181
  Exception: If no node with the given name is found.
176
182
  """
@@ -191,12 +197,16 @@ class Graph:
191
197
  try:
192
198
  import networkx as nx
193
199
  except ImportError:
194
- raise ImportError("NetworkX is required for visualization. Please install it with 'pip install networkx'.")
200
+ raise ImportError(
201
+ "NetworkX is required for visualization. Please install it with 'pip install networkx'."
202
+ )
195
203
 
196
204
  try:
197
205
  import matplotlib.pyplot as plt
198
206
  except ImportError:
199
- raise ImportError("Matplotlib is required for visualization. Please install it with 'pip install matplotlib'.")
207
+ raise ImportError(
208
+ "Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
209
+ )
200
210
 
201
211
  G = nx.DiGraph()
202
212
 
@@ -207,7 +217,9 @@ class Graph:
207
217
  G.add_edge(edge.source.name, edge.destination.name)
208
218
 
209
219
  if nx.is_directed_acyclic_graph(G):
210
- logger.info("The graph is acyclic. Visualization will use a topological layout.")
220
+ logger.info(
221
+ "The graph is acyclic. Visualization will use a topological layout."
222
+ )
211
223
  # Use topological layout if acyclic
212
224
  # Compute the topological generations
213
225
  generations = list(nx.topological_generations(G))
@@ -241,20 +253,30 @@ class Graph:
241
253
  pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
242
254
 
243
255
  else:
244
- print("Warning: The graph contains cycles. Visualization will use a spring layout.")
256
+ print(
257
+ "Warning: The graph contains cycles. Visualization will use a spring layout."
258
+ )
245
259
  pos = nx.spring_layout(G, k=1, iterations=50)
246
260
 
247
261
  plt.figure(figsize=(8, 6))
248
- nx.draw(G, pos, with_labels=True, node_color='lightblue',
249
- node_size=3000, font_size=8, font_weight='bold',
250
- arrows=True, edge_color='gray', arrowsize=20)
262
+ nx.draw(
263
+ G,
264
+ pos,
265
+ with_labels=True,
266
+ node_color="lightblue",
267
+ node_size=3000,
268
+ font_size=8,
269
+ font_weight="bold",
270
+ arrows=True,
271
+ edge_color="gray",
272
+ arrowsize=20,
273
+ )
251
274
 
252
- plt.axis('off')
275
+ plt.axis("off")
253
276
  plt.show()
254
277
 
255
278
 
256
-
257
279
  class Edge:
258
280
  def __init__(self, source, destination):
259
281
  self.source = source
260
- self.destination = destination
282
+ self.destination = destination
@@ -0,0 +1,3 @@
1
+ from graphai.nodes.base import node, router
2
+
3
+ __all__ = ["node", "router"]
@@ -0,0 +1,191 @@
1
+ import inspect
2
+ from typing import Any, Callable, Dict, Optional
3
+ from pydantic import Field
4
+
5
+ from graphai.callback import Callback
6
+ from graphai.utils import FunctionSchema
7
+
8
+
9
+ class NodeMeta(type):
10
+ @staticmethod
11
+ def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
12
+ init_signature = inspect.signature(cls_type.__init__)
13
+ init_params = {
14
+ name: arg
15
+ for name, arg in init_signature.parameters.items()
16
+ if name != "self"
17
+ }
18
+ return init_params
19
+
20
+ def __call__(cls, *args, **kwargs):
21
+ named_positional_args = NodeMeta.positional_to_kwargs(cls, args)
22
+ kwargs.update(named_positional_args)
23
+ return super().__call__(**kwargs)
24
+
25
+
26
+ class _Node:
27
+ def __init__(
28
+ self,
29
+ is_router: bool = False,
30
+ ):
31
+ self.is_router = is_router
32
+
33
+ def _node(
34
+ self,
35
+ func: Callable,
36
+ start: bool = False,
37
+ end: bool = False,
38
+ stream: bool = False,
39
+ name: str | None = None,
40
+ ) -> Callable:
41
+ """Decorator validating node structure."""
42
+ if not callable(func):
43
+ raise ValueError("Node must be a callable function.")
44
+
45
+ func_signature = inspect.signature(func)
46
+ schema: FunctionSchema = FunctionSchema.from_callable(func)
47
+
48
+ class NodeClass:
49
+ _func_signature = func_signature
50
+ is_router: bool = Field(
51
+ default=False, description="Whether the node is a router."
52
+ )
53
+ # following attributes will be overridden by the decorator
54
+ name: str | None = Field(default=None, description="The name of the node.")
55
+ is_start: bool = Field(
56
+ default=False, description="Whether the node is the start of the graph."
57
+ )
58
+ is_end: bool = Field(
59
+ default=False, description="Whether the node is the end of the graph."
60
+ )
61
+ schema: FunctionSchema | None = Field(
62
+ default=None, description="The schema of the node."
63
+ )
64
+ stream: bool = Field(
65
+ default=False, description="Whether the node includes streaming object."
66
+ )
67
+
68
+ def __init__(self):
69
+ self._expected_params = set(self._func_signature.parameters.keys())
70
+
71
+ async def execute(self, *args, **kwargs):
72
+ # Prepare arguments, including callback if stream is True
73
+ params_dict = await self._parse_params(*args, **kwargs)
74
+ return await func(**params_dict) # Pass only the necessary arguments
75
+
76
+ async def _parse_params(self, *args, **kwargs) -> Dict[str, Any]:
77
+ # filter out unexpected keyword args
78
+ expected_kwargs = {
79
+ k: v for k, v in kwargs.items() if k in self._expected_params
80
+ }
81
+ # Convert args to kwargs based on the function signature
82
+ args_names = list(self._func_signature.parameters.keys())[
83
+ 1 : len(args) + 1
84
+ ] # skip 'self'
85
+ expected_args_kwargs = dict(zip(args_names, args))
86
+ # Combine filtered args and kwargs
87
+ combined_params = {**expected_args_kwargs, **expected_kwargs}
88
+
89
+ # Bind the current instance attributes to the function signature
90
+ if "callback" in self._expected_params and not stream:
91
+ raise ValueError(
92
+ f"Node {func.__name__}: requires stream=True when callback is defined."
93
+ )
94
+ bound_params = self._func_signature.bind_partial(**combined_params)
95
+ # get the default parameters (if any)
96
+ bound_params.apply_defaults()
97
+ params_dict = bound_params.arguments.copy()
98
+ # Filter arguments to match the next node's parameters
99
+ filtered_params = {
100
+ k: v for k, v in params_dict.items() if k in self._expected_params
101
+ }
102
+ # confirm all required parameters are present
103
+ missing_params = [
104
+ p for p in self._expected_params if p not in filtered_params
105
+ ]
106
+ # if anything is missing we raise an error
107
+ if missing_params:
108
+ raise ValueError(
109
+ f"Missing required parameters for the {func.__name__} node: {', '.join(missing_params)}"
110
+ )
111
+ return filtered_params
112
+
113
+ @classmethod
114
+ def get_signature(cls):
115
+ """Returns the signature of the decorated function as LLM readable
116
+ string.
117
+ """
118
+ signature_components = []
119
+ if NodeClass._func_signature:
120
+ for param in NodeClass._func_signature.parameters.values():
121
+ if param.default is param.empty:
122
+ signature_components.append(
123
+ f"{param.name}: {param.annotation}"
124
+ )
125
+ else:
126
+ signature_components.append(
127
+ f"{param.name}: {param.annotation} = {param.default}"
128
+ )
129
+ else:
130
+ return "No signature"
131
+ return "\n".join(signature_components)
132
+
133
+ @classmethod
134
+ async def invoke(
135
+ cls,
136
+ input: Dict[str, Any],
137
+ callback: Optional[Callback] = None,
138
+ state: Optional[Dict[str, Any]] = None,
139
+ ):
140
+ if callback:
141
+ if stream:
142
+ input["callback"] = callback
143
+ else:
144
+ raise ValueError(
145
+ f"Error in node {func.__name__}. When callback provided, stream must be True."
146
+ )
147
+ # Add state to the input if present and the parameter exists in the function signature
148
+ if state is not None and "state" in cls._func_signature.parameters:
149
+ input["state"] = state
150
+
151
+ instance = cls()
152
+ out = await instance.execute(**input)
153
+ return out
154
+
155
+ NodeClass.__name__ = func.__name__
156
+ node_class_name = name or func.__name__
157
+ if node_class_name is None:
158
+ raise ValueError("Unexpected error: node name not set.")
159
+ NodeClass.name = node_class_name
160
+ NodeClass.__doc__ = func.__doc__
161
+ NodeClass.is_start = start
162
+ NodeClass.is_end = end
163
+ NodeClass.is_router = self.is_router
164
+ NodeClass.stream = stream
165
+ NodeClass.schema = schema
166
+ return NodeClass
167
+
168
+ def __call__(
169
+ self,
170
+ func: Optional[Callable] = None,
171
+ start: bool = False,
172
+ end: bool = False,
173
+ stream: bool = False,
174
+ name: str | None = None,
175
+ ):
176
+ # We must wrap the call to the decorator in a function for it to work
177
+ # correctly with or without parenthesis
178
+ def wrap(
179
+ func: Callable, start=start, end=end, stream=stream, name=name
180
+ ) -> Callable:
181
+ return self._node(func=func, start=start, end=end, stream=stream, name=name)
182
+
183
+ if func:
184
+ # Decorator is called without parenthesis
185
+ return wrap(func=func, start=start, end=end, stream=stream, name=name)
186
+ # Decorator is called with parenthesis
187
+ return wrap
188
+
189
+
190
+ node = _Node()
191
+ router = _Node(is_router=True)
@@ -0,0 +1,205 @@
1
+ import inspect
2
+ from typing import Any, Callable, List, Optional
3
+ from pydantic import BaseModel, Field
4
+ import logging
5
+
6
+ import colorlog
7
+
8
+
9
+ class CustomFormatter(colorlog.ColoredFormatter):
10
+ """Custom formatter for the logger."""
11
+
12
+ def __init__(self):
13
+ super().__init__(
14
+ "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
15
+ datefmt="%Y-%m-%d %H:%M:%S",
16
+ log_colors={
17
+ "DEBUG": "cyan",
18
+ "INFO": "green",
19
+ "WARNING": "yellow",
20
+ "ERROR": "red",
21
+ "CRITICAL": "bold_red",
22
+ },
23
+ reset=True,
24
+ style="%",
25
+ )
26
+
27
+
28
+ def add_coloured_handler(logger):
29
+ """Add a coloured handler to the logger."""
30
+ formatter = CustomFormatter()
31
+ console_handler = logging.StreamHandler()
32
+ console_handler.setFormatter(formatter)
33
+ logger.addHandler(console_handler)
34
+ return logger
35
+
36
+
37
+ def setup_custom_logger(name):
38
+ """Setup a custom logger."""
39
+ logger = logging.getLogger(name)
40
+
41
+ if not logger.hasHandlers():
42
+ add_coloured_handler(logger)
43
+ logger.setLevel(logging.INFO)
44
+ logger.propagate = False
45
+
46
+ return logger
47
+
48
+
49
+ logger: logging.Logger = setup_custom_logger(__name__)
50
+
51
+
52
+ def openai_type_mapping(param_type: str) -> str:
53
+ if param_type == "int":
54
+ return "number"
55
+ elif param_type == "float":
56
+ return "number"
57
+ elif param_type == "str":
58
+ return "string"
59
+ elif param_type == "bool":
60
+ return "boolean"
61
+ else:
62
+ return "object"
63
+
64
+
65
+ class Parameter(BaseModel):
66
+ """Parameter for a function.
67
+
68
+ :param name: The name of the parameter.
69
+ :type name: str
70
+ :param description: The description of the parameter.
71
+ :type description: Optional[str]
72
+ :param type: The type of the parameter.
73
+ :type type: str
74
+ :param default: The default value of the parameter.
75
+ :type default: Any
76
+ :param required: Whether the parameter is required.
77
+ :type required: bool
78
+ """
79
+
80
+ name: str = Field(description="The name of the parameter")
81
+ description: Optional[str] = Field(
82
+ default=None, description="The description of the parameter"
83
+ )
84
+ type: str = Field(description="The type of the parameter")
85
+ default: Any = Field(description="The default value of the parameter")
86
+ required: bool = Field(description="Whether the parameter is required")
87
+
88
+ def to_dict(self) -> dict[str, Any]:
89
+ """Convert the parameter to a dictionary for an standard dictionary-based function schema.
90
+ This is the most common format used by LLM providers, including OpenAI, Ollama, and others.
91
+
92
+ :return: The parameter in dictionary format.
93
+ :rtype: dict[str, Any]
94
+ """
95
+ return {
96
+ self.name: {
97
+ "description": self.description,
98
+ "type": openai_type_mapping(self.type),
99
+ }
100
+ }
101
+
102
+
103
+ class FunctionSchema(BaseModel):
104
+ """Class that consumes a function and can return a schema required by
105
+ different LLMs for function calling.
106
+ """
107
+
108
+ name: str = Field(description="The name of the function")
109
+ description: str = Field(description="The description of the function")
110
+ signature: str = Field(description="The signature of the function")
111
+ output: str = Field(description="The output of the function")
112
+ parameters: list[Parameter] = Field(description="The parameters of the function")
113
+
114
+ @classmethod
115
+ def from_callable(cls, function: Callable) -> "FunctionSchema":
116
+ """Initialize the FunctionSchema.
117
+
118
+ :param function: The function to consume.
119
+ :type function: Callable
120
+ """
121
+ if callable(function):
122
+ name = function.__name__
123
+ description = str(inspect.getdoc(function))
124
+ if description is None or description == "":
125
+ logger.warning(f"Function {name} has no docstring")
126
+ signature = str(inspect.signature(function))
127
+ output = str(inspect.signature(function).return_annotation)
128
+ parameters = []
129
+ for param in inspect.signature(function).parameters.values():
130
+ parameters.append(
131
+ Parameter(
132
+ name=param.name,
133
+ type=param.annotation.__name__,
134
+ default=param.default,
135
+ required=param.default is inspect.Parameter.empty,
136
+ )
137
+ )
138
+ return cls.model_construct(
139
+ name=name,
140
+ description=description,
141
+ signature=signature,
142
+ output=output,
143
+ parameters=parameters,
144
+ )
145
+ elif isinstance(function, BaseModel):
146
+ raise NotImplementedError("Pydantic BaseModel not implemented yet.")
147
+ else:
148
+ raise TypeError("Function must be a Callable or BaseModel")
149
+
150
+ @classmethod
151
+ def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
152
+ signature_parts = []
153
+ for field_name, field_model in model.__annotations__.items():
154
+ field_info = model.model_fields[field_name]
155
+ default_value = field_info.default
156
+ if default_value:
157
+ default_repr = repr(default_value)
158
+ signature_part = (
159
+ f"{field_name}: {field_model.__name__} = {default_repr}"
160
+ )
161
+ else:
162
+ signature_part = f"{field_name}: {field_model.__name__}"
163
+ signature_parts.append(signature_part)
164
+ signature = f"({', '.join(signature_parts)}) -> str"
165
+ return cls.model_construct(
166
+ name=model.__class__.__name__,
167
+ description=model.__doc__ or "",
168
+ signature=signature,
169
+ output="", # TODO: Implement output
170
+ parameters=[],
171
+ )
172
+
173
+ def to_dict(self) -> dict:
174
+ schema_dict = {
175
+ "type": "function",
176
+ "function": {
177
+ "name": self.name,
178
+ "description": self.description,
179
+ "parameters": {
180
+ "type": "object",
181
+ "properties": {
182
+ k: v for param in self.parameters for k, v in param.to_dict().items()
183
+ },
184
+ "required": [
185
+ param.name for param in self.parameters if param.required
186
+ ],
187
+ },
188
+ },
189
+ }
190
+ return schema_dict
191
+
192
+ def to_openai(self) -> dict:
193
+ return self.to_dict()
194
+
195
+
196
+ DEFAULT = set(["default", "openai", "ollama", "litellm"])
197
+
198
+
199
+ def get_schemas(callables: List[Callable], format: str = "default") -> list[dict]:
200
+ if format in DEFAULT:
201
+ return [
202
+ FunctionSchema.from_callable(callable).to_dict() for callable in callables
203
+ ]
204
+ else:
205
+ raise ValueError(f"Format {format} not supported")
@@ -1,12 +1,13 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
7
- Requires-Dist: semantic-router>=0.1.5
8
7
  Requires-Dist: networkx>=3.4.2
9
8
  Requires-Dist: matplotlib>=3.10.0
9
+ Requires-Dist: pydantic>=2.11.1
10
+ Requires-Dist: colorlog>=6.9.0
10
11
  Provides-Extra: dev
11
12
  Requires-Dist: ipykernel>=6.25.0; extra == "dev"
12
13
  Requires-Dist: ruff>=0.1.5; extra == "dev"
@@ -16,7 +17,9 @@ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
16
17
  Requires-Dist: pytest-xdist>=3.5.0; extra == "dev"
17
18
  Requires-Dist: pytest-asyncio>=0.24.0; extra == "dev"
18
19
  Requires-Dist: mypy>=1.7.1; extra == "dev"
19
- Requires-Dist: black[jupyter]<24.5.0,>=23.12.1; extra == "dev"
20
+ Requires-Dist: types-networkx>=3.4.2.20250319; extra == "dev"
21
+ Provides-Extra: docs
22
+ Requires-Dist: pydoc-markdown>=4.8.2; python_version < "3.12" and extra == "docs"
20
23
 
21
24
  # Philosophy
22
25
 
@@ -4,6 +4,8 @@ graphai/__init__.py
4
4
  graphai/callback.py
5
5
  graphai/graph.py
6
6
  graphai/utils.py
7
+ graphai/nodes/__init__.py
8
+ graphai/nodes/base.py
7
9
  graphai_lib.egg-info/PKG-INFO
8
10
  graphai_lib.egg-info/SOURCES.txt
9
11
  graphai_lib.egg-info/dependency_links.txt
@@ -1,6 +1,7 @@
1
- semantic-router>=0.1.5
2
1
  networkx>=3.4.2
3
2
  matplotlib>=3.10.0
3
+ pydantic>=2.11.1
4
+ colorlog>=6.9.0
4
5
 
5
6
  [dev]
6
7
  ipykernel>=6.25.0
@@ -11,4 +12,9 @@ pytest-cov>=4.1.0
11
12
  pytest-xdist>=3.5.0
12
13
  pytest-asyncio>=0.24.0
13
14
  mypy>=1.7.1
14
- black[jupyter]<24.5.0,>=23.12.1
15
+ types-networkx>=3.4.2.20250319
16
+
17
+ [docs]
18
+
19
+ [docs:python_version < "3.12"]
20
+ pydoc-markdown>=4.8.2
@@ -1,13 +1,14 @@
1
1
  [project]
2
2
  name = "graphai-lib"
3
- version = "0.0.3"
3
+ version = "0.0.5"
4
4
  description = "Not an AI framework"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.14"
7
7
  dependencies = [
8
- "semantic-router>=0.1.5",
9
8
  "networkx>=3.4.2",
10
9
  "matplotlib>=3.10.0",
10
+ "pydantic>=2.11.1",
11
+ "colorlog>=6.9.0",
11
12
  ]
12
13
 
13
14
  [project.optional-dependencies]
@@ -20,12 +21,13 @@ dev = [
20
21
  "pytest-xdist>=3.5.0",
21
22
  "pytest-asyncio>=0.24.0",
22
23
  "mypy>=1.7.1",
23
- "black[jupyter]>=23.12.1,<24.5.0",
24
+ "types-networkx>=3.4.2.20250319",
24
25
  ]
26
+ docs = ["pydoc-markdown>=4.8.2 ; python_version < '3.12'"]
25
27
 
26
28
  [build-system]
27
29
  requires = ["setuptools>=61.0"]
28
30
  build-backend = "setuptools.build_meta"
29
31
 
30
- [tool.setuptools.packages.find]
31
- include = ["graphai"]
32
+ [tool.setuptools]
33
+ packages = ["graphai", "graphai.nodes"]
@@ -1,125 +0,0 @@
1
- import inspect
2
- from typing import Any, Callable, Dict, List, Union, Optional
3
- from pydantic import BaseModel, Field
4
-
5
-
6
- class Parameter(BaseModel):
7
- class Config:
8
- arbitrary_types_allowed = True
9
-
10
- name: str = Field(description="The name of the parameter")
11
- description: Optional[str] = Field(
12
- default=None, description="The description of the parameter"
13
- )
14
- type: str = Field(description="The type of the parameter")
15
- default: Any = Field(description="The default value of the parameter")
16
- required: bool = Field(description="Whether the parameter is required")
17
-
18
- def to_openai(self):
19
- return {
20
- self.name: {
21
- "description": self.description,
22
- "type": self.type,
23
- }
24
- }
25
-
26
- class FunctionSchema:
27
- """Class that consumes a function and can return a schema required by
28
- different LLMs for function calling.
29
- """
30
-
31
- name: str = Field(description="The name of the function")
32
- description: str = Field(description="The description of the function")
33
- signature: str = Field(description="The signature of the function")
34
- output: str = Field(description="The output of the function")
35
- parameters: List[Parameter] = Field(description="The parameters of the function")
36
-
37
- def __init__(self, function: Union[Callable, BaseModel]):
38
- self.function = function
39
- if callable(function):
40
- self._process_function(function)
41
- elif isinstance(function, BaseModel):
42
- raise NotImplementedError("Pydantic BaseModel not implemented yet.")
43
- else:
44
- raise TypeError("Function must be a Callable or BaseModel")
45
-
46
- def _process_function(self, function: Callable):
47
- self.name = function.__name__
48
- self.description = str(inspect.getdoc(function))
49
- self.signature = str(inspect.signature(function))
50
- self.output = str(inspect.signature(function).return_annotation)
51
- parameters = []
52
- for param in inspect.signature(function).parameters.values():
53
- parameters.append(
54
- Parameter(
55
- name=param.name,
56
- type=param.annotation.__name__,
57
- default=param.default,
58
- required=param.default is inspect.Parameter.empty,
59
- )
60
- )
61
- self.parameters = parameters
62
-
63
- def to_openai(self):
64
- schema_dict = {
65
- "type": "function",
66
- "function": {
67
- "name": self.name,
68
- "description": self.description,
69
- "parameters": {
70
- "type": "object",
71
- "properties": {
72
- param.name: {
73
- "description": (
74
- param.description
75
- if isinstance(param.description, str)
76
- else "None provided"
77
- ),
78
- "type": self._openai_type_mapping(param.type),
79
- }
80
- for param in self.parameters
81
- },
82
- "required": [
83
- param.name for param in self.parameters if param.required
84
- ],
85
- },
86
- },
87
- }
88
- return schema_dict
89
-
90
- def _openai_type_mapping(self, param_type: str) -> str:
91
- if param_type == "int":
92
- return "number"
93
- elif param_type == "float":
94
- return "number"
95
- elif param_type == "str":
96
- return "string"
97
- elif param_type == "bool":
98
- return "boolean"
99
- else:
100
- return "object"
101
-
102
-
103
- def get_schema_pydantic(model: BaseModel) -> Dict[str, Any]:
104
- signature_parts = []
105
- for field_name, field_model in model.__annotations__.items():
106
- field_info = model.__fields__[field_name]
107
- default_value = field_info.default
108
-
109
- if default_value:
110
- default_repr = repr(default_value)
111
- signature_part = (
112
- f"{field_name}: {field_model.__name__} = {default_repr}"
113
- )
114
- else:
115
- signature_part = f"{field_name}: {field_model.__name__}"
116
-
117
- signature_parts.append(signature_part)
118
- signature = f"({', '.join(signature_parts)}) -> str"
119
- schema = FunctionSchema(
120
- name=model.__class__.__name__,
121
- description=model.__doc__,
122
- signature=signature,
123
- output="", # TODO: Implement output
124
- )
125
- return schema
File without changes
File without changes