graphai-lib 0.0.9rc3__tar.gz → 0.0.10rc1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/PKG-INFO +1 -1
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai/callback.py +17 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai/graph.py +42 -3
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/PKG-INFO +1 -1
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/pyproject.toml +1 -1
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/README.md +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai/__init__.py +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai/nodes/__init__.py +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai/nodes/base.py +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai/py.typed +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai/utils.py +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/SOURCES.txt +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/dependency_links.txt +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/requires.txt +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/top_level.txt +0 -0
- {graphai_lib-0.0.9rc3 → graphai_lib-0.0.10rc1}/setup.cfg +0 -0
@@ -1,6 +1,7 @@
|
|
1
1
|
import asyncio
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from enum import Enum
|
4
|
+
import json
|
4
5
|
from pydantic import Field
|
5
6
|
from typing import Any
|
6
7
|
from collections.abc import AsyncIterator
|
@@ -43,6 +44,22 @@ class GraphEvent:
|
|
43
44
|
token: str | None = None
|
44
45
|
params: dict[str, Any] | None = None
|
45
46
|
|
47
|
+
def encode(self, charset: str = "utf-8") -> bytes:
|
48
|
+
"""Encodes the event as a JSON string, important for compatability with FastAPI
|
49
|
+
and starlette.
|
50
|
+
|
51
|
+
:param charset: The character set to use for encoding the event.
|
52
|
+
:type charset: str
|
53
|
+
"""
|
54
|
+
event_dict = {
|
55
|
+
"type": self.type.value if hasattr(self.type, "value") else str(self.type),
|
56
|
+
"identifier": self.identifier,
|
57
|
+
"token": self.token,
|
58
|
+
"params": self.params,
|
59
|
+
}
|
60
|
+
data = f"data: {json.dumps(event_dict, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
61
|
+
return data.encode(charset)
|
62
|
+
|
46
63
|
|
47
64
|
class Callback:
|
48
65
|
"""The original callback handler class. Outputs a stream of structured text
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
|
2
|
+
import asyncio
|
3
|
+
from typing import Any, Iterable, Protocol
|
4
|
+
from graphlib import TopologicalSorter, CycleError
|
3
5
|
from graphai.callback import Callback
|
4
6
|
from graphai.utils import logger
|
5
7
|
|
@@ -165,14 +167,14 @@ class Graph:
|
|
165
167
|
self.end_node = node
|
166
168
|
return self
|
167
169
|
|
168
|
-
def compile(self) ->
|
170
|
+
def compile(self, *, strict: bool = False) -> Graph:
|
169
171
|
"""
|
170
172
|
Validate the graph:
|
171
173
|
- exactly one start node present (or Graph.start_node set)
|
172
174
|
- at least one end node present
|
173
175
|
- all edges reference known nodes
|
174
|
-
- no cycles
|
175
176
|
- all nodes reachable from the start
|
177
|
+
(optional) **no cycles** when strict=True
|
176
178
|
Returns self on success; raises GraphCompileError otherwise.
|
177
179
|
"""
|
178
180
|
# nodes map
|
@@ -301,6 +303,17 @@ class Graph:
|
|
301
303
|
if unreachable:
|
302
304
|
raise GraphCompileError(f"Unreachable nodes: {unreachable}")
|
303
305
|
|
306
|
+
# optional cycle detection (strict mode)
|
307
|
+
if strict:
|
308
|
+
preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
|
309
|
+
for s, ds in adj.items():
|
310
|
+
for d in ds:
|
311
|
+
preds[d].add(s)
|
312
|
+
try:
|
313
|
+
list(TopologicalSorter(preds).static_order())
|
314
|
+
except CycleError as e:
|
315
|
+
raise GraphCompileError("cycle detected in graph (strict mode)") from e
|
316
|
+
|
304
317
|
return self
|
305
318
|
|
306
319
|
def _validate_output(self, output: dict[str, Any], node_name: str):
|
@@ -362,6 +375,32 @@ class Graph:
|
|
362
375
|
del state["callback"]
|
363
376
|
return state
|
364
377
|
|
378
|
+
async def execute_many(
|
379
|
+
self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
|
380
|
+
) -> list[Any]:
|
381
|
+
"""Execute the graph on many inputs concurrently.
|
382
|
+
|
383
|
+
:param inputs: An iterable of input dicts to feed into the graph.
|
384
|
+
:type inputs: Iterable[dict]
|
385
|
+
:param concurrency: Maximum number of graph executions to run at once.
|
386
|
+
:type concurrency: int
|
387
|
+
:param state: Optional shared state to pass to each execution.
|
388
|
+
If you want isolated state per execution, pass None
|
389
|
+
and the graph's normal semantics will apply.
|
390
|
+
:type state: Optional[Any]
|
391
|
+
:returns: The list of results in the same order as ``inputs``.
|
392
|
+
:rtype: list[Any]
|
393
|
+
"""
|
394
|
+
|
395
|
+
sem = asyncio.Semaphore(concurrency)
|
396
|
+
|
397
|
+
async def _run_one(inp: dict[str, Any]) -> Any:
|
398
|
+
async with sem:
|
399
|
+
return await self.execute(input=inp)
|
400
|
+
|
401
|
+
tasks = [asyncio.create_task(_run_one(i)) for i in inputs]
|
402
|
+
return await asyncio.gather(*tasks)
|
403
|
+
|
365
404
|
def get_callback(self):
|
366
405
|
"""Get a new instance of the callback class.
|
367
406
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|