graphai-lib 0.0.8__py3-none-any.whl → 0.0.9rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graphai/graph.py CHANGED
@@ -1,8 +1,23 @@
1
+ from __future__ import annotations
1
2
  from typing import Any, Protocol, Type
3
+ from graphlib import TopologicalSorter, CycleError
2
4
  from graphai.callback import Callback
3
5
  from graphai.utils import logger
4
6
 
5
7
 
8
+ # to fix mypy error
9
+ class _HasName(Protocol):
10
+ name: str
11
+
12
+
13
+ class GraphError(Exception):
14
+ pass
15
+
16
+
17
+ class GraphCompileError(GraphError):
18
+ pass
19
+
20
+
6
21
  class NodeProtocol(Protocol):
7
22
  """Protocol defining the interface of a decorated node."""
8
23
 
@@ -20,6 +35,26 @@ class NodeProtocol(Protocol):
20
35
  ) -> dict[str, Any]: ...
21
36
 
22
37
 
38
+ def _name_of(x: Any) -> str | None:
39
+ """Return the node name if x is a str or has .name, else None."""
40
+ if x is None:
41
+ return None
42
+ if isinstance(x, str):
43
+ return x
44
+ name = getattr(x, "name", None)
45
+ return name if isinstance(name, str) else None
46
+
47
+
48
+ def _require_name(x: Any, kind: str) -> str:
49
+ """Like _name_of, but raises a helpful compile error when missing."""
50
+ s = _name_of(x)
51
+ if s is None:
52
+ raise GraphCompileError(
53
+ f"Edge {kind} must be a node name (str) or object with .name"
54
+ )
55
+ return s
56
+
57
+
23
58
  class Graph:
24
59
  def __init__(
25
60
  self, max_steps: int = 10, initial_state: dict[str, Any] | None = None
@@ -37,22 +72,22 @@ class Graph:
37
72
  """Get the current graph state."""
38
73
  return self.state
39
74
 
40
- def set_state(self, state: dict[str, Any]) -> "Graph":
75
+ def set_state(self, state: dict[str, Any]) -> Graph:
41
76
  """Set the graph state."""
42
77
  self.state = state
43
78
  return self
44
79
 
45
- def update_state(self, values: dict[str, Any]) -> "Graph":
80
+ def update_state(self, values: dict[str, Any]) -> Graph:
46
81
  """Update the graph state with new values."""
47
82
  self.state.update(values)
48
83
  return self
49
84
 
50
- def reset_state(self) -> "Graph":
85
+ def reset_state(self) -> Graph:
51
86
  """Reset the graph state to an empty dict."""
52
87
  self.state = {}
53
88
  return self
54
89
 
55
- def add_node(self, node: NodeProtocol) -> "Graph":
90
+ def add_node(self, node: NodeProtocol) -> Graph:
56
91
  if node.name in self.nodes:
57
92
  raise Exception(f"Node with name '{node.name}' already exists.")
58
93
  self.nodes[node.name] = node
@@ -68,7 +103,9 @@ class Graph:
68
103
  self.end_nodes.append(node)
69
104
  return self
70
105
 
71
- def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str) -> "Graph":
106
+ def add_edge(
107
+ self, source: NodeProtocol | str, destination: NodeProtocol | str
108
+ ) -> Graph:
72
109
  """Adds an edge between two nodes that already exist in the graph.
73
110
 
74
111
  Args:
@@ -89,9 +126,7 @@ class Graph:
89
126
  else:
90
127
  source_name = str(source)
91
128
  if source_node is None:
92
- raise ValueError(
93
- f"Node with name '{source_name}' not found."
94
- )
129
+ raise ValueError(f"Node with name '{source_name}' not found.")
95
130
  # get destination node from graph
96
131
  destination_name: str
97
132
  if isinstance(destination, str):
@@ -105,9 +140,7 @@ class Graph:
105
140
  else:
106
141
  destination_name = str(destination)
107
142
  if destination_node is None:
108
- raise ValueError(
109
- f"Node with name '{destination_name}' not found."
110
- )
143
+ raise ValueError(f"Node with name '{destination_name}' not found.")
111
144
  edge = Edge(source_node, destination_node)
112
145
  self.edges.append(edge)
113
146
  return self
@@ -117,7 +150,7 @@ class Graph:
117
150
  sources: list[NodeProtocol],
118
151
  router: NodeProtocol,
119
152
  destinations: list[NodeProtocol],
120
- ) -> "Graph":
153
+ ) -> Graph:
121
154
  if not router.is_router:
122
155
  raise TypeError("A router object must be passed to the router parameter.")
123
156
  [self.add_edge(source, router) for source in sources]
@@ -125,26 +158,162 @@ class Graph:
125
158
  self.add_edge(router, destination)
126
159
  return self
127
160
 
128
- def set_start_node(self, node: NodeProtocol) -> "Graph":
161
+ def set_start_node(self, node: NodeProtocol) -> Graph:
129
162
  self.start_node = node
130
163
  return self
131
164
 
132
- def set_end_node(self, node: NodeProtocol) -> "Graph":
165
+ def set_end_node(self, node: NodeProtocol) -> Graph:
133
166
  self.end_node = node
134
167
  return self
135
168
 
136
169
  def compile(self) -> "Graph":
137
- if not self.start_node:
138
- raise Exception("Start node not defined.")
139
- if not self.end_nodes:
140
- raise Exception("No end nodes defined.")
141
- if not self._is_valid():
142
- raise Exception("Graph is not valid.")
143
- return self
170
+ """
171
+ Validate the graph:
172
+ - exactly one start node present (or Graph.start_node set)
173
+ - at least one end node present
174
+ - all edges reference known nodes
175
+ - no cycles
176
+ - all nodes reachable from the start
177
+ Returns self on success; raises GraphCompileError otherwise.
178
+ """
179
+ # nodes map
180
+ nodes = getattr(self, "nodes", None)
181
+ if not isinstance(nodes, dict) or not nodes:
182
+ raise GraphCompileError("No nodes have been added to the graph")
183
+
184
+ start_name: str | None = None
185
+ # Bind and narrow the attribute for mypy
186
+ start_node: _HasName | None = getattr(self, "start_node", None)
187
+ if start_node is not None:
188
+ start_name = start_node.name
189
+ else:
190
+ starts = [
191
+ name
192
+ for name, n in nodes.items()
193
+ if getattr(n, "is_start", False) or getattr(n, "start", False)
194
+ ]
195
+ if len(starts) > 1:
196
+ raise GraphCompileError(f"Multiple start nodes defined: {starts}")
197
+ if len(starts) == 1:
198
+ start_name = starts[0]
199
+
200
+ if not start_name:
201
+ raise GraphCompileError("No start node defined")
202
+
203
+ # at least one end node
204
+ if not any(
205
+ getattr(n, "is_end", False) or getattr(n, "end", False)
206
+ for n in nodes.values()
207
+ ):
208
+ raise GraphCompileError("No end node defined")
209
+
210
+ # normalize edges into adjacency {src: set(dst)}
211
+ raw_edges = getattr(self, "edges", None)
212
+ adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
213
+
214
+ def _add_edge(src: str, dst: str) -> None:
215
+ if src not in nodes:
216
+ raise GraphCompileError(f"Edge references unknown source node: {src}")
217
+ if dst not in nodes:
218
+ raise GraphCompileError(
219
+ f"Edge from {src} references unknown node(s): ['{dst}']"
220
+ )
221
+ adj[src].add(dst)
222
+
223
+ if raw_edges is None:
224
+ pass
225
+ elif isinstance(raw_edges, dict):
226
+ for raw_src, dsts in raw_edges.items():
227
+ src = _require_name(raw_src, "source")
228
+ dst_iter = (
229
+ [dsts]
230
+ if isinstance(dsts, (str,)) or getattr(dsts, "name", None)
231
+ else list(dsts)
232
+ )
233
+ for d in dst_iter:
234
+ dst = _require_name(d, "destination")
235
+ _add_edge(src, dst)
236
+ else:
237
+ # generic iterable of “edge records”
238
+ try:
239
+ iterator = iter(raw_edges)
240
+ except TypeError:
241
+ raise GraphCompileError("Internal edge map has unsupported type")
242
+
243
+ for item in iterator:
244
+ # (src, dst) OR (src, Iterable[dst])
245
+ if isinstance(item, (tuple, list)) and len(item) == 2:
246
+ raw_src, rhs = item
247
+ src = _require_name(raw_src, "source")
248
+
249
+ if isinstance(rhs, str) or getattr(rhs, "name", None):
250
+ dst = _require_name(rhs, "destination")
251
+ _add_edge(src, rhs)
252
+ else:
253
+ # assume iterable of dsts (strings or node-like)
254
+ try:
255
+ for d in rhs:
256
+ dst = _require_name(d, "destination")
257
+ _add_edge(src, d)
258
+ except TypeError:
259
+ raise GraphCompileError(
260
+ "Edge tuple second item must be a destination or an iterable of destinations"
261
+ )
262
+ continue
263
+
264
+ # Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
265
+ if isinstance(item, dict):
266
+ src = _require_name(item.get("source", item.get("src")), "source")
267
+ dst = _require_name(
268
+ item.get("destination", item.get("dst")), "destination"
269
+ )
270
+ _add_edge(src, dst)
271
+ continue
272
+
273
+ # Object with attributes .source/.destination (or .src/.dst)
274
+ if hasattr(item, "source") or hasattr(item, "src"):
275
+ src = _require_name(
276
+ getattr(item, "source", getattr(item, "src", None)), "source"
277
+ )
278
+ dst = _require_name(
279
+ getattr(item, "destination", getattr(item, "dst", None)),
280
+ "destination",
281
+ )
282
+ _add_edge(src, dst)
283
+ continue
284
+
285
+ # If none matched, this is an unsupported edge record
286
+ raise GraphCompileError(
287
+ "Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
288
+ "(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
289
+ )
290
+
291
+ # cycle detection
292
+ preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
293
+ for s, ds in adj.items():
294
+ for d in ds:
295
+ preds[d].add(s)
296
+
297
+ try:
298
+ list(TopologicalSorter(preds).static_order())
299
+ except CycleError as e:
300
+ raise GraphCompileError("Cycle detected in graph") from e
301
+
302
+ # reachability from start
303
+ seen: set[str] = set()
304
+ stack = [start_name]
305
+ while stack:
306
+ cur = stack.pop()
307
+ if cur in seen:
308
+ continue
309
+ seen.add(cur)
310
+ stack.extend(adj.get(cur, ()))
311
+
312
+ unreachable = sorted(set(nodes.keys()) - seen)
313
+ if unreachable:
314
+ raise GraphCompileError(f"Unreachable nodes: {unreachable}")
144
315
 
145
- def _is_valid(self):
146
- # Implement validation logic, e.g., checking for cycles, disconnected components, etc.
147
- return True
316
+ return self
148
317
 
149
318
  def _validate_output(self, output: dict[str, Any], node_name: str):
150
319
  if not isinstance(output, dict):
@@ -249,20 +418,24 @@ class Graph:
249
418
  f"No outgoing edge found for current node '{current_node.name}'."
250
419
  )
251
420
 
252
- def visualize(self):
421
+ def visualize(self, *, save_path: str | None = None):
422
+ """Render the current graph. If matplotlib is not installed,
423
+ raise a helpful error telling users to install the viz extra.
424
+ Optionally save to a file via `save_path`.
425
+ """
253
426
  try:
254
- import networkx as nx
255
- except ImportError:
427
+ import matplotlib.pyplot as plt
428
+ except ImportError as e:
256
429
  raise ImportError(
257
- "NetworkX is required for visualization. Please install it with 'pip install networkx'."
258
- )
430
+ "Graph visualization requires matplotlib. Install it with: `pip install matplotlib`"
431
+ ) from e
259
432
 
260
433
  try:
261
- import matplotlib.pyplot as plt
262
- except ImportError:
434
+ import networkx as nx
435
+ except ImportError as e:
263
436
  raise ImportError(
264
- "Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
265
- )
437
+ "NetworkX is required for visualization. Please install it with `pip install networkx`."
438
+ ) from e
266
439
 
267
440
  G: Any = nx.DiGraph()
268
441
 
@@ -328,8 +501,12 @@ class Graph:
328
501
  arrowsize=20,
329
502
  )
330
503
 
331
- plt.axis("off")
332
- plt.show()
504
+ if save_path:
505
+ plt.savefig(save_path, bbox_inches="tight")
506
+ else:
507
+ plt.axis("off")
508
+ plt.show()
509
+ plt.close()
333
510
 
334
511
 
335
512
  class Edge:
graphai/py.typed ADDED
File without changes
graphai/utils.py CHANGED
@@ -1,9 +1,18 @@
1
+ from enum import Enum
1
2
  import inspect
2
3
  import os
3
- from typing import Any, Callable
4
+ import sys
5
+ from typing import Any, Callable, Union, get_args, get_origin
4
6
  from pydantic import BaseModel, Field
7
+ from pydantic_core import PydanticUndefined
5
8
  import logging
6
- import sys
9
+
10
+
11
+ # we support python 3.10 so we define our own StrEnum (introduced in 3.11)
12
+ class StrEnum(str, Enum):
13
+ """Backport of StrEnum for Python < 3.11"""
14
+ def __str__(self):
15
+ return self.value
7
16
 
8
17
 
9
18
  class ColoredFormatter(logging.Formatter):
@@ -119,6 +128,9 @@ class Parameter(BaseModel):
119
128
  }
120
129
  }
121
130
 
131
+ class OpenAIAPI(StrEnum):
132
+ COMPLETIONS = "completions"
133
+ RESPONSES = "responses"
122
134
 
123
135
  class FunctionSchema(BaseModel):
124
136
  """Class that consumes a function and can return a schema required by
@@ -167,29 +179,95 @@ class FunctionSchema(BaseModel):
167
179
  )
168
180
 
169
181
  @classmethod
170
- def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
182
+ def from_pydantic(cls, model: type[BaseModel]) -> "FunctionSchema":
183
+ """Create a FunctionSchema from a Pydantic model class.
184
+
185
+ :param model: The Pydantic model class to convert
186
+ :type model: type[BaseModel]
187
+ :return: FunctionSchema instance
188
+ :rtype: FunctionSchema
189
+ """
190
+ # Extract model metadata
191
+ name = model.__name__
192
+ description = model.__doc__ or ""
193
+
194
+ # Build parameters list
195
+ parameters = []
171
196
  signature_parts = []
172
- for field_name, field_model in model.__annotations__.items():
173
- field_info = model.model_fields[field_name]
174
- default_value = field_info.default
175
- if default_value:
176
- default_repr = repr(default_value)
177
- signature_part = (
178
- f"{field_name}: {field_model.__name__} = {default_repr}"
197
+
198
+ for field_name, field_info in model.model_fields.items():
199
+ # Get the field type
200
+ field_type = model.__annotations__.get(field_name)
201
+
202
+ # Determine the type name - handle Optional and other generic types
203
+ type_name = str(field_type)
204
+
205
+ # Try to extract the actual type from Optional[T] -> T
206
+ origin = get_origin(field_type)
207
+ args = get_args(field_type)
208
+
209
+ if origin is Union:
210
+ # This is likely Optional[T] which is Union[T, None]
211
+ non_none_types = [arg for arg in args if arg is not type(None)]
212
+ if non_none_types:
213
+ actual_type = non_none_types[0]
214
+ if hasattr(actual_type, '__name__'):
215
+ type_name = actual_type.__name__
216
+ else:
217
+ type_name = str(actual_type)
218
+ elif field_type and hasattr(field_type, '__name__'):
219
+ type_name = field_type.__name__
220
+
221
+ # Check if field is required (no default value)
222
+ # In Pydantic v2, PydanticUndefined means no default
223
+ is_required = (
224
+ field_info.default is PydanticUndefined
225
+ and field_info.default_factory is None
226
+ )
227
+
228
+ # Get the actual default value
229
+ if field_info.default is not PydanticUndefined and field_info.default is not None:
230
+ default_value = field_info.default
231
+ elif field_info.default_factory is not None:
232
+ # For default_factory, we can't always call it without arguments
233
+ # Just use a placeholder to indicate there's a factory
234
+ try:
235
+ # Try calling with no arguments (common case)
236
+ default_value = field_info.default_factory() # type: ignore[call-arg]
237
+ except TypeError:
238
+ # If it needs arguments, just indicate it has a factory default
239
+ default_value = "<factory>"
240
+ else:
241
+ default_value = inspect.Parameter.empty
242
+
243
+ # Add parameter
244
+ parameters.append(
245
+ Parameter(
246
+ name=field_name,
247
+ description=field_info.description,
248
+ type=type_name,
249
+ default=default_value,
250
+ required=is_required,
179
251
  )
252
+ )
253
+
254
+ # Build signature part
255
+ if default_value != inspect.Parameter.empty:
256
+ signature_parts.append(f"{field_name}: {type_name} = {repr(default_value)}")
180
257
  else:
181
- signature_part = f"{field_name}: {field_model.__name__}"
182
- signature_parts.append(signature_part)
183
- signature = f"({', '.join(signature_parts)}) -> str"
258
+ signature_parts.append(f"{field_name}: {type_name}")
259
+
260
+ signature = f"({', '.join(signature_parts)}) -> dict"
261
+
184
262
  return cls.model_construct(
185
- name=model.__class__.__name__,
186
- description=model.__doc__ or "",
263
+ name=name,
264
+ description=description,
187
265
  signature=signature,
188
- output="", # TODO: Implement output
189
- parameters=[],
266
+ output="dict",
267
+ parameters=parameters,
190
268
  )
191
269
 
192
- def to_dict(self) -> dict:
270
+ def to_dict(self) -> dict[str, Any]:
193
271
  schema_dict = {
194
272
  "type": "function",
195
273
  "function": {
@@ -210,14 +288,42 @@ class FunctionSchema(BaseModel):
210
288
  }
211
289
  return schema_dict
212
290
 
213
- def to_openai(self) -> dict:
214
- return self.to_dict()
291
+ def to_openai(self, api: OpenAIAPI=OpenAIAPI.COMPLETIONS) -> dict[str, Any]:
292
+ """Convert the function schema into OpenAI-compatible formats. Supports
293
+ both completions and responses APIs.
294
+
295
+ :param api: The API to convert to.
296
+ :type api: OpenAIAPI
297
+ :return: The function schema in OpenAI-compatible format.
298
+ :rtype: dict
299
+ """
300
+ if api == "completions":
301
+ return self.to_dict()
302
+ elif api == "responses":
303
+ return {
304
+ "type": "function",
305
+ "name": self.name,
306
+ "description": self.description,
307
+ "parameters": {
308
+ "type": "object",
309
+ "properties": {
310
+ k: v
311
+ for param in self.parameters
312
+ for k, v in param.to_dict().items()
313
+ },
314
+ "required": [
315
+ param.name for param in self.parameters if param.required
316
+ ],
317
+ },
318
+ }
319
+ else:
320
+ raise ValueError(f"Unrecognized OpenAI API: {api}")
215
321
 
216
322
 
217
323
  DEFAULT = set(["default", "openai", "ollama", "litellm"])
218
324
 
219
325
 
220
- def get_schemas(callables: list[Callable], format: str = "default") -> list[dict]:
326
+ def get_schemas(callables: list[Callable], format: str = "default") -> list[dict[str, Any]]:
221
327
  if format in DEFAULT:
222
328
  return [
223
329
  FunctionSchema.from_callable(callable).to_dict() for callable in callables
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.8
3
+ Version: 0.0.9rc1
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -0,0 +1,11 @@
1
+ graphai/__init__.py,sha256=UbqXq7iGIYe1GyTPcpgLSXbgWovggFsAbTMtr4JQm3M,160
2
+ graphai/callback.py,sha256=Wl0JCmE8NcFifKmP9-a5bFa0WKVHTdrSClHVRmIEPpc,7323
3
+ graphai/graph.py,sha256=Bm_Jwa5EMUACqXTZSoibVQfHSwS1rV-ExfxqYRIkPDY,18944
4
+ graphai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ graphai/utils.py,sha256=LIFg-fQalU9sB5DCuk6is48OdpEgNX95i9h-YddFbvM,11717
6
+ graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
7
+ graphai/nodes/base.py,sha256=SoKfOdRu5EIJ_z8xIz5zbNXcxPI2l9MKTQDeaQI-2no,7494
8
+ graphai_lib-0.0.9rc1.dist-info/METADATA,sha256=T8nS6wOKwd7ZfzVvD_h4m38bowb9zzr7TutdIz5GWJQ,913
9
+ graphai_lib-0.0.9rc1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ graphai_lib-0.0.9rc1.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
11
+ graphai_lib-0.0.9rc1.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- graphai/__init__.py,sha256=UbqXq7iGIYe1GyTPcpgLSXbgWovggFsAbTMtr4JQm3M,160
2
- graphai/callback.py,sha256=Wl0JCmE8NcFifKmP9-a5bFa0WKVHTdrSClHVRmIEPpc,7323
3
- graphai/graph.py,sha256=4pY07zSs8VakN9Pma3xgiFOgKqkrHMi5C04umo4UN8c,12342
4
- graphai/utils.py,sha256=_q3pE7rgeXzGLYE41LWn87rlW3t74hQQwqHI1PvVjrY,7554
5
- graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
6
- graphai/nodes/base.py,sha256=SoKfOdRu5EIJ_z8xIz5zbNXcxPI2l9MKTQDeaQI-2no,7494
7
- graphai_lib-0.0.8.dist-info/METADATA,sha256=5OmA9X6HwtZxOr2GexbLFsBv7TBL8bMzE5OpLZ5oReE,910
8
- graphai_lib-0.0.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
- graphai_lib-0.0.8.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
10
- graphai_lib-0.0.8.dist-info/RECORD,,