graphai-lib 0.0.9rc2__tar.gz → 0.0.10rc1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/PKG-INFO +1 -1
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai/callback.py +159 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai/graph.py +42 -3
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/PKG-INFO +1 -1
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/pyproject.toml +1 -1
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/README.md +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai/__init__.py +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai/nodes/__init__.py +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai/nodes/base.py +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai/py.typed +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai/utils.py +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/SOURCES.txt +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/dependency_links.txt +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/requires.txt +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/graphai_lib.egg-info/top_level.txt +0 -0
- {graphai_lib-0.0.9rc2 → graphai_lib-0.0.10rc1}/setup.cfg +0 -0
@@ -1,13 +1,70 @@
|
|
1
1
|
import asyncio
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from enum import Enum
|
4
|
+
import json
|
2
5
|
from pydantic import Field
|
3
6
|
from typing import Any
|
4
7
|
from collections.abc import AsyncIterator
|
8
|
+
import warnings
|
5
9
|
|
6
10
|
|
7
11
|
log_stream = True
|
8
12
|
|
9
13
|
|
14
|
+
class StrEnum(Enum):
|
15
|
+
def __str__(self) -> str:
|
16
|
+
return str(self.value)
|
17
|
+
|
18
|
+
class GraphEventType(StrEnum):
|
19
|
+
START = "start"
|
20
|
+
END = "end"
|
21
|
+
START_NODE = "start_node"
|
22
|
+
END_NODE = "end_node"
|
23
|
+
CALLBACK = "callback"
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class GraphEvent:
|
27
|
+
"""A graph event emitted for specific graph events such as start node or end node,
|
28
|
+
and used by the callback to emit user-defined events.
|
29
|
+
|
30
|
+
:param type: The type of event, can be start_node, end_node, or callback.
|
31
|
+
:type type: GraphEventType
|
32
|
+
:param identifier: The identifier of the event, this is set typically by a callback
|
33
|
+
handler and can be used to distinguish between different events. For example, a
|
34
|
+
conversation/session ID could be used.
|
35
|
+
:type identifier: str
|
36
|
+
:param token: The token associated with the event, such as LLM streamed output.
|
37
|
+
:type token: str | None
|
38
|
+
:param params: The parameters associated with the event, such as tool call parameters
|
39
|
+
or event metadata.
|
40
|
+
:type params: dict[str, Any] | None
|
41
|
+
"""
|
42
|
+
type: GraphEventType
|
43
|
+
identifier: str
|
44
|
+
token: str | None = None
|
45
|
+
params: dict[str, Any] | None = None
|
46
|
+
|
47
|
+
def encode(self, charset: str = "utf-8") -> bytes:
|
48
|
+
"""Encodes the event as a JSON string, important for compatability with FastAPI
|
49
|
+
and starlette.
|
50
|
+
|
51
|
+
:param charset: The character set to use for encoding the event.
|
52
|
+
:type charset: str
|
53
|
+
"""
|
54
|
+
event_dict = {
|
55
|
+
"type": self.type.value if hasattr(self.type, "value") else str(self.type),
|
56
|
+
"identifier": self.identifier,
|
57
|
+
"token": self.token,
|
58
|
+
"params": self.params,
|
59
|
+
}
|
60
|
+
data = f"data: {json.dumps(event_dict, ensure_ascii=False, separators=(',', ':'))}\n\n"
|
61
|
+
return data.encode(charset)
|
62
|
+
|
63
|
+
|
10
64
|
class Callback:
|
65
|
+
"""The original callback handler class. Outputs a stream of structured text
|
66
|
+
tokens. It is recommended to use the newer `EventCallback` handler instead.
|
67
|
+
"""
|
11
68
|
identifier: str = Field(
|
12
69
|
default="graphai",
|
13
70
|
description=(
|
@@ -67,6 +124,11 @@ class Callback:
|
|
67
124
|
special_token_format: str = "<{identifier}:{token}:{params}>",
|
68
125
|
token_format: str = "{token}",
|
69
126
|
):
|
127
|
+
warnings.warn(
|
128
|
+
"The `Callback` class is deprecated and will be removed in " +
|
129
|
+
"v0.1.0. Use the `EventCallback` class instead.",
|
130
|
+
DeprecationWarning
|
131
|
+
)
|
70
132
|
self.identifier = identifier
|
71
133
|
self.special_token_format = special_token_format
|
72
134
|
self.token_format = token_format
|
@@ -198,3 +260,100 @@ class Callback:
|
|
198
260
|
return self.special_token_format.format(
|
199
261
|
identifier=identifier, token=name, params=params_str
|
200
262
|
)
|
263
|
+
|
264
|
+
|
265
|
+
class EventCallback(Callback):
|
266
|
+
"""The event callback handler class. Outputs a stream of structured text
|
267
|
+
tokens. It is recommended to use the newer `EventCallback` handler instead.
|
268
|
+
"""
|
269
|
+
def __init__(
|
270
|
+
self,
|
271
|
+
identifier: str = "graphai",
|
272
|
+
special_token_format: str | None = None,
|
273
|
+
token_format: str | None = None,
|
274
|
+
):
|
275
|
+
warnings.warn(
|
276
|
+
"The `special_token_format` and `token_format` parameters are " +
|
277
|
+
"deprecated and will be removed in v0.1.0.",
|
278
|
+
DeprecationWarning
|
279
|
+
)
|
280
|
+
if special_token_format is None:
|
281
|
+
special_token_format = "<{identifier}:{token}:{params}>"
|
282
|
+
if token_format is None:
|
283
|
+
token_format = "{token}"
|
284
|
+
super().__init__(identifier, special_token_format, token_format)
|
285
|
+
self.events: list[GraphEvent] = []
|
286
|
+
|
287
|
+
def __call__(self, token: str, node_name: str | None = None):
|
288
|
+
if self._done:
|
289
|
+
raise RuntimeError("Cannot add tokens to a closed stream")
|
290
|
+
self._check_node_name(node_name=node_name)
|
291
|
+
event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
|
292
|
+
# otherwise we just assume node is correct and send token
|
293
|
+
self.queue.put_nowait(event)
|
294
|
+
|
295
|
+
async def acall(self, token: str, node_name: str | None = None):
|
296
|
+
# TODO JB: do we need to have `node_name` param?
|
297
|
+
if self._done:
|
298
|
+
raise RuntimeError("Cannot add tokens to a closed stream")
|
299
|
+
self._check_node_name(node_name=node_name)
|
300
|
+
event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
|
301
|
+
# otherwise we just assume node is correct and send token
|
302
|
+
self.queue.put_nowait(event)
|
303
|
+
|
304
|
+
async def aiter(self) -> AsyncIterator[GraphEvent]: # type: ignore[override]
|
305
|
+
"""Used by receiver to get the tokens from the stream queue. Creates
|
306
|
+
a generator that yields tokens from the queue until the END token is
|
307
|
+
received.
|
308
|
+
"""
|
309
|
+
while True: # Keep going until we see the END token
|
310
|
+
try:
|
311
|
+
if self._done and self.queue.empty():
|
312
|
+
break
|
313
|
+
token = await self.queue.get()
|
314
|
+
yield token
|
315
|
+
self.queue.task_done()
|
316
|
+
if token.type == GraphEventType.END:
|
317
|
+
break
|
318
|
+
except asyncio.CancelledError:
|
319
|
+
break
|
320
|
+
self._done = True # Mark as done after processing all tokens
|
321
|
+
|
322
|
+
async def start_node(self, node_name: str, active: bool = True):
|
323
|
+
"""Starts a new node and emits the start token."""
|
324
|
+
if self._done:
|
325
|
+
raise RuntimeError("Cannot start node on a closed stream")
|
326
|
+
self.current_node_name = node_name
|
327
|
+
if self.first_token:
|
328
|
+
self.first_token = False
|
329
|
+
self.active = active
|
330
|
+
if self.active:
|
331
|
+
token = GraphEvent(type=GraphEventType.START_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
|
332
|
+
self.queue.put_nowait(token)
|
333
|
+
|
334
|
+
async def end_node(self, node_name: str):
|
335
|
+
"""Emits the end token for the current node."""
|
336
|
+
if self._done:
|
337
|
+
raise RuntimeError("Cannot end node on a closed stream")
|
338
|
+
# self.current_node_name = node_name
|
339
|
+
if self.active:
|
340
|
+
token = GraphEvent(type=GraphEventType.END_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
|
341
|
+
self.queue.put_nowait(token)
|
342
|
+
|
343
|
+
async def close(self):
|
344
|
+
"""Close the stream and prevent further tokens from being added.
|
345
|
+
This will send an END token and set the done flag to True.
|
346
|
+
"""
|
347
|
+
if self._done:
|
348
|
+
return
|
349
|
+
end_token = GraphEvent(type=GraphEventType.END, identifier=self.identifier)
|
350
|
+
self._done = True # Set done before putting the end token
|
351
|
+
self.queue.put_nowait(end_token)
|
352
|
+
# Don't wait for queue.join() as it can cause deadlock
|
353
|
+
# The stream will close when aiter processes the END token
|
354
|
+
|
355
|
+
async def _build_special_token(
|
356
|
+
self, name: str, params: dict[str, Any] | None = None
|
357
|
+
):
|
358
|
+
raise NotImplementedError("This method is not implemented for the `EventCallback` class.")
|
359
|
+
|
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
|
2
|
+
import asyncio
|
3
|
+
from typing import Any, Iterable, Protocol
|
4
|
+
from graphlib import TopologicalSorter, CycleError
|
3
5
|
from graphai.callback import Callback
|
4
6
|
from graphai.utils import logger
|
5
7
|
|
@@ -165,14 +167,14 @@ class Graph:
|
|
165
167
|
self.end_node = node
|
166
168
|
return self
|
167
169
|
|
168
|
-
def compile(self) ->
|
170
|
+
def compile(self, *, strict: bool = False) -> Graph:
|
169
171
|
"""
|
170
172
|
Validate the graph:
|
171
173
|
- exactly one start node present (or Graph.start_node set)
|
172
174
|
- at least one end node present
|
173
175
|
- all edges reference known nodes
|
174
|
-
- no cycles
|
175
176
|
- all nodes reachable from the start
|
177
|
+
(optional) **no cycles** when strict=True
|
176
178
|
Returns self on success; raises GraphCompileError otherwise.
|
177
179
|
"""
|
178
180
|
# nodes map
|
@@ -301,6 +303,17 @@ class Graph:
|
|
301
303
|
if unreachable:
|
302
304
|
raise GraphCompileError(f"Unreachable nodes: {unreachable}")
|
303
305
|
|
306
|
+
# optional cycle detection (strict mode)
|
307
|
+
if strict:
|
308
|
+
preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
|
309
|
+
for s, ds in adj.items():
|
310
|
+
for d in ds:
|
311
|
+
preds[d].add(s)
|
312
|
+
try:
|
313
|
+
list(TopologicalSorter(preds).static_order())
|
314
|
+
except CycleError as e:
|
315
|
+
raise GraphCompileError("cycle detected in graph (strict mode)") from e
|
316
|
+
|
304
317
|
return self
|
305
318
|
|
306
319
|
def _validate_output(self, output: dict[str, Any], node_name: str):
|
@@ -362,6 +375,32 @@ class Graph:
|
|
362
375
|
del state["callback"]
|
363
376
|
return state
|
364
377
|
|
378
|
+
async def execute_many(
|
379
|
+
self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
|
380
|
+
) -> list[Any]:
|
381
|
+
"""Execute the graph on many inputs concurrently.
|
382
|
+
|
383
|
+
:param inputs: An iterable of input dicts to feed into the graph.
|
384
|
+
:type inputs: Iterable[dict]
|
385
|
+
:param concurrency: Maximum number of graph executions to run at once.
|
386
|
+
:type concurrency: int
|
387
|
+
:param state: Optional shared state to pass to each execution.
|
388
|
+
If you want isolated state per execution, pass None
|
389
|
+
and the graph's normal semantics will apply.
|
390
|
+
:type state: Optional[Any]
|
391
|
+
:returns: The list of results in the same order as ``inputs``.
|
392
|
+
:rtype: list[Any]
|
393
|
+
"""
|
394
|
+
|
395
|
+
sem = asyncio.Semaphore(concurrency)
|
396
|
+
|
397
|
+
async def _run_one(inp: dict[str, Any]) -> Any:
|
398
|
+
async with sem:
|
399
|
+
return await self.execute(input=inp)
|
400
|
+
|
401
|
+
tasks = [asyncio.create_task(_run_one(i)) for i in inputs]
|
402
|
+
return await asyncio.gather(*tasks)
|
403
|
+
|
365
404
|
def get_callback(self):
|
366
405
|
"""Get a new instance of the callback class.
|
367
406
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|