graphai-lib 0.0.10rc1__py3-none-any.whl → 0.0.10rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graphai/graph.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  import asyncio
3
3
  from typing import Any, Iterable, Protocol
4
4
  from graphlib import TopologicalSorter, CycleError
5
+
5
6
  from graphai.callback import Callback
6
7
  from graphai.utils import logger
7
8
 
@@ -64,22 +65,41 @@ class Graph:
64
65
  self.edges: list[Any] = []
65
66
  self.start_node: NodeProtocol | None = None
66
67
  self.end_nodes: list[NodeProtocol] = []
68
+ self.join_nodes: set[NodeProtocol] = set()
67
69
  self.Callback: type[Callback] = Callback
68
70
  self.max_steps = max_steps
69
71
  self.state = initial_state or {}
70
72
 
71
73
  # Allow getting and setting the graph's internal state
72
74
  def get_state(self) -> dict[str, Any]:
73
- """Get the current graph state."""
75
+ """Get the current graph state.
76
+
77
+ Returns:
78
+ The current graph state.
79
+ """
74
80
  return self.state
75
81
 
76
82
  def set_state(self, state: dict[str, Any]) -> Graph:
77
- """Set the graph state."""
83
+ """Set the graph state.
84
+
85
+ Args:
86
+ state: The new state to set for the graph.
87
+
88
+ Returns:
89
+ The graph instance.
90
+ """
78
91
  self.state = state
79
92
  return self
80
93
 
81
94
  def update_state(self, values: dict[str, Any]) -> Graph:
82
- """Update the graph state with new values."""
95
+ """Update the graph state with new values.
96
+
97
+ Args:
98
+ values: The new values to update the graph state with.
99
+
100
+ Returns:
101
+ The graph instance.
102
+ """
83
103
  self.state.update(values)
84
104
  return self
85
105
 
@@ -89,6 +109,14 @@ class Graph:
89
109
  return self
90
110
 
91
111
  def add_node(self, node: NodeProtocol) -> Graph:
112
+ """Adds a node to the graph.
113
+
114
+ Args:
115
+ node: The node to add to the graph.
116
+
117
+ Raises:
118
+ Exception: If a node with the same name already exists in the graph.
119
+ """
92
120
  if node.name in self.nodes:
93
121
  raise Exception(f"Node with name '{node.name}' already exists.")
94
122
  self.nodes[node.name] = node
@@ -104,6 +132,18 @@ class Graph:
104
132
  self.end_nodes.append(node)
105
133
  return self
106
134
 
135
+ def _get_node(self, node_candidate: NodeProtocol | str) -> NodeProtocol:
136
+ # first get node from graph
137
+ if isinstance(node_candidate, str):
138
+ node = self.nodes.get(node_candidate)
139
+ else:
140
+ # check if it's a node-like object by looking for required attributes
141
+ if hasattr(node_candidate, "name"):
142
+ node = self.nodes.get(node_candidate.name)
143
+ if node is None:
144
+ raise ValueError(f"Node with name '{node_candidate}' not found.")
145
+ return node
146
+
107
147
  def add_edge(
108
148
  self, source: NodeProtocol | str, destination: NodeProtocol | str
109
149
  ) -> Graph:
@@ -115,33 +155,10 @@ class Graph:
115
155
  """
116
156
  source_node, destination_node = None, None
117
157
  # get source node from graph
118
- source_name: str
119
- if isinstance(source, str):
120
- source_node = self.nodes.get(source)
121
- source_name = source
122
- else:
123
- # Check if it's a node-like object by looking for required attributes
124
- if hasattr(source, "name"):
125
- source_node = self.nodes.get(source.name)
126
- source_name = source.name
127
- else:
128
- source_name = str(source)
129
- if source_node is None:
130
- raise ValueError(f"Node with name '{source_name}' not found.")
158
+ source_node = self._get_node(node_candidate=source)
131
159
  # get destination node from graph
132
- destination_name: str
133
- if isinstance(destination, str):
134
- destination_node = self.nodes.get(destination)
135
- destination_name = destination
136
- else:
137
- # Check if it's a node-like object by looking for required attributes
138
- if hasattr(destination, "name"):
139
- destination_node = self.nodes.get(destination.name)
140
- destination_name = destination.name
141
- else:
142
- destination_name = str(destination)
143
- if destination_node is None:
144
- raise ValueError(f"Node with name '{destination_name}' not found.")
160
+ destination_node = self._get_node(node_candidate=destination)
161
+ # create edge
145
162
  edge = Edge(source_node, destination_node)
146
163
  self.edges.append(edge)
147
164
  return self
@@ -152,6 +169,14 @@ class Graph:
152
169
  router: NodeProtocol,
153
170
  destinations: list[NodeProtocol],
154
171
  ) -> Graph:
172
+ """Adds a router node, allowing for a decision to be made on which branch to
173
+ follow based on the `choice` output of the router node.
174
+
175
+ Args:
176
+ sources: The list of source nodes for the router.
177
+ router: The router node.
178
+ destinations: The list of destination nodes for the router.
179
+ """
155
180
  if not router.is_router:
156
181
  raise TypeError("A router object must be passed to the router parameter.")
157
182
  [self.add_edge(source, router) for source in sources]
@@ -168,8 +193,7 @@ class Graph:
168
193
  return self
169
194
 
170
195
  def compile(self, *, strict: bool = False) -> Graph:
171
- """
172
- Validate the graph:
196
+ """Validate the graph:
173
197
  - exactly one start node present (or Graph.start_node set)
174
198
  - at least one end node present
175
199
  - all edges reference known nodes
@@ -181,7 +205,6 @@ class Graph:
181
205
  nodes = getattr(self, "nodes", None)
182
206
  if not isinstance(nodes, dict) or not nodes:
183
207
  raise GraphCompileError("No nodes have been added to the graph")
184
-
185
208
  start_name: str | None = None
186
209
  # Bind and narrow the attribute for mypy
187
210
  start_node: _HasName | None = getattr(self, "start_node", None)
@@ -197,21 +220,17 @@ class Graph:
197
220
  raise GraphCompileError(f"Multiple start nodes defined: {starts}")
198
221
  if len(starts) == 1:
199
222
  start_name = starts[0]
200
-
201
223
  if not start_name:
202
224
  raise GraphCompileError("No start node defined")
203
-
204
225
  # at least one end node
205
226
  if not any(
206
227
  getattr(n, "is_end", False) or getattr(n, "end", False)
207
228
  for n in nodes.values()
208
229
  ):
209
230
  raise GraphCompileError("No end node defined")
210
-
211
231
  # normalize edges into adjacency {src: set(dst)}
212
232
  raw_edges = getattr(self, "edges", None)
213
233
  adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
214
-
215
234
  def _add_edge(src: str, dst: str) -> None:
216
235
  if src not in nodes:
217
236
  raise GraphCompileError(f"Edge references unknown source node: {src}")
@@ -220,7 +239,6 @@ class Graph:
220
239
  f"Edge from {src} references unknown node(s): ['{dst}']"
221
240
  )
222
241
  adj[src].add(dst)
223
-
224
242
  if raw_edges is None:
225
243
  pass
226
244
  elif isinstance(raw_edges, dict):
@@ -240,13 +258,11 @@ class Graph:
240
258
  iterator = iter(raw_edges)
241
259
  except TypeError:
242
260
  raise GraphCompileError("Internal edge map has unsupported type")
243
-
244
261
  for item in iterator:
245
262
  # (src, dst) OR (src, Iterable[dst])
246
263
  if isinstance(item, (tuple, list)) and len(item) == 2:
247
264
  raw_src, rhs = item
248
265
  src = _require_name(raw_src, "source")
249
-
250
266
  if isinstance(rhs, str) or getattr(rhs, "name", None):
251
267
  dst = _require_name(rhs, "destination")
252
268
  _add_edge(src, rhs)
@@ -261,7 +277,6 @@ class Graph:
261
277
  "Edge tuple second item must be a destination or an iterable of destinations"
262
278
  )
263
279
  continue
264
-
265
280
  # Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
266
281
  if isinstance(item, dict):
267
282
  src = _require_name(item.get("source", item.get("src")), "source")
@@ -270,7 +285,6 @@ class Graph:
270
285
  )
271
286
  _add_edge(src, dst)
272
287
  continue
273
-
274
288
  # Object with attributes .source/.destination (or .src/.dst)
275
289
  if hasattr(item, "source") or hasattr(item, "src"):
276
290
  src = _require_name(
@@ -282,13 +296,11 @@ class Graph:
282
296
  )
283
297
  _add_edge(src, dst)
284
298
  continue
285
-
286
299
  # If none matched, this is an unsupported edge record
287
300
  raise GraphCompileError(
288
301
  "Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
289
302
  "(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
290
303
  )
291
-
292
304
  # reachability from start
293
305
  seen: set[str] = set()
294
306
  stack = [start_name]
@@ -298,11 +310,9 @@ class Graph:
298
310
  continue
299
311
  seen.add(cur)
300
312
  stack.extend(adj.get(cur, ()))
301
-
302
313
  unreachable = sorted(set(nodes.keys()) - seen)
303
314
  if unreachable:
304
315
  raise GraphCompileError(f"Unreachable nodes: {unreachable}")
305
-
306
316
  # optional cycle detection (strict mode)
307
317
  if strict:
308
318
  preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
@@ -313,7 +323,6 @@ class Graph:
313
323
  list(TopologicalSorter(preds).static_order())
314
324
  except CycleError as e:
315
325
  raise GraphCompileError("cycle detected in graph (strict mode)") from e
316
-
317
326
  return self
318
327
 
319
328
  def _validate_output(self, output: dict[str, Any], node_name: str):
@@ -323,57 +332,123 @@ class Graph:
323
332
  f"Instead, got {type(output)} from '{output}'."
324
333
  )
325
334
 
326
- async def execute(self, input, callback: Callback | None = None):
327
- # TODO JB: may need to add init callback here to init the queue on every new execution
328
- if callback is None:
329
- callback = self.get_callback()
330
-
331
- # Type assertion to tell the type checker that start_node is not None after compile()
332
- assert self.start_node is not None, "Graph must be compiled before execution"
333
- current_node = self.start_node
335
+ def _get_next_nodes(self, current_node: NodeProtocol) -> list[NodeProtocol]:
336
+ """Return all successor nodes for the given node."""
337
+ # we skip JoinEdge because they don't have regular destinations
338
+ # and next nodes for those are handled in the execute method
339
+ return [
340
+ edge.destination
341
+ for edge in self.edges
342
+ if isinstance(edge, Edge) and edge.source == current_node
343
+ ]
344
+
345
+ async def _invoke_node(
346
+ self, node: NodeProtocol, state: dict[str, Any], callback: Callback
347
+ ):
348
+ if node.stream:
349
+ await callback.start_node(node_name=node.name)
350
+ output = await node.invoke(input=state, callback=callback, state=self.state)
351
+ self._validate_output(output=output, node_name=node.name)
352
+ await callback.end_node(node_name=node.name)
353
+ else:
354
+ output = await node.invoke(input=state, state=self.state)
355
+ self._validate_output(output=output, node_name=node.name)
356
+ return output
334
357
 
335
- state = input
336
- # Don't reset the graph state if it was initialized with initial_state
337
- steps = 0
358
+ async def _execute_branch(
359
+ self,
360
+ current_node: NodeProtocol,
361
+ state: dict[str, Any],
362
+ callback: Callback,
363
+ steps: int,
364
+ stop_at_join: bool = False,
365
+ ):
366
+ """Recursively execute a branch starting from `current_node`.
367
+ When a node has multiple successors, run them concurrently and merge their outputs."""
338
368
  while True:
339
- # we invoke the node here
340
- if current_node.stream:
341
- # add callback tokens and param here if we are streaming
342
- await callback.start_node(node_name=current_node.name)
343
- # Include graph's internal state in the node execution context
344
- output = await current_node.invoke(
345
- input=state, callback=callback, state=self.state
346
- )
347
- self._validate_output(output=output, node_name=current_node.name)
348
- await callback.end_node(node_name=current_node.name)
349
- else:
350
- # Include graph's internal state in the node execution context
351
- output = await current_node.invoke(input=state, state=self.state)
352
- self._validate_output(output=output, node_name=current_node.name)
353
- # add output to state
354
- state = {**state, **output}
369
+ output = await self._invoke_node(current_node, state, callback)
370
+ state = {**state, **output} # merge node output into local state
355
371
  if current_node.is_end:
356
- # finish loop if this was an end node
357
372
  break
358
373
  if current_node.is_router:
359
- # if we have a router node we let the router decide the next node
360
374
  next_node_name = str(output["choice"])
361
375
  del output["choice"]
362
376
  current_node = self._get_node_by_name(node_name=next_node_name)
377
+ continue
378
+ if stop_at_join and current_node in self.join_nodes:
379
+ # for parallel branches, wait at JoinEdge until all branches are complete
380
+ return state
381
+
382
+ next_nodes = self._get_next_nodes(current_node)
383
+ if not next_nodes:
384
+ raise Exception(
385
+ f"No outgoing edge found for current node '{current_node.name}'."
386
+ )
387
+ if len(next_nodes) == 1:
388
+ current_node = next_nodes[0]
363
389
  else:
364
- # otherwise, we have linear path
365
- current_node = self._get_next_node(current_node=current_node)
390
+ # Run each branch concurrently
391
+ results = await asyncio.gather(
392
+ *[
393
+ self._execute_branch(
394
+ current_node=n,
395
+ state=state.copy(),
396
+ callback=callback,
397
+ steps=steps + 1,
398
+ stop_at_join=True, # force parallel branches to wait at JoinEdge
399
+ )
400
+ for n in next_nodes
401
+ ]
402
+ )
403
+ # merge states returned by each branch
404
+ merged = state.copy()
405
+ for res in results:
406
+ for k, v in res.items():
407
+ if k != "callback":
408
+ merged[k] = v
409
+ if set(next_nodes) & self.join_nodes:
410
+ # if any of the next nodes are join nodes, we need to continue from the
411
+ # JoinEdge.destination node
412
+ join_edge = next(
413
+ (
414
+ e for e in self.edges if isinstance(e, JoinEdge)
415
+ and any(n in e.sources for n in next_nodes)
416
+ ),
417
+ None
418
+ )
419
+ if not join_edge:
420
+ raise Exception("No JoinEdge found for next_nodes")
421
+ # set current_node (for next iteration) to the JoinEdge.destination
422
+ current_node = join_edge.destination
423
+ # continue to the destination node with our merged state
424
+ state = merged
425
+ continue
426
+ else:
427
+ # if this happens we have multiple branches that do not join so we
428
+ # can just return the merged states
429
+ return merged
366
430
  steps += 1
367
431
  if steps >= self.max_steps:
368
432
  raise Exception(
369
- f"Max steps reached: {self.max_steps}. You can modify this "
370
- "by setting `max_steps` when initializing the Graph object."
433
+ f"Max steps reached: {self.max_steps}. You can modify this by setting `max_steps` when initializing the Graph object."
371
434
  )
435
+ return state
436
+
437
+ async def execute(self, input: dict[str, Any], callback: Callback | None = None):
438
+ # TODO JB: may need to add init callback here to init the queue on every new execution
439
+ if callback is None:
440
+ callback = self.get_callback()
441
+
442
+ # Type assertion to tell the type checker that start_node is not None after compile()
443
+ assert self.start_node is not None, "Graph must be compiled before execution"
444
+
445
+ state = input
446
+ result = await self._execute_branch(self.start_node, state, callback, 0)
372
447
  # TODO JB: may need to add end callback here to close the queue for every execution
373
- if callback and "callback" in state:
448
+ if callback and "callback" in result:
374
449
  await callback.close()
375
- del state["callback"]
376
- return state
450
+ del result["callback"]
451
+ return result
377
452
 
378
453
  async def execute_many(
379
454
  self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
@@ -439,12 +514,46 @@ class Graph:
439
514
 
440
515
  def _get_next_node(self, current_node):
441
516
  for edge in self.edges:
442
- if edge.source == current_node:
517
+ if isinstance(edge, Edge) and edge.source == current_node:
443
518
  return edge.destination
519
+ # we skip JoinEdge because they don't have regular destinations
520
+ # and next nodes for those are handled in the execute method
444
521
  raise Exception(
445
522
  f"No outgoing edge found for current node '{current_node.name}'."
446
523
  )
447
524
 
525
+ def add_parallel(
526
+ self, source: NodeProtocol | str, destinations: list[NodeProtocol | str]
527
+ ):
528
+ """Add multiple outgoing edges from a single source node to be executed in parallel.
529
+
530
+ Args:
531
+ source: The source node for the parallel branches.
532
+ destinations: The list of destination nodes for the parallel branches.
533
+ """
534
+ for dest in destinations:
535
+ self.add_edge(source, dest)
536
+ return self
537
+
538
+ def add_join(
539
+ self, sources: list[NodeProtocol | str], destination: NodeProtocol | str
540
+ ):
541
+ """Joins multiple parallel branches into a single branch.
542
+
543
+ Args:
544
+ sources: The list of source nodes for the join.
545
+ destination: The destination node for the join.
546
+ """
547
+ # get source nodes from graph
548
+ source_nodes = [self._get_node(node_candidate=source) for source in sources]
549
+ # get destination node from graph
550
+ destination_node = self._get_node(node_candidate=destination)
551
+ # create join edge
552
+ edge = JoinEdge(source_nodes, destination_node)
553
+ self.edges.append(edge)
554
+ self.join_nodes.update(source_nodes)
555
+ return self
556
+
448
557
  def visualize(self, *, save_path: str | None = None):
449
558
  """Render the current graph. If matplotlib is not installed,
450
559
  raise a helpful error telling users to install the viz extra.
@@ -540,3 +649,8 @@ class Edge:
540
649
  def __init__(self, source, destination):
541
650
  self.source = source
542
651
  self.destination = destination
652
+
653
+ class JoinEdge:
654
+ def __init__(self, sources, destination):
655
+ self.sources = sources
656
+ self.destination = destination
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.10rc1
3
+ Version: 0.0.10rc2
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -1,11 +1,11 @@
1
1
  graphai/__init__.py,sha256=UbqXq7iGIYe1GyTPcpgLSXbgWovggFsAbTMtr4JQm3M,160
2
2
  graphai/callback.py,sha256=xQhK1xbjkouemaIB5hjPZpo3thxiA6VVu9TpHslOsww,13851
3
- graphai/graph.py,sha256=CSdRYlQnfVHNw4XyKMEZ1rJqqg6rOx-aT8w9xumFhBs,20138
3
+ graphai/graph.py,sha256=ZqwQWO3n8rcrdWOlga6YyRZEHSc_82U4Q0pWqmSn7ko,24867
4
4
  graphai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  graphai/utils.py,sha256=LIFg-fQalU9sB5DCuk6is48OdpEgNX95i9h-YddFbvM,11717
6
6
  graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
7
7
  graphai/nodes/base.py,sha256=SoKfOdRu5EIJ_z8xIz5zbNXcxPI2l9MKTQDeaQI-2no,7494
8
- graphai_lib-0.0.10rc1.dist-info/METADATA,sha256=2sIQkGfW7m0UoOEioVono53cVqpDkxRlxoDVlt2Tl1o,914
9
- graphai_lib-0.0.10rc1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- graphai_lib-0.0.10rc1.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
11
- graphai_lib-0.0.10rc1.dist-info/RECORD,,
8
+ graphai_lib-0.0.10rc2.dist-info/METADATA,sha256=NYLHcB-zLC9abdWHqCjGDsyJu-YXgBKNRssu8Bu9SFY,914
9
+ graphai_lib-0.0.10rc2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ graphai_lib-0.0.10rc2.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
11
+ graphai_lib-0.0.10rc2.dist-info/RECORD,,