graphai-lib 0.0.8__tar.gz → 0.0.9rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.8
3
+ Version: 0.0.9rc2
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,8 +1,22 @@
1
- from typing import Any, Protocol, Type
1
+ from __future__ import annotations
2
+ from typing import Any, Protocol
2
3
  from graphai.callback import Callback
3
4
  from graphai.utils import logger
4
5
 
5
6
 
7
+ # to fix mypy error
8
+ class _HasName(Protocol):
9
+ name: str
10
+
11
+
12
+ class GraphError(Exception):
13
+ pass
14
+
15
+
16
+ class GraphCompileError(GraphError):
17
+ pass
18
+
19
+
6
20
  class NodeProtocol(Protocol):
7
21
  """Protocol defining the interface of a decorated node."""
8
22
 
@@ -20,6 +34,26 @@ class NodeProtocol(Protocol):
20
34
  ) -> dict[str, Any]: ...
21
35
 
22
36
 
37
+ def _name_of(x: Any) -> str | None:
38
+ """Return the node name if x is a str or has .name, else None."""
39
+ if x is None:
40
+ return None
41
+ if isinstance(x, str):
42
+ return x
43
+ name = getattr(x, "name", None)
44
+ return name if isinstance(name, str) else None
45
+
46
+
47
+ def _require_name(x: Any, kind: str) -> str:
48
+ """Like _name_of, but raises a helpful compile error when missing."""
49
+ s = _name_of(x)
50
+ if s is None:
51
+ raise GraphCompileError(
52
+ f"Edge {kind} must be a node name (str) or object with .name"
53
+ )
54
+ return s
55
+
56
+
23
57
  class Graph:
24
58
  def __init__(
25
59
  self, max_steps: int = 10, initial_state: dict[str, Any] | None = None
@@ -28,7 +62,7 @@ class Graph:
28
62
  self.edges: list[Any] = []
29
63
  self.start_node: NodeProtocol | None = None
30
64
  self.end_nodes: list[NodeProtocol] = []
31
- self.Callback: Type[Callback] = Callback
65
+ self.Callback: type[Callback] = Callback
32
66
  self.max_steps = max_steps
33
67
  self.state = initial_state or {}
34
68
 
@@ -37,22 +71,22 @@ class Graph:
37
71
  """Get the current graph state."""
38
72
  return self.state
39
73
 
40
- def set_state(self, state: dict[str, Any]) -> "Graph":
74
+ def set_state(self, state: dict[str, Any]) -> Graph:
41
75
  """Set the graph state."""
42
76
  self.state = state
43
77
  return self
44
78
 
45
- def update_state(self, values: dict[str, Any]) -> "Graph":
79
+ def update_state(self, values: dict[str, Any]) -> Graph:
46
80
  """Update the graph state with new values."""
47
81
  self.state.update(values)
48
82
  return self
49
83
 
50
- def reset_state(self) -> "Graph":
84
+ def reset_state(self) -> Graph:
51
85
  """Reset the graph state to an empty dict."""
52
86
  self.state = {}
53
87
  return self
54
88
 
55
- def add_node(self, node: NodeProtocol) -> "Graph":
89
+ def add_node(self, node: NodeProtocol) -> Graph:
56
90
  if node.name in self.nodes:
57
91
  raise Exception(f"Node with name '{node.name}' already exists.")
58
92
  self.nodes[node.name] = node
@@ -68,7 +102,9 @@ class Graph:
68
102
  self.end_nodes.append(node)
69
103
  return self
70
104
 
71
- def add_edge(self, source: NodeProtocol | str, destination: NodeProtocol | str) -> "Graph":
105
+ def add_edge(
106
+ self, source: NodeProtocol | str, destination: NodeProtocol | str
107
+ ) -> Graph:
72
108
  """Adds an edge between two nodes that already exist in the graph.
73
109
 
74
110
  Args:
@@ -89,9 +125,7 @@ class Graph:
89
125
  else:
90
126
  source_name = str(source)
91
127
  if source_node is None:
92
- raise ValueError(
93
- f"Node with name '{source_name}' not found."
94
- )
128
+ raise ValueError(f"Node with name '{source_name}' not found.")
95
129
  # get destination node from graph
96
130
  destination_name: str
97
131
  if isinstance(destination, str):
@@ -105,9 +139,7 @@ class Graph:
105
139
  else:
106
140
  destination_name = str(destination)
107
141
  if destination_node is None:
108
- raise ValueError(
109
- f"Node with name '{destination_name}' not found."
110
- )
142
+ raise ValueError(f"Node with name '{destination_name}' not found.")
111
143
  edge = Edge(source_node, destination_node)
112
144
  self.edges.append(edge)
113
145
  return self
@@ -117,7 +149,7 @@ class Graph:
117
149
  sources: list[NodeProtocol],
118
150
  router: NodeProtocol,
119
151
  destinations: list[NodeProtocol],
120
- ) -> "Graph":
152
+ ) -> Graph:
121
153
  if not router.is_router:
122
154
  raise TypeError("A router object must be passed to the router parameter.")
123
155
  [self.add_edge(source, router) for source in sources]
@@ -125,26 +157,151 @@ class Graph:
125
157
  self.add_edge(router, destination)
126
158
  return self
127
159
 
128
- def set_start_node(self, node: NodeProtocol) -> "Graph":
160
+ def set_start_node(self, node: NodeProtocol) -> Graph:
129
161
  self.start_node = node
130
162
  return self
131
163
 
132
- def set_end_node(self, node: NodeProtocol) -> "Graph":
164
+ def set_end_node(self, node: NodeProtocol) -> Graph:
133
165
  self.end_node = node
134
166
  return self
135
167
 
136
168
  def compile(self) -> "Graph":
137
- if not self.start_node:
138
- raise Exception("Start node not defined.")
139
- if not self.end_nodes:
140
- raise Exception("No end nodes defined.")
141
- if not self._is_valid():
142
- raise Exception("Graph is not valid.")
143
- return self
169
+ """
170
+ Validate the graph:
171
+ - exactly one start node present (or Graph.start_node set)
172
+ - at least one end node present
173
+ - all edges reference known nodes
174
+ - no cycles
175
+ - all nodes reachable from the start
176
+ Returns self on success; raises GraphCompileError otherwise.
177
+ """
178
+ # nodes map
179
+ nodes = getattr(self, "nodes", None)
180
+ if not isinstance(nodes, dict) or not nodes:
181
+ raise GraphCompileError("No nodes have been added to the graph")
182
+
183
+ start_name: str | None = None
184
+ # Bind and narrow the attribute for mypy
185
+ start_node: _HasName | None = getattr(self, "start_node", None)
186
+ if start_node is not None:
187
+ start_name = start_node.name
188
+ else:
189
+ starts = [
190
+ name
191
+ for name, n in nodes.items()
192
+ if getattr(n, "is_start", False) or getattr(n, "start", False)
193
+ ]
194
+ if len(starts) > 1:
195
+ raise GraphCompileError(f"Multiple start nodes defined: {starts}")
196
+ if len(starts) == 1:
197
+ start_name = starts[0]
198
+
199
+ if not start_name:
200
+ raise GraphCompileError("No start node defined")
201
+
202
+ # at least one end node
203
+ if not any(
204
+ getattr(n, "is_end", False) or getattr(n, "end", False)
205
+ for n in nodes.values()
206
+ ):
207
+ raise GraphCompileError("No end node defined")
208
+
209
+ # normalize edges into adjacency {src: set(dst)}
210
+ raw_edges = getattr(self, "edges", None)
211
+ adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
212
+
213
+ def _add_edge(src: str, dst: str) -> None:
214
+ if src not in nodes:
215
+ raise GraphCompileError(f"Edge references unknown source node: {src}")
216
+ if dst not in nodes:
217
+ raise GraphCompileError(
218
+ f"Edge from {src} references unknown node(s): ['{dst}']"
219
+ )
220
+ adj[src].add(dst)
221
+
222
+ if raw_edges is None:
223
+ pass
224
+ elif isinstance(raw_edges, dict):
225
+ for raw_src, dsts in raw_edges.items():
226
+ src = _require_name(raw_src, "source")
227
+ dst_iter = (
228
+ [dsts]
229
+ if isinstance(dsts, (str,)) or getattr(dsts, "name", None)
230
+ else list(dsts)
231
+ )
232
+ for d in dst_iter:
233
+ dst = _require_name(d, "destination")
234
+ _add_edge(src, dst)
235
+ else:
236
+ # generic iterable of “edge records”
237
+ try:
238
+ iterator = iter(raw_edges)
239
+ except TypeError:
240
+ raise GraphCompileError("Internal edge map has unsupported type")
241
+
242
+ for item in iterator:
243
+ # (src, dst) OR (src, Iterable[dst])
244
+ if isinstance(item, (tuple, list)) and len(item) == 2:
245
+ raw_src, rhs = item
246
+ src = _require_name(raw_src, "source")
247
+
248
+ if isinstance(rhs, str) or getattr(rhs, "name", None):
249
+ dst = _require_name(rhs, "destination")
250
+ _add_edge(src, rhs)
251
+ else:
252
+ # assume iterable of dsts (strings or node-like)
253
+ try:
254
+ for d in rhs:
255
+ dst = _require_name(d, "destination")
256
+ _add_edge(src, d)
257
+ except TypeError:
258
+ raise GraphCompileError(
259
+ "Edge tuple second item must be a destination or an iterable of destinations"
260
+ )
261
+ continue
262
+
263
+ # Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
264
+ if isinstance(item, dict):
265
+ src = _require_name(item.get("source", item.get("src")), "source")
266
+ dst = _require_name(
267
+ item.get("destination", item.get("dst")), "destination"
268
+ )
269
+ _add_edge(src, dst)
270
+ continue
271
+
272
+ # Object with attributes .source/.destination (or .src/.dst)
273
+ if hasattr(item, "source") or hasattr(item, "src"):
274
+ src = _require_name(
275
+ getattr(item, "source", getattr(item, "src", None)), "source"
276
+ )
277
+ dst = _require_name(
278
+ getattr(item, "destination", getattr(item, "dst", None)),
279
+ "destination",
280
+ )
281
+ _add_edge(src, dst)
282
+ continue
283
+
284
+ # If none matched, this is an unsupported edge record
285
+ raise GraphCompileError(
286
+ "Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
287
+ "(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
288
+ )
289
+
290
+ # reachability from start
291
+ seen: set[str] = set()
292
+ stack = [start_name]
293
+ while stack:
294
+ cur = stack.pop()
295
+ if cur in seen:
296
+ continue
297
+ seen.add(cur)
298
+ stack.extend(adj.get(cur, ()))
144
299
 
145
- def _is_valid(self):
146
- # Implement validation logic, e.g., checking for cycles, disconnected components, etc.
147
- return True
300
+ unreachable = sorted(set(nodes.keys()) - seen)
301
+ if unreachable:
302
+ raise GraphCompileError(f"Unreachable nodes: {unreachable}")
303
+
304
+ return self
148
305
 
149
306
  def _validate_output(self, output: dict[str, Any], node_name: str):
150
307
  if not isinstance(output, dict):
@@ -219,7 +376,7 @@ class Graph:
219
376
  as the default callback when no callback is passed to the `execute` method.
220
377
 
221
378
  :param callback_class: The callback class to use as the default callback.
222
- :type callback_class: Type[Callback]
379
+ :type callback_class: type[Callback]
223
380
  """
224
381
  self.Callback = callback_class
225
382
  return self
@@ -249,20 +406,24 @@ class Graph:
249
406
  f"No outgoing edge found for current node '{current_node.name}'."
250
407
  )
251
408
 
252
- def visualize(self):
409
+ def visualize(self, *, save_path: str | None = None):
410
+ """Render the current graph. If matplotlib is not installed,
411
+ raise a helpful error telling users to install the viz extra.
412
+ Optionally save to a file via `save_path`.
413
+ """
253
414
  try:
254
- import networkx as nx
255
- except ImportError:
415
+ import matplotlib.pyplot as plt
416
+ except ImportError as e:
256
417
  raise ImportError(
257
- "NetworkX is required for visualization. Please install it with 'pip install networkx'."
258
- )
418
+ "Graph visualization requires matplotlib. Install it with: `pip install matplotlib`"
419
+ ) from e
259
420
 
260
421
  try:
261
- import matplotlib.pyplot as plt
262
- except ImportError:
422
+ import networkx as nx
423
+ except ImportError as e:
263
424
  raise ImportError(
264
- "Matplotlib is required for visualization. Please install it with 'pip install matplotlib'."
265
- )
425
+ "NetworkX is required for visualization. Please install it with `pip install networkx`."
426
+ ) from e
266
427
 
267
428
  G: Any = nx.DiGraph()
268
429
 
@@ -328,8 +489,12 @@ class Graph:
328
489
  arrowsize=20,
329
490
  )
330
491
 
331
- plt.axis("off")
332
- plt.show()
492
+ if save_path:
493
+ plt.savefig(save_path, bbox_inches="tight")
494
+ else:
495
+ plt.axis("off")
496
+ plt.show()
497
+ plt.close()
333
498
 
334
499
 
335
500
  class Edge:
File without changes
@@ -1,9 +1,18 @@
1
+ from enum import Enum
1
2
  import inspect
2
3
  import os
3
- from typing import Any, Callable
4
+ import sys
5
+ from typing import Any, Callable, Union, get_args, get_origin
4
6
  from pydantic import BaseModel, Field
7
+ from pydantic_core import PydanticUndefined
5
8
  import logging
6
- import sys
9
+
10
+
11
+ # we support python 3.10 so we define our own StrEnum (introduced in 3.11)
12
+ class StrEnum(str, Enum):
13
+ """Backport of StrEnum for Python < 3.11"""
14
+ def __str__(self):
15
+ return self.value
7
16
 
8
17
 
9
18
  class ColoredFormatter(logging.Formatter):
@@ -119,6 +128,9 @@ class Parameter(BaseModel):
119
128
  }
120
129
  }
121
130
 
131
+ class OpenAIAPI(StrEnum):
132
+ COMPLETIONS = "completions"
133
+ RESPONSES = "responses"
122
134
 
123
135
  class FunctionSchema(BaseModel):
124
136
  """Class that consumes a function and can return a schema required by
@@ -167,29 +179,95 @@ class FunctionSchema(BaseModel):
167
179
  )
168
180
 
169
181
  @classmethod
170
- def from_pydantic(cls, model: BaseModel) -> "FunctionSchema":
182
+ def from_pydantic(cls, model: type[BaseModel]) -> "FunctionSchema":
183
+ """Create a FunctionSchema from a Pydantic model class.
184
+
185
+ :param model: The Pydantic model class to convert
186
+ :type model: type[BaseModel]
187
+ :return: FunctionSchema instance
188
+ :rtype: FunctionSchema
189
+ """
190
+ # Extract model metadata
191
+ name = model.__name__
192
+ description = model.__doc__ or ""
193
+
194
+ # Build parameters list
195
+ parameters = []
171
196
  signature_parts = []
172
- for field_name, field_model in model.__annotations__.items():
173
- field_info = model.model_fields[field_name]
174
- default_value = field_info.default
175
- if default_value:
176
- default_repr = repr(default_value)
177
- signature_part = (
178
- f"{field_name}: {field_model.__name__} = {default_repr}"
197
+
198
+ for field_name, field_info in model.model_fields.items():
199
+ # Get the field type
200
+ field_type = model.__annotations__.get(field_name)
201
+
202
+ # Determine the type name - handle Optional and other generic types
203
+ type_name = str(field_type)
204
+
205
+ # Try to extract the actual type from Optional[T] -> T
206
+ origin = get_origin(field_type)
207
+ args = get_args(field_type)
208
+
209
+ if origin is Union:
210
+ # This is likely Optional[T] which is Union[T, None]
211
+ non_none_types = [arg for arg in args if arg is not type(None)]
212
+ if non_none_types:
213
+ actual_type = non_none_types[0]
214
+ if hasattr(actual_type, '__name__'):
215
+ type_name = actual_type.__name__
216
+ else:
217
+ type_name = str(actual_type)
218
+ elif field_type and hasattr(field_type, '__name__'):
219
+ type_name = field_type.__name__
220
+
221
+ # Check if field is required (no default value)
222
+ # In Pydantic v2, PydanticUndefined means no default
223
+ is_required = (
224
+ field_info.default is PydanticUndefined
225
+ and field_info.default_factory is None
226
+ )
227
+
228
+ # Get the actual default value
229
+ if field_info.default is not PydanticUndefined and field_info.default is not None:
230
+ default_value = field_info.default
231
+ elif field_info.default_factory is not None:
232
+ # For default_factory, we can't always call it without arguments
233
+ # Just use a placeholder to indicate there's a factory
234
+ try:
235
+ # Try calling with no arguments (common case)
236
+ default_value = field_info.default_factory() # type: ignore[call-arg]
237
+ except TypeError:
238
+ # If it needs arguments, just indicate it has a factory default
239
+ default_value = "<factory>"
240
+ else:
241
+ default_value = inspect.Parameter.empty
242
+
243
+ # Add parameter
244
+ parameters.append(
245
+ Parameter(
246
+ name=field_name,
247
+ description=field_info.description,
248
+ type=type_name,
249
+ default=default_value,
250
+ required=is_required,
179
251
  )
252
+ )
253
+
254
+ # Build signature part
255
+ if default_value != inspect.Parameter.empty:
256
+ signature_parts.append(f"{field_name}: {type_name} = {repr(default_value)}")
180
257
  else:
181
- signature_part = f"{field_name}: {field_model.__name__}"
182
- signature_parts.append(signature_part)
183
- signature = f"({', '.join(signature_parts)}) -> str"
258
+ signature_parts.append(f"{field_name}: {type_name}")
259
+
260
+ signature = f"({', '.join(signature_parts)}) -> dict"
261
+
184
262
  return cls.model_construct(
185
- name=model.__class__.__name__,
186
- description=model.__doc__ or "",
263
+ name=name,
264
+ description=description,
187
265
  signature=signature,
188
- output="", # TODO: Implement output
189
- parameters=[],
266
+ output="dict",
267
+ parameters=parameters,
190
268
  )
191
269
 
192
- def to_dict(self) -> dict:
270
+ def to_dict(self) -> dict[str, Any]:
193
271
  schema_dict = {
194
272
  "type": "function",
195
273
  "function": {
@@ -210,14 +288,42 @@ class FunctionSchema(BaseModel):
210
288
  }
211
289
  return schema_dict
212
290
 
213
- def to_openai(self) -> dict:
214
- return self.to_dict()
291
+ def to_openai(self, api: OpenAIAPI=OpenAIAPI.COMPLETIONS) -> dict[str, Any]:
292
+ """Convert the function schema into OpenAI-compatible formats. Supports
293
+ both completions and responses APIs.
294
+
295
+ :param api: The API to convert to.
296
+ :type api: OpenAIAPI
297
+ :return: The function schema in OpenAI-compatible format.
298
+ :rtype: dict
299
+ """
300
+ if api == "completions":
301
+ return self.to_dict()
302
+ elif api == "responses":
303
+ return {
304
+ "type": "function",
305
+ "name": self.name,
306
+ "description": self.description,
307
+ "parameters": {
308
+ "type": "object",
309
+ "properties": {
310
+ k: v
311
+ for param in self.parameters
312
+ for k, v in param.to_dict().items()
313
+ },
314
+ "required": [
315
+ param.name for param in self.parameters if param.required
316
+ ],
317
+ },
318
+ }
319
+ else:
320
+ raise ValueError(f"Unrecognized OpenAI API: {api}")
215
321
 
216
322
 
217
323
  DEFAULT = set(["default", "openai", "ollama", "litellm"])
218
324
 
219
325
 
220
- def get_schemas(callables: list[Callable], format: str = "default") -> list[dict]:
326
+ def get_schemas(callables: list[Callable], format: str = "default") -> list[dict[str, Any]]:
221
327
  if format in DEFAULT:
222
328
  return [
223
329
  FunctionSchema.from_callable(callable).to_dict() for callable in callables
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.8
3
+ Version: 0.0.9rc2
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -3,6 +3,7 @@ pyproject.toml
3
3
  graphai/__init__.py
4
4
  graphai/callback.py
5
5
  graphai/graph.py
6
+ graphai/py.typed
6
7
  graphai/utils.py
7
8
  graphai/nodes/__init__.py
8
9
  graphai/nodes/base.py
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "graphai-lib"
3
- version = "0.0.8"
3
+ version = "0.0.9rc2"
4
4
  description = "Not an AI framework"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.14"
@@ -29,6 +29,9 @@ build-backend = "setuptools.build_meta"
29
29
  [tool.setuptools]
30
30
  packages = ["graphai", "graphai.nodes"]
31
31
 
32
+ [tool.setuptools.package-data]
33
+ graphai = ["py.typed"]
34
+
32
35
  [tool.mypy]
33
36
  python_version = "3.10"
34
37
  warn_return_any = true
File without changes
File without changes