graphai-lib 0.0.9rc3__tar.gz → 0.0.10rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.9rc3
3
+ Version: 0.0.10rc2
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  from dataclasses import dataclass
3
3
  from enum import Enum
4
+ import json
4
5
  from pydantic import Field
5
6
  from typing import Any
6
7
  from collections.abc import AsyncIterator
@@ -43,6 +44,22 @@ class GraphEvent:
43
44
  token: str | None = None
44
45
  params: dict[str, Any] | None = None
45
46
 
47
+ def encode(self, charset: str = "utf-8") -> bytes:
48
+ """Encodes the event as a JSON string, important for compatability with FastAPI
49
+ and starlette.
50
+
51
+ :param charset: The character set to use for encoding the event.
52
+ :type charset: str
53
+ """
54
+ event_dict = {
55
+ "type": self.type.value if hasattr(self.type, "value") else str(self.type),
56
+ "identifier": self.identifier,
57
+ "token": self.token,
58
+ "params": self.params,
59
+ }
60
+ data = f"data: {json.dumps(event_dict, ensure_ascii=False, separators=(',', ':'))}\n\n"
61
+ return data.encode(charset)
62
+
46
63
 
47
64
  class Callback:
48
65
  """The original callback handler class. Outputs a stream of structured text
@@ -1,5 +1,8 @@
1
1
  from __future__ import annotations
2
- from typing import Any, Protocol
2
+ import asyncio
3
+ from typing import Any, Iterable, Protocol
4
+ from graphlib import TopologicalSorter, CycleError
5
+
3
6
  from graphai.callback import Callback
4
7
  from graphai.utils import logger
5
8
 
@@ -62,22 +65,41 @@ class Graph:
62
65
  self.edges: list[Any] = []
63
66
  self.start_node: NodeProtocol | None = None
64
67
  self.end_nodes: list[NodeProtocol] = []
68
+ self.join_nodes: set[NodeProtocol] = set()
65
69
  self.Callback: type[Callback] = Callback
66
70
  self.max_steps = max_steps
67
71
  self.state = initial_state or {}
68
72
 
69
73
  # Allow getting and setting the graph's internal state
70
74
  def get_state(self) -> dict[str, Any]:
71
- """Get the current graph state."""
75
+ """Get the current graph state.
76
+
77
+ Returns:
78
+ The current graph state.
79
+ """
72
80
  return self.state
73
81
 
74
82
  def set_state(self, state: dict[str, Any]) -> Graph:
75
- """Set the graph state."""
83
+ """Set the graph state.
84
+
85
+ Args:
86
+ state: The new state to set for the graph.
87
+
88
+ Returns:
89
+ The graph instance.
90
+ """
76
91
  self.state = state
77
92
  return self
78
93
 
79
94
  def update_state(self, values: dict[str, Any]) -> Graph:
80
- """Update the graph state with new values."""
95
+ """Update the graph state with new values.
96
+
97
+ Args:
98
+ values: The new values to update the graph state with.
99
+
100
+ Returns:
101
+ The graph instance.
102
+ """
81
103
  self.state.update(values)
82
104
  return self
83
105
 
@@ -87,6 +109,14 @@ class Graph:
87
109
  return self
88
110
 
89
111
  def add_node(self, node: NodeProtocol) -> Graph:
112
+ """Adds a node to the graph.
113
+
114
+ Args:
115
+ node: The node to add to the graph.
116
+
117
+ Raises:
118
+ Exception: If a node with the same name already exists in the graph.
119
+ """
90
120
  if node.name in self.nodes:
91
121
  raise Exception(f"Node with name '{node.name}' already exists.")
92
122
  self.nodes[node.name] = node
@@ -102,6 +132,18 @@ class Graph:
102
132
  self.end_nodes.append(node)
103
133
  return self
104
134
 
135
+ def _get_node(self, node_candidate: NodeProtocol | str) -> NodeProtocol:
136
+ # first get node from graph
137
+ if isinstance(node_candidate, str):
138
+ node = self.nodes.get(node_candidate)
139
+ else:
140
+ # check if it's a node-like object by looking for required attributes
141
+ if hasattr(node_candidate, "name"):
142
+ node = self.nodes.get(node_candidate.name)
143
+ if node is None:
144
+ raise ValueError(f"Node with name '{node_candidate}' not found.")
145
+ return node
146
+
105
147
  def add_edge(
106
148
  self, source: NodeProtocol | str, destination: NodeProtocol | str
107
149
  ) -> Graph:
@@ -113,33 +155,10 @@ class Graph:
113
155
  """
114
156
  source_node, destination_node = None, None
115
157
  # get source node from graph
116
- source_name: str
117
- if isinstance(source, str):
118
- source_node = self.nodes.get(source)
119
- source_name = source
120
- else:
121
- # Check if it's a node-like object by looking for required attributes
122
- if hasattr(source, "name"):
123
- source_node = self.nodes.get(source.name)
124
- source_name = source.name
125
- else:
126
- source_name = str(source)
127
- if source_node is None:
128
- raise ValueError(f"Node with name '{source_name}' not found.")
158
+ source_node = self._get_node(node_candidate=source)
129
159
  # get destination node from graph
130
- destination_name: str
131
- if isinstance(destination, str):
132
- destination_node = self.nodes.get(destination)
133
- destination_name = destination
134
- else:
135
- # Check if it's a node-like object by looking for required attributes
136
- if hasattr(destination, "name"):
137
- destination_node = self.nodes.get(destination.name)
138
- destination_name = destination.name
139
- else:
140
- destination_name = str(destination)
141
- if destination_node is None:
142
- raise ValueError(f"Node with name '{destination_name}' not found.")
160
+ destination_node = self._get_node(node_candidate=destination)
161
+ # create edge
143
162
  edge = Edge(source_node, destination_node)
144
163
  self.edges.append(edge)
145
164
  return self
@@ -150,6 +169,14 @@ class Graph:
150
169
  router: NodeProtocol,
151
170
  destinations: list[NodeProtocol],
152
171
  ) -> Graph:
172
+ """Adds a router node, allowing for a decision to be made on which branch to
173
+ follow based on the `choice` output of the router node.
174
+
175
+ Args:
176
+ sources: The list of source nodes for the router.
177
+ router: The router node.
178
+ destinations: The list of destination nodes for the router.
179
+ """
153
180
  if not router.is_router:
154
181
  raise TypeError("A router object must be passed to the router parameter.")
155
182
  [self.add_edge(source, router) for source in sources]
@@ -165,21 +192,19 @@ class Graph:
165
192
  self.end_node = node
166
193
  return self
167
194
 
168
- def compile(self) -> "Graph":
169
- """
170
- Validate the graph:
195
+ def compile(self, *, strict: bool = False) -> Graph:
196
+ """Validate the graph:
171
197
  - exactly one start node present (or Graph.start_node set)
172
198
  - at least one end node present
173
199
  - all edges reference known nodes
174
- - no cycles
175
200
  - all nodes reachable from the start
201
+ (optional) **no cycles** when strict=True
176
202
  Returns self on success; raises GraphCompileError otherwise.
177
203
  """
178
204
  # nodes map
179
205
  nodes = getattr(self, "nodes", None)
180
206
  if not isinstance(nodes, dict) or not nodes:
181
207
  raise GraphCompileError("No nodes have been added to the graph")
182
-
183
208
  start_name: str | None = None
184
209
  # Bind and narrow the attribute for mypy
185
210
  start_node: _HasName | None = getattr(self, "start_node", None)
@@ -195,21 +220,17 @@ class Graph:
195
220
  raise GraphCompileError(f"Multiple start nodes defined: {starts}")
196
221
  if len(starts) == 1:
197
222
  start_name = starts[0]
198
-
199
223
  if not start_name:
200
224
  raise GraphCompileError("No start node defined")
201
-
202
225
  # at least one end node
203
226
  if not any(
204
227
  getattr(n, "is_end", False) or getattr(n, "end", False)
205
228
  for n in nodes.values()
206
229
  ):
207
230
  raise GraphCompileError("No end node defined")
208
-
209
231
  # normalize edges into adjacency {src: set(dst)}
210
232
  raw_edges = getattr(self, "edges", None)
211
233
  adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
212
-
213
234
  def _add_edge(src: str, dst: str) -> None:
214
235
  if src not in nodes:
215
236
  raise GraphCompileError(f"Edge references unknown source node: {src}")
@@ -218,7 +239,6 @@ class Graph:
218
239
  f"Edge from {src} references unknown node(s): ['{dst}']"
219
240
  )
220
241
  adj[src].add(dst)
221
-
222
242
  if raw_edges is None:
223
243
  pass
224
244
  elif isinstance(raw_edges, dict):
@@ -238,13 +258,11 @@ class Graph:
238
258
  iterator = iter(raw_edges)
239
259
  except TypeError:
240
260
  raise GraphCompileError("Internal edge map has unsupported type")
241
-
242
261
  for item in iterator:
243
262
  # (src, dst) OR (src, Iterable[dst])
244
263
  if isinstance(item, (tuple, list)) and len(item) == 2:
245
264
  raw_src, rhs = item
246
265
  src = _require_name(raw_src, "source")
247
-
248
266
  if isinstance(rhs, str) or getattr(rhs, "name", None):
249
267
  dst = _require_name(rhs, "destination")
250
268
  _add_edge(src, rhs)
@@ -259,7 +277,6 @@ class Graph:
259
277
  "Edge tuple second item must be a destination or an iterable of destinations"
260
278
  )
261
279
  continue
262
-
263
280
  # Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
264
281
  if isinstance(item, dict):
265
282
  src = _require_name(item.get("source", item.get("src")), "source")
@@ -268,7 +285,6 @@ class Graph:
268
285
  )
269
286
  _add_edge(src, dst)
270
287
  continue
271
-
272
288
  # Object with attributes .source/.destination (or .src/.dst)
273
289
  if hasattr(item, "source") or hasattr(item, "src"):
274
290
  src = _require_name(
@@ -280,13 +296,11 @@ class Graph:
280
296
  )
281
297
  _add_edge(src, dst)
282
298
  continue
283
-
284
299
  # If none matched, this is an unsupported edge record
285
300
  raise GraphCompileError(
286
301
  "Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
287
302
  "(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
288
303
  )
289
-
290
304
  # reachability from start
291
305
  seen: set[str] = set()
292
306
  stack = [start_name]
@@ -296,11 +310,19 @@ class Graph:
296
310
  continue
297
311
  seen.add(cur)
298
312
  stack.extend(adj.get(cur, ()))
299
-
300
313
  unreachable = sorted(set(nodes.keys()) - seen)
301
314
  if unreachable:
302
315
  raise GraphCompileError(f"Unreachable nodes: {unreachable}")
303
-
316
+ # optional cycle detection (strict mode)
317
+ if strict:
318
+ preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
319
+ for s, ds in adj.items():
320
+ for d in ds:
321
+ preds[d].add(s)
322
+ try:
323
+ list(TopologicalSorter(preds).static_order())
324
+ except CycleError as e:
325
+ raise GraphCompileError("cycle detected in graph (strict mode)") from e
304
326
  return self
305
327
 
306
328
  def _validate_output(self, output: dict[str, Any], node_name: str):
@@ -310,57 +332,149 @@ class Graph:
310
332
  f"Instead, got {type(output)} from '{output}'."
311
333
  )
312
334
 
313
- async def execute(self, input, callback: Callback | None = None):
314
- # TODO JB: may need to add init callback here to init the queue on every new execution
315
- if callback is None:
316
- callback = self.get_callback()
317
-
318
- # Type assertion to tell the type checker that start_node is not None after compile()
319
- assert self.start_node is not None, "Graph must be compiled before execution"
320
- current_node = self.start_node
335
+ def _get_next_nodes(self, current_node: NodeProtocol) -> list[NodeProtocol]:
336
+ """Return all successor nodes for the given node."""
337
+ # we skip JoinEdge because they don't have regular destinations
338
+ # and next nodes for those are handled in the execute method
339
+ return [
340
+ edge.destination
341
+ for edge in self.edges
342
+ if isinstance(edge, Edge) and edge.source == current_node
343
+ ]
344
+
345
+ async def _invoke_node(
346
+ self, node: NodeProtocol, state: dict[str, Any], callback: Callback
347
+ ):
348
+ if node.stream:
349
+ await callback.start_node(node_name=node.name)
350
+ output = await node.invoke(input=state, callback=callback, state=self.state)
351
+ self._validate_output(output=output, node_name=node.name)
352
+ await callback.end_node(node_name=node.name)
353
+ else:
354
+ output = await node.invoke(input=state, state=self.state)
355
+ self._validate_output(output=output, node_name=node.name)
356
+ return output
321
357
 
322
- state = input
323
- # Don't reset the graph state if it was initialized with initial_state
324
- steps = 0
358
+ async def _execute_branch(
359
+ self,
360
+ current_node: NodeProtocol,
361
+ state: dict[str, Any],
362
+ callback: Callback,
363
+ steps: int,
364
+ stop_at_join: bool = False,
365
+ ):
366
+ """Recursively execute a branch starting from `current_node`.
367
+ When a node has multiple successors, run them concurrently and merge their outputs."""
325
368
  while True:
326
- # we invoke the node here
327
- if current_node.stream:
328
- # add callback tokens and param here if we are streaming
329
- await callback.start_node(node_name=current_node.name)
330
- # Include graph's internal state in the node execution context
331
- output = await current_node.invoke(
332
- input=state, callback=callback, state=self.state
333
- )
334
- self._validate_output(output=output, node_name=current_node.name)
335
- await callback.end_node(node_name=current_node.name)
336
- else:
337
- # Include graph's internal state in the node execution context
338
- output = await current_node.invoke(input=state, state=self.state)
339
- self._validate_output(output=output, node_name=current_node.name)
340
- # add output to state
341
- state = {**state, **output}
369
+ output = await self._invoke_node(current_node, state, callback)
370
+ state = {**state, **output} # merge node output into local state
342
371
  if current_node.is_end:
343
- # finish loop if this was an end node
344
372
  break
345
373
  if current_node.is_router:
346
- # if we have a router node we let the router decide the next node
347
374
  next_node_name = str(output["choice"])
348
375
  del output["choice"]
349
376
  current_node = self._get_node_by_name(node_name=next_node_name)
377
+ continue
378
+ if stop_at_join and current_node in self.join_nodes:
379
+ # for parallel branches, wait at JoinEdge until all branches are complete
380
+ return state
381
+
382
+ next_nodes = self._get_next_nodes(current_node)
383
+ if not next_nodes:
384
+ raise Exception(
385
+ f"No outgoing edge found for current node '{current_node.name}'."
386
+ )
387
+ if len(next_nodes) == 1:
388
+ current_node = next_nodes[0]
350
389
  else:
351
- # otherwise, we have linear path
352
- current_node = self._get_next_node(current_node=current_node)
390
+ # Run each branch concurrently
391
+ results = await asyncio.gather(
392
+ *[
393
+ self._execute_branch(
394
+ current_node=n,
395
+ state=state.copy(),
396
+ callback=callback,
397
+ steps=steps + 1,
398
+ stop_at_join=True, # force parallel branches to wait at JoinEdge
399
+ )
400
+ for n in next_nodes
401
+ ]
402
+ )
403
+ # merge states returned by each branch
404
+ merged = state.copy()
405
+ for res in results:
406
+ for k, v in res.items():
407
+ if k != "callback":
408
+ merged[k] = v
409
+ if set(next_nodes) & self.join_nodes:
410
+ # if any of the next nodes are join nodes, we need to continue from the
411
+ # JoinEdge.destination node
412
+ join_edge = next(
413
+ (
414
+ e for e in self.edges if isinstance(e, JoinEdge)
415
+ and any(n in e.sources for n in next_nodes)
416
+ ),
417
+ None
418
+ )
419
+ if not join_edge:
420
+ raise Exception("No JoinEdge found for next_nodes")
421
+ # set current_node (for next iteration) to the JoinEdge.destination
422
+ current_node = join_edge.destination
423
+ # continue to the destination node with our merged state
424
+ state = merged
425
+ continue
426
+ else:
427
+ # if this happens we have multiple branches that do not join so we
428
+ # can just return the merged states
429
+ return merged
353
430
  steps += 1
354
431
  if steps >= self.max_steps:
355
432
  raise Exception(
356
- f"Max steps reached: {self.max_steps}. You can modify this "
357
- "by setting `max_steps` when initializing the Graph object."
433
+ f"Max steps reached: {self.max_steps}. You can modify this by setting `max_steps` when initializing the Graph object."
358
434
  )
435
+ return state
436
+
437
+ async def execute(self, input: dict[str, Any], callback: Callback | None = None):
438
+ # TODO JB: may need to add init callback here to init the queue on every new execution
439
+ if callback is None:
440
+ callback = self.get_callback()
441
+
442
+ # Type assertion to tell the type checker that start_node is not None after compile()
443
+ assert self.start_node is not None, "Graph must be compiled before execution"
444
+
445
+ state = input
446
+ result = await self._execute_branch(self.start_node, state, callback, 0)
359
447
  # TODO JB: may need to add end callback here to close the queue for every execution
360
- if callback and "callback" in state:
448
+ if callback and "callback" in result:
361
449
  await callback.close()
362
- del state["callback"]
363
- return state
450
+ del result["callback"]
451
+ return result
452
+
453
+ async def execute_many(
454
+ self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
455
+ ) -> list[Any]:
456
+ """Execute the graph on many inputs concurrently.
457
+
458
+ :param inputs: An iterable of input dicts to feed into the graph.
459
+ :type inputs: Iterable[dict]
460
+ :param concurrency: Maximum number of graph executions to run at once.
461
+ :type concurrency: int
462
+ :param state: Optional shared state to pass to each execution.
463
+ If you want isolated state per execution, pass None
464
+ and the graph's normal semantics will apply.
465
+ :type state: Optional[Any]
466
+ :returns: The list of results in the same order as ``inputs``.
467
+ :rtype: list[Any]
468
+ """
469
+
470
+ sem = asyncio.Semaphore(concurrency)
471
+
472
+ async def _run_one(inp: dict[str, Any]) -> Any:
473
+ async with sem:
474
+ return await self.execute(input=inp)
475
+
476
+ tasks = [asyncio.create_task(_run_one(i)) for i in inputs]
477
+ return await asyncio.gather(*tasks)
364
478
 
365
479
  def get_callback(self):
366
480
  """Get a new instance of the callback class.
@@ -400,12 +514,46 @@ class Graph:
400
514
 
401
515
  def _get_next_node(self, current_node):
402
516
  for edge in self.edges:
403
- if edge.source == current_node:
517
+ if isinstance(edge, Edge) and edge.source == current_node:
404
518
  return edge.destination
519
+ # we skip JoinEdge because they don't have regular destinations
520
+ # and next nodes for those are handled in the execute method
405
521
  raise Exception(
406
522
  f"No outgoing edge found for current node '{current_node.name}'."
407
523
  )
408
524
 
525
+ def add_parallel(
526
+ self, source: NodeProtocol | str, destinations: list[NodeProtocol | str]
527
+ ):
528
+ """Add multiple outgoing edges from a single source node to be executed in parallel.
529
+
530
+ Args:
531
+ source: The source node for the parallel branches.
532
+ destinations: The list of destination nodes for the parallel branches.
533
+ """
534
+ for dest in destinations:
535
+ self.add_edge(source, dest)
536
+ return self
537
+
538
+ def add_join(
539
+ self, sources: list[NodeProtocol | str], destination: NodeProtocol | str
540
+ ):
541
+ """Joins multiple parallel branches into a single branch.
542
+
543
+ Args:
544
+ sources: The list of source nodes for the join.
545
+ destination: The destination node for the join.
546
+ """
547
+ # get source nodes from graph
548
+ source_nodes = [self._get_node(node_candidate=source) for source in sources]
549
+ # get destination node from graph
550
+ destination_node = self._get_node(node_candidate=destination)
551
+ # create join edge
552
+ edge = JoinEdge(source_nodes, destination_node)
553
+ self.edges.append(edge)
554
+ self.join_nodes.update(source_nodes)
555
+ return self
556
+
409
557
  def visualize(self, *, save_path: str | None = None):
410
558
  """Render the current graph. If matplotlib is not installed,
411
559
  raise a helpful error telling users to install the viz extra.
@@ -501,3 +649,8 @@ class Edge:
501
649
  def __init__(self, source, destination):
502
650
  self.source = source
503
651
  self.destination = destination
652
+
653
+ class JoinEdge:
654
+ def __init__(self, sources, destination):
655
+ self.sources = sources
656
+ self.destination = destination
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.9rc3
3
+ Version: 0.0.10rc2
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "graphai-lib"
3
- version = "0.0.9rc3"
3
+ version = "0.0.10rc2"
4
4
  description = "Not an AI framework"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10,<3.14"