graphai-lib 0.0.9rc3__py3-none-any.whl → 0.0.10rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphai/callback.py +17 -0
- graphai/graph.py +238 -85
- {graphai_lib-0.0.9rc3.dist-info → graphai_lib-0.0.10rc2.dist-info}/METADATA +1 -1
- graphai_lib-0.0.10rc2.dist-info/RECORD +11 -0
- graphai_lib-0.0.9rc3.dist-info/RECORD +0 -11
- {graphai_lib-0.0.9rc3.dist-info → graphai_lib-0.0.10rc2.dist-info}/WHEEL +0 -0
- {graphai_lib-0.0.9rc3.dist-info → graphai_lib-0.0.10rc2.dist-info}/top_level.txt +0 -0
graphai/callback.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from enum import Enum
|
4
|
+
import json
|
4
5
|
from pydantic import Field
|
5
6
|
from typing import Any
|
6
7
|
from collections.abc import AsyncIterator
|
@@ -43,6 +44,22 @@ class GraphEvent:
|
|
43
44
|
token: str | None = None
|
44
45
|
params: dict[str, Any] | None = None
|
45
46
|
|
47
|
+
def encode(self, charset: str = "utf-8") -> bytes:
|
48
|
+
"""Encodes the event as a JSON string, important for compatability with FastAPI
|
49
|
+
and starlette.
|
50
|
+
|
51
|
+
:param charset: The character set to use for encoding the event.
|
52
|
+
:type charset: str
|
53
|
+
"""
|
54
|
+
event_dict = {
|
55
|
+
"type": self.type.value if hasattr(self.type, "value") else str(self.type),
|
56
|
+
"identifier": self.identifier,
|
57
|
+
"token": self.token,
|
58
|
+
"params": self.params,
|
59
|
+
}
|
60
|
+
data = f"data: {json.dumps(event_dict, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
61
|
+
return data.encode(charset)
|
62
|
+
|
46
63
|
|
47
64
|
class Callback:
|
48
65
|
"""The original callback handler class. Outputs a stream of structured text
|
graphai/graph.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
|
2
|
+
import asyncio
|
3
|
+
from typing import Any, Iterable, Protocol
|
4
|
+
from graphlib import TopologicalSorter, CycleError
|
5
|
+
|
3
6
|
from graphai.callback import Callback
|
4
7
|
from graphai.utils import logger
|
5
8
|
|
@@ -62,22 +65,41 @@ class Graph:
|
|
62
65
|
self.edges: list[Any] = []
|
63
66
|
self.start_node: NodeProtocol | None = None
|
64
67
|
self.end_nodes: list[NodeProtocol] = []
|
68
|
+
self.join_nodes: set[NodeProtocol] = set()
|
65
69
|
self.Callback: type[Callback] = Callback
|
66
70
|
self.max_steps = max_steps
|
67
71
|
self.state = initial_state or {}
|
68
72
|
|
69
73
|
# Allow getting and setting the graph's internal state
|
70
74
|
def get_state(self) -> dict[str, Any]:
|
71
|
-
"""Get the current graph state.
|
75
|
+
"""Get the current graph state.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The current graph state.
|
79
|
+
"""
|
72
80
|
return self.state
|
73
81
|
|
74
82
|
def set_state(self, state: dict[str, Any]) -> Graph:
|
75
|
-
"""Set the graph state.
|
83
|
+
"""Set the graph state.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
state: The new state to set for the graph.
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
The graph instance.
|
90
|
+
"""
|
76
91
|
self.state = state
|
77
92
|
return self
|
78
93
|
|
79
94
|
def update_state(self, values: dict[str, Any]) -> Graph:
|
80
|
-
"""Update the graph state with new values.
|
95
|
+
"""Update the graph state with new values.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
values: The new values to update the graph state with.
|
99
|
+
|
100
|
+
Returns:
|
101
|
+
The graph instance.
|
102
|
+
"""
|
81
103
|
self.state.update(values)
|
82
104
|
return self
|
83
105
|
|
@@ -87,6 +109,14 @@ class Graph:
|
|
87
109
|
return self
|
88
110
|
|
89
111
|
def add_node(self, node: NodeProtocol) -> Graph:
|
112
|
+
"""Adds a node to the graph.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
node: The node to add to the graph.
|
116
|
+
|
117
|
+
Raises:
|
118
|
+
Exception: If a node with the same name already exists in the graph.
|
119
|
+
"""
|
90
120
|
if node.name in self.nodes:
|
91
121
|
raise Exception(f"Node with name '{node.name}' already exists.")
|
92
122
|
self.nodes[node.name] = node
|
@@ -102,6 +132,18 @@ class Graph:
|
|
102
132
|
self.end_nodes.append(node)
|
103
133
|
return self
|
104
134
|
|
135
|
+
def _get_node(self, node_candidate: NodeProtocol | str) -> NodeProtocol:
|
136
|
+
# first get node from graph
|
137
|
+
if isinstance(node_candidate, str):
|
138
|
+
node = self.nodes.get(node_candidate)
|
139
|
+
else:
|
140
|
+
# check if it's a node-like object by looking for required attributes
|
141
|
+
if hasattr(node_candidate, "name"):
|
142
|
+
node = self.nodes.get(node_candidate.name)
|
143
|
+
if node is None:
|
144
|
+
raise ValueError(f"Node with name '{node_candidate}' not found.")
|
145
|
+
return node
|
146
|
+
|
105
147
|
def add_edge(
|
106
148
|
self, source: NodeProtocol | str, destination: NodeProtocol | str
|
107
149
|
) -> Graph:
|
@@ -113,33 +155,10 @@ class Graph:
|
|
113
155
|
"""
|
114
156
|
source_node, destination_node = None, None
|
115
157
|
# get source node from graph
|
116
|
-
|
117
|
-
if isinstance(source, str):
|
118
|
-
source_node = self.nodes.get(source)
|
119
|
-
source_name = source
|
120
|
-
else:
|
121
|
-
# Check if it's a node-like object by looking for required attributes
|
122
|
-
if hasattr(source, "name"):
|
123
|
-
source_node = self.nodes.get(source.name)
|
124
|
-
source_name = source.name
|
125
|
-
else:
|
126
|
-
source_name = str(source)
|
127
|
-
if source_node is None:
|
128
|
-
raise ValueError(f"Node with name '{source_name}' not found.")
|
158
|
+
source_node = self._get_node(node_candidate=source)
|
129
159
|
# get destination node from graph
|
130
|
-
|
131
|
-
|
132
|
-
destination_node = self.nodes.get(destination)
|
133
|
-
destination_name = destination
|
134
|
-
else:
|
135
|
-
# Check if it's a node-like object by looking for required attributes
|
136
|
-
if hasattr(destination, "name"):
|
137
|
-
destination_node = self.nodes.get(destination.name)
|
138
|
-
destination_name = destination.name
|
139
|
-
else:
|
140
|
-
destination_name = str(destination)
|
141
|
-
if destination_node is None:
|
142
|
-
raise ValueError(f"Node with name '{destination_name}' not found.")
|
160
|
+
destination_node = self._get_node(node_candidate=destination)
|
161
|
+
# create edge
|
143
162
|
edge = Edge(source_node, destination_node)
|
144
163
|
self.edges.append(edge)
|
145
164
|
return self
|
@@ -150,6 +169,14 @@ class Graph:
|
|
150
169
|
router: NodeProtocol,
|
151
170
|
destinations: list[NodeProtocol],
|
152
171
|
) -> Graph:
|
172
|
+
"""Adds a router node, allowing for a decision to be made on which branch to
|
173
|
+
follow based on the `choice` output of the router node.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
sources: The list of source nodes for the router.
|
177
|
+
router: The router node.
|
178
|
+
destinations: The list of destination nodes for the router.
|
179
|
+
"""
|
153
180
|
if not router.is_router:
|
154
181
|
raise TypeError("A router object must be passed to the router parameter.")
|
155
182
|
[self.add_edge(source, router) for source in sources]
|
@@ -165,21 +192,19 @@ class Graph:
|
|
165
192
|
self.end_node = node
|
166
193
|
return self
|
167
194
|
|
168
|
-
def compile(self) ->
|
169
|
-
"""
|
170
|
-
Validate the graph:
|
195
|
+
def compile(self, *, strict: bool = False) -> Graph:
|
196
|
+
"""Validate the graph:
|
171
197
|
- exactly one start node present (or Graph.start_node set)
|
172
198
|
- at least one end node present
|
173
199
|
- all edges reference known nodes
|
174
|
-
- no cycles
|
175
200
|
- all nodes reachable from the start
|
201
|
+
(optional) **no cycles** when strict=True
|
176
202
|
Returns self on success; raises GraphCompileError otherwise.
|
177
203
|
"""
|
178
204
|
# nodes map
|
179
205
|
nodes = getattr(self, "nodes", None)
|
180
206
|
if not isinstance(nodes, dict) or not nodes:
|
181
207
|
raise GraphCompileError("No nodes have been added to the graph")
|
182
|
-
|
183
208
|
start_name: str | None = None
|
184
209
|
# Bind and narrow the attribute for mypy
|
185
210
|
start_node: _HasName | None = getattr(self, "start_node", None)
|
@@ -195,21 +220,17 @@ class Graph:
|
|
195
220
|
raise GraphCompileError(f"Multiple start nodes defined: {starts}")
|
196
221
|
if len(starts) == 1:
|
197
222
|
start_name = starts[0]
|
198
|
-
|
199
223
|
if not start_name:
|
200
224
|
raise GraphCompileError("No start node defined")
|
201
|
-
|
202
225
|
# at least one end node
|
203
226
|
if not any(
|
204
227
|
getattr(n, "is_end", False) or getattr(n, "end", False)
|
205
228
|
for n in nodes.values()
|
206
229
|
):
|
207
230
|
raise GraphCompileError("No end node defined")
|
208
|
-
|
209
231
|
# normalize edges into adjacency {src: set(dst)}
|
210
232
|
raw_edges = getattr(self, "edges", None)
|
211
233
|
adj: dict[str, set[str]] = {name: set() for name in nodes.keys()}
|
212
|
-
|
213
234
|
def _add_edge(src: str, dst: str) -> None:
|
214
235
|
if src not in nodes:
|
215
236
|
raise GraphCompileError(f"Edge references unknown source node: {src}")
|
@@ -218,7 +239,6 @@ class Graph:
|
|
218
239
|
f"Edge from {src} references unknown node(s): ['{dst}']"
|
219
240
|
)
|
220
241
|
adj[src].add(dst)
|
221
|
-
|
222
242
|
if raw_edges is None:
|
223
243
|
pass
|
224
244
|
elif isinstance(raw_edges, dict):
|
@@ -238,13 +258,11 @@ class Graph:
|
|
238
258
|
iterator = iter(raw_edges)
|
239
259
|
except TypeError:
|
240
260
|
raise GraphCompileError("Internal edge map has unsupported type")
|
241
|
-
|
242
261
|
for item in iterator:
|
243
262
|
# (src, dst) OR (src, Iterable[dst])
|
244
263
|
if isinstance(item, (tuple, list)) and len(item) == 2:
|
245
264
|
raw_src, rhs = item
|
246
265
|
src = _require_name(raw_src, "source")
|
247
|
-
|
248
266
|
if isinstance(rhs, str) or getattr(rhs, "name", None):
|
249
267
|
dst = _require_name(rhs, "destination")
|
250
268
|
_add_edge(src, rhs)
|
@@ -259,7 +277,6 @@ class Graph:
|
|
259
277
|
"Edge tuple second item must be a destination or an iterable of destinations"
|
260
278
|
)
|
261
279
|
continue
|
262
|
-
|
263
280
|
# Mapping-style: {"source": "...", "destination": "..."} or {"src": "...", "dst": "..."}
|
264
281
|
if isinstance(item, dict):
|
265
282
|
src = _require_name(item.get("source", item.get("src")), "source")
|
@@ -268,7 +285,6 @@ class Graph:
|
|
268
285
|
)
|
269
286
|
_add_edge(src, dst)
|
270
287
|
continue
|
271
|
-
|
272
288
|
# Object with attributes .source/.destination (or .src/.dst)
|
273
289
|
if hasattr(item, "source") or hasattr(item, "src"):
|
274
290
|
src = _require_name(
|
@@ -280,13 +296,11 @@ class Graph:
|
|
280
296
|
)
|
281
297
|
_add_edge(src, dst)
|
282
298
|
continue
|
283
|
-
|
284
299
|
# If none matched, this is an unsupported edge record
|
285
300
|
raise GraphCompileError(
|
286
301
|
"Edges must be dict[str, Iterable[str]] or an iterable of (src, dst), "
|
287
302
|
"(src, Iterable[dst]), mapping{'source'/'destination'}, or objects with .source/.destination"
|
288
303
|
)
|
289
|
-
|
290
304
|
# reachability from start
|
291
305
|
seen: set[str] = set()
|
292
306
|
stack = [start_name]
|
@@ -296,11 +310,19 @@ class Graph:
|
|
296
310
|
continue
|
297
311
|
seen.add(cur)
|
298
312
|
stack.extend(adj.get(cur, ()))
|
299
|
-
|
300
313
|
unreachable = sorted(set(nodes.keys()) - seen)
|
301
314
|
if unreachable:
|
302
315
|
raise GraphCompileError(f"Unreachable nodes: {unreachable}")
|
303
|
-
|
316
|
+
# optional cycle detection (strict mode)
|
317
|
+
if strict:
|
318
|
+
preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
|
319
|
+
for s, ds in adj.items():
|
320
|
+
for d in ds:
|
321
|
+
preds[d].add(s)
|
322
|
+
try:
|
323
|
+
list(TopologicalSorter(preds).static_order())
|
324
|
+
except CycleError as e:
|
325
|
+
raise GraphCompileError("cycle detected in graph (strict mode)") from e
|
304
326
|
return self
|
305
327
|
|
306
328
|
def _validate_output(self, output: dict[str, Any], node_name: str):
|
@@ -310,57 +332,149 @@ class Graph:
|
|
310
332
|
f"Instead, got {type(output)} from '{output}'."
|
311
333
|
)
|
312
334
|
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
335
|
+
def _get_next_nodes(self, current_node: NodeProtocol) -> list[NodeProtocol]:
|
336
|
+
"""Return all successor nodes for the given node."""
|
337
|
+
# we skip JoinEdge because they don't have regular destinations
|
338
|
+
# and next nodes for those are handled in the execute method
|
339
|
+
return [
|
340
|
+
edge.destination
|
341
|
+
for edge in self.edges
|
342
|
+
if isinstance(edge, Edge) and edge.source == current_node
|
343
|
+
]
|
344
|
+
|
345
|
+
async def _invoke_node(
|
346
|
+
self, node: NodeProtocol, state: dict[str, Any], callback: Callback
|
347
|
+
):
|
348
|
+
if node.stream:
|
349
|
+
await callback.start_node(node_name=node.name)
|
350
|
+
output = await node.invoke(input=state, callback=callback, state=self.state)
|
351
|
+
self._validate_output(output=output, node_name=node.name)
|
352
|
+
await callback.end_node(node_name=node.name)
|
353
|
+
else:
|
354
|
+
output = await node.invoke(input=state, state=self.state)
|
355
|
+
self._validate_output(output=output, node_name=node.name)
|
356
|
+
return output
|
321
357
|
|
322
|
-
|
323
|
-
|
324
|
-
|
358
|
+
async def _execute_branch(
|
359
|
+
self,
|
360
|
+
current_node: NodeProtocol,
|
361
|
+
state: dict[str, Any],
|
362
|
+
callback: Callback,
|
363
|
+
steps: int,
|
364
|
+
stop_at_join: bool = False,
|
365
|
+
):
|
366
|
+
"""Recursively execute a branch starting from `current_node`.
|
367
|
+
When a node has multiple successors, run them concurrently and merge their outputs."""
|
325
368
|
while True:
|
326
|
-
|
327
|
-
|
328
|
-
# add callback tokens and param here if we are streaming
|
329
|
-
await callback.start_node(node_name=current_node.name)
|
330
|
-
# Include graph's internal state in the node execution context
|
331
|
-
output = await current_node.invoke(
|
332
|
-
input=state, callback=callback, state=self.state
|
333
|
-
)
|
334
|
-
self._validate_output(output=output, node_name=current_node.name)
|
335
|
-
await callback.end_node(node_name=current_node.name)
|
336
|
-
else:
|
337
|
-
# Include graph's internal state in the node execution context
|
338
|
-
output = await current_node.invoke(input=state, state=self.state)
|
339
|
-
self._validate_output(output=output, node_name=current_node.name)
|
340
|
-
# add output to state
|
341
|
-
state = {**state, **output}
|
369
|
+
output = await self._invoke_node(current_node, state, callback)
|
370
|
+
state = {**state, **output} # merge node output into local state
|
342
371
|
if current_node.is_end:
|
343
|
-
# finish loop if this was an end node
|
344
372
|
break
|
345
373
|
if current_node.is_router:
|
346
|
-
# if we have a router node we let the router decide the next node
|
347
374
|
next_node_name = str(output["choice"])
|
348
375
|
del output["choice"]
|
349
376
|
current_node = self._get_node_by_name(node_name=next_node_name)
|
377
|
+
continue
|
378
|
+
if stop_at_join and current_node in self.join_nodes:
|
379
|
+
# for parallel branches, wait at JoinEdge until all branches are complete
|
380
|
+
return state
|
381
|
+
|
382
|
+
next_nodes = self._get_next_nodes(current_node)
|
383
|
+
if not next_nodes:
|
384
|
+
raise Exception(
|
385
|
+
f"No outgoing edge found for current node '{current_node.name}'."
|
386
|
+
)
|
387
|
+
if len(next_nodes) == 1:
|
388
|
+
current_node = next_nodes[0]
|
350
389
|
else:
|
351
|
-
#
|
352
|
-
|
390
|
+
# Run each branch concurrently
|
391
|
+
results = await asyncio.gather(
|
392
|
+
*[
|
393
|
+
self._execute_branch(
|
394
|
+
current_node=n,
|
395
|
+
state=state.copy(),
|
396
|
+
callback=callback,
|
397
|
+
steps=steps + 1,
|
398
|
+
stop_at_join=True, # force parallel branches to wait at JoinEdge
|
399
|
+
)
|
400
|
+
for n in next_nodes
|
401
|
+
]
|
402
|
+
)
|
403
|
+
# merge states returned by each branch
|
404
|
+
merged = state.copy()
|
405
|
+
for res in results:
|
406
|
+
for k, v in res.items():
|
407
|
+
if k != "callback":
|
408
|
+
merged[k] = v
|
409
|
+
if set(next_nodes) & self.join_nodes:
|
410
|
+
# if any of the next nodes are join nodes, we need to continue from the
|
411
|
+
# JoinEdge.destination node
|
412
|
+
join_edge = next(
|
413
|
+
(
|
414
|
+
e for e in self.edges if isinstance(e, JoinEdge)
|
415
|
+
and any(n in e.sources for n in next_nodes)
|
416
|
+
),
|
417
|
+
None
|
418
|
+
)
|
419
|
+
if not join_edge:
|
420
|
+
raise Exception("No JoinEdge found for next_nodes")
|
421
|
+
# set current_node (for next iteration) to the JoinEdge.destination
|
422
|
+
current_node = join_edge.destination
|
423
|
+
# continue to the destination node with our merged state
|
424
|
+
state = merged
|
425
|
+
continue
|
426
|
+
else:
|
427
|
+
# if this happens we have multiple branches that do not join so we
|
428
|
+
# can just return the merged states
|
429
|
+
return merged
|
353
430
|
steps += 1
|
354
431
|
if steps >= self.max_steps:
|
355
432
|
raise Exception(
|
356
|
-
f"Max steps reached: {self.max_steps}. You can modify this "
|
357
|
-
"by setting `max_steps` when initializing the Graph object."
|
433
|
+
f"Max steps reached: {self.max_steps}. You can modify this by setting `max_steps` when initializing the Graph object."
|
358
434
|
)
|
435
|
+
return state
|
436
|
+
|
437
|
+
async def execute(self, input: dict[str, Any], callback: Callback | None = None):
|
438
|
+
# TODO JB: may need to add init callback here to init the queue on every new execution
|
439
|
+
if callback is None:
|
440
|
+
callback = self.get_callback()
|
441
|
+
|
442
|
+
# Type assertion to tell the type checker that start_node is not None after compile()
|
443
|
+
assert self.start_node is not None, "Graph must be compiled before execution"
|
444
|
+
|
445
|
+
state = input
|
446
|
+
result = await self._execute_branch(self.start_node, state, callback, 0)
|
359
447
|
# TODO JB: may need to add end callback here to close the queue for every execution
|
360
|
-
if callback and "callback" in
|
448
|
+
if callback and "callback" in result:
|
361
449
|
await callback.close()
|
362
|
-
del
|
363
|
-
return
|
450
|
+
del result["callback"]
|
451
|
+
return result
|
452
|
+
|
453
|
+
async def execute_many(
|
454
|
+
self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
|
455
|
+
) -> list[Any]:
|
456
|
+
"""Execute the graph on many inputs concurrently.
|
457
|
+
|
458
|
+
:param inputs: An iterable of input dicts to feed into the graph.
|
459
|
+
:type inputs: Iterable[dict]
|
460
|
+
:param concurrency: Maximum number of graph executions to run at once.
|
461
|
+
:type concurrency: int
|
462
|
+
:param state: Optional shared state to pass to each execution.
|
463
|
+
If you want isolated state per execution, pass None
|
464
|
+
and the graph's normal semantics will apply.
|
465
|
+
:type state: Optional[Any]
|
466
|
+
:returns: The list of results in the same order as ``inputs``.
|
467
|
+
:rtype: list[Any]
|
468
|
+
"""
|
469
|
+
|
470
|
+
sem = asyncio.Semaphore(concurrency)
|
471
|
+
|
472
|
+
async def _run_one(inp: dict[str, Any]) -> Any:
|
473
|
+
async with sem:
|
474
|
+
return await self.execute(input=inp)
|
475
|
+
|
476
|
+
tasks = [asyncio.create_task(_run_one(i)) for i in inputs]
|
477
|
+
return await asyncio.gather(*tasks)
|
364
478
|
|
365
479
|
def get_callback(self):
|
366
480
|
"""Get a new instance of the callback class.
|
@@ -400,12 +514,46 @@ class Graph:
|
|
400
514
|
|
401
515
|
def _get_next_node(self, current_node):
|
402
516
|
for edge in self.edges:
|
403
|
-
if edge.source == current_node:
|
517
|
+
if isinstance(edge, Edge) and edge.source == current_node:
|
404
518
|
return edge.destination
|
519
|
+
# we skip JoinEdge because they don't have regular destinations
|
520
|
+
# and next nodes for those are handled in the execute method
|
405
521
|
raise Exception(
|
406
522
|
f"No outgoing edge found for current node '{current_node.name}'."
|
407
523
|
)
|
408
524
|
|
525
|
+
def add_parallel(
|
526
|
+
self, source: NodeProtocol | str, destinations: list[NodeProtocol | str]
|
527
|
+
):
|
528
|
+
"""Add multiple outgoing edges from a single source node to be executed in parallel.
|
529
|
+
|
530
|
+
Args:
|
531
|
+
source: The source node for the parallel branches.
|
532
|
+
destinations: The list of destination nodes for the parallel branches.
|
533
|
+
"""
|
534
|
+
for dest in destinations:
|
535
|
+
self.add_edge(source, dest)
|
536
|
+
return self
|
537
|
+
|
538
|
+
def add_join(
|
539
|
+
self, sources: list[NodeProtocol | str], destination: NodeProtocol | str
|
540
|
+
):
|
541
|
+
"""Joins multiple parallel branches into a single branch.
|
542
|
+
|
543
|
+
Args:
|
544
|
+
sources: The list of source nodes for the join.
|
545
|
+
destination: The destination node for the join.
|
546
|
+
"""
|
547
|
+
# get source nodes from graph
|
548
|
+
source_nodes = [self._get_node(node_candidate=source) for source in sources]
|
549
|
+
# get destination node from graph
|
550
|
+
destination_node = self._get_node(node_candidate=destination)
|
551
|
+
# create join edge
|
552
|
+
edge = JoinEdge(source_nodes, destination_node)
|
553
|
+
self.edges.append(edge)
|
554
|
+
self.join_nodes.update(source_nodes)
|
555
|
+
return self
|
556
|
+
|
409
557
|
def visualize(self, *, save_path: str | None = None):
|
410
558
|
"""Render the current graph. If matplotlib is not installed,
|
411
559
|
raise a helpful error telling users to install the viz extra.
|
@@ -501,3 +649,8 @@ class Edge:
|
|
501
649
|
def __init__(self, source, destination):
|
502
650
|
self.source = source
|
503
651
|
self.destination = destination
|
652
|
+
|
653
|
+
class JoinEdge:
|
654
|
+
def __init__(self, sources, destination):
|
655
|
+
self.sources = sources
|
656
|
+
self.destination = destination
|
@@ -0,0 +1,11 @@
|
|
1
|
+
graphai/__init__.py,sha256=UbqXq7iGIYe1GyTPcpgLSXbgWovggFsAbTMtr4JQm3M,160
|
2
|
+
graphai/callback.py,sha256=xQhK1xbjkouemaIB5hjPZpo3thxiA6VVu9TpHslOsww,13851
|
3
|
+
graphai/graph.py,sha256=ZqwQWO3n8rcrdWOlga6YyRZEHSc_82U4Q0pWqmSn7ko,24867
|
4
|
+
graphai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
+
graphai/utils.py,sha256=LIFg-fQalU9sB5DCuk6is48OdpEgNX95i9h-YddFbvM,11717
|
6
|
+
graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
|
7
|
+
graphai/nodes/base.py,sha256=SoKfOdRu5EIJ_z8xIz5zbNXcxPI2l9MKTQDeaQI-2no,7494
|
8
|
+
graphai_lib-0.0.10rc2.dist-info/METADATA,sha256=NYLHcB-zLC9abdWHqCjGDsyJu-YXgBKNRssu8Bu9SFY,914
|
9
|
+
graphai_lib-0.0.10rc2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
10
|
+
graphai_lib-0.0.10rc2.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
|
11
|
+
graphai_lib-0.0.10rc2.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
graphai/__init__.py,sha256=UbqXq7iGIYe1GyTPcpgLSXbgWovggFsAbTMtr4JQm3M,160
|
2
|
-
graphai/callback.py,sha256=NrwArRBHWXvDodIdkDKmCAX3RlF1Zr7dGJsz1anNPgM,13195
|
3
|
-
graphai/graph.py,sha256=O0hZ_29ln8oTYnk5EX5bdhO6-Cr9pm-6IVCq-ZPWbgY,18526
|
4
|
-
graphai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
graphai/utils.py,sha256=LIFg-fQalU9sB5DCuk6is48OdpEgNX95i9h-YddFbvM,11717
|
6
|
-
graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
|
7
|
-
graphai/nodes/base.py,sha256=SoKfOdRu5EIJ_z8xIz5zbNXcxPI2l9MKTQDeaQI-2no,7494
|
8
|
-
graphai_lib-0.0.9rc3.dist-info/METADATA,sha256=9iXKyNdp-EsY0_D5hVTy2eixMmkqnVsFR1G3cD-NvVs,913
|
9
|
-
graphai_lib-0.0.9rc3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
10
|
-
graphai_lib-0.0.9rc3.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
|
11
|
-
graphai_lib-0.0.9rc3.dist-info/RECORD,,
|
File without changes
|
File without changes
|