graphai-lib 0.0.9rc2__py3-none-any.whl → 0.0.10rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graphai/callback.py CHANGED
@@ -1,13 +1,70 @@
1
1
  import asyncio
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ import json
2
5
  from pydantic import Field
3
6
  from typing import Any
4
7
  from collections.abc import AsyncIterator
8
+ import warnings
5
9
 
6
10
 
7
11
  log_stream = True
8
12
 
9
13
 
14
+ class StrEnum(Enum):
15
+ def __str__(self) -> str:
16
+ return str(self.value)
17
+
18
+ class GraphEventType(StrEnum):
19
+ START = "start"
20
+ END = "end"
21
+ START_NODE = "start_node"
22
+ END_NODE = "end_node"
23
+ CALLBACK = "callback"
24
+
25
+ @dataclass
26
+ class GraphEvent:
27
+ """A graph event emitted for specific graph events such as start node or end node,
28
+ and used by the callback to emit user-defined events.
29
+
30
+ :param type: The type of event, can be start_node, end_node, or callback.
31
+ :type type: GraphEventType
32
+ :param identifier: The identifier of the event, this is set typically by a callback
33
+ handler and can be used to distinguish between different events. For example, a
34
+ conversation/session ID could be used.
35
+ :type identifier: str
36
+ :param token: The token associated with the event, such as LLM streamed output.
37
+ :type token: str | None
38
+ :param params: The parameters associated with the event, such as tool call parameters
39
+ or event metadata.
40
+ :type params: dict[str, Any] | None
41
+ """
42
+ type: GraphEventType
43
+ identifier: str
44
+ token: str | None = None
45
+ params: dict[str, Any] | None = None
46
+
47
+ def encode(self, charset: str = "utf-8") -> bytes:
48
+ """Encodes the event as a JSON string, important for compatability with FastAPI
49
+ and starlette.
50
+
51
+ :param charset: The character set to use for encoding the event.
52
+ :type charset: str
53
+ """
54
+ event_dict = {
55
+ "type": self.type.value if hasattr(self.type, "value") else str(self.type),
56
+ "identifier": self.identifier,
57
+ "token": self.token,
58
+ "params": self.params,
59
+ }
60
+ data = f"data: {json.dumps(event_dict, ensure_ascii=False, separators=(',', ':'))}\n\n"
61
+ return data.encode(charset)
62
+
63
+
10
64
  class Callback:
65
+ """The original callback handler class. Outputs a stream of structured text
66
+ tokens. It is recommended to use the newer `EventCallback` handler instead.
67
+ """
11
68
  identifier: str = Field(
12
69
  default="graphai",
13
70
  description=(
@@ -67,6 +124,11 @@ class Callback:
67
124
  special_token_format: str = "<{identifier}:{token}:{params}>",
68
125
  token_format: str = "{token}",
69
126
  ):
127
+ warnings.warn(
128
+ "The `Callback` class is deprecated and will be removed in " +
129
+ "v0.1.0. Use the `EventCallback` class instead.",
130
+ DeprecationWarning
131
+ )
70
132
  self.identifier = identifier
71
133
  self.special_token_format = special_token_format
72
134
  self.token_format = token_format
@@ -198,3 +260,100 @@ class Callback:
198
260
  return self.special_token_format.format(
199
261
  identifier=identifier, token=name, params=params_str
200
262
  )
263
+
264
+
265
+ class EventCallback(Callback):
266
+ """The event callback handler class. Outputs a stream of structured text
267
+ tokens. It is recommended to use the newer `EventCallback` handler instead.
268
+ """
269
+ def __init__(
270
+ self,
271
+ identifier: str = "graphai",
272
+ special_token_format: str | None = None,
273
+ token_format: str | None = None,
274
+ ):
275
+ warnings.warn(
276
+ "The `special_token_format` and `token_format` parameters are " +
277
+ "deprecated and will be removed in v0.1.0.",
278
+ DeprecationWarning
279
+ )
280
+ if special_token_format is None:
281
+ special_token_format = "<{identifier}:{token}:{params}>"
282
+ if token_format is None:
283
+ token_format = "{token}"
284
+ super().__init__(identifier, special_token_format, token_format)
285
+ self.events: list[GraphEvent] = []
286
+
287
+ def __call__(self, token: str, node_name: str | None = None):
288
+ if self._done:
289
+ raise RuntimeError("Cannot add tokens to a closed stream")
290
+ self._check_node_name(node_name=node_name)
291
+ event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
292
+ # otherwise we just assume node is correct and send token
293
+ self.queue.put_nowait(event)
294
+
295
+ async def acall(self, token: str, node_name: str | None = None):
296
+ # TODO JB: do we need to have `node_name` param?
297
+ if self._done:
298
+ raise RuntimeError("Cannot add tokens to a closed stream")
299
+ self._check_node_name(node_name=node_name)
300
+ event = GraphEvent(type=GraphEventType.CALLBACK, identifier=self.identifier, token=token, params=None)
301
+ # otherwise we just assume node is correct and send token
302
+ self.queue.put_nowait(event)
303
+
304
+ async def aiter(self) -> AsyncIterator[GraphEvent]: # type: ignore[override]
305
+ """Used by receiver to get the tokens from the stream queue. Creates
306
+ a generator that yields tokens from the queue until the END token is
307
+ received.
308
+ """
309
+ while True: # Keep going until we see the END token
310
+ try:
311
+ if self._done and self.queue.empty():
312
+ break
313
+ token = await self.queue.get()
314
+ yield token
315
+ self.queue.task_done()
316
+ if token.type == GraphEventType.END:
317
+ break
318
+ except asyncio.CancelledError:
319
+ break
320
+ self._done = True # Mark as done after processing all tokens
321
+
322
+ async def start_node(self, node_name: str, active: bool = True):
323
+ """Starts a new node and emits the start token."""
324
+ if self._done:
325
+ raise RuntimeError("Cannot start node on a closed stream")
326
+ self.current_node_name = node_name
327
+ if self.first_token:
328
+ self.first_token = False
329
+ self.active = active
330
+ if self.active:
331
+ token = GraphEvent(type=GraphEventType.START_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
332
+ self.queue.put_nowait(token)
333
+
334
+ async def end_node(self, node_name: str):
335
+ """Emits the end token for the current node."""
336
+ if self._done:
337
+ raise RuntimeError("Cannot end node on a closed stream")
338
+ # self.current_node_name = node_name
339
+ if self.active:
340
+ token = GraphEvent(type=GraphEventType.END_NODE, identifier=self.identifier, token=self.current_node_name, params=None)
341
+ self.queue.put_nowait(token)
342
+
343
+ async def close(self):
344
+ """Close the stream and prevent further tokens from being added.
345
+ This will send an END token and set the done flag to True.
346
+ """
347
+ if self._done:
348
+ return
349
+ end_token = GraphEvent(type=GraphEventType.END, identifier=self.identifier)
350
+ self._done = True # Set done before putting the end token
351
+ self.queue.put_nowait(end_token)
352
+ # Don't wait for queue.join() as it can cause deadlock
353
+ # The stream will close when aiter processes the END token
354
+
355
+ async def _build_special_token(
356
+ self, name: str, params: dict[str, Any] | None = None
357
+ ):
358
+ raise NotImplementedError("This method is not implemented for the `EventCallback` class.")
359
+
graphai/graph.py CHANGED
@@ -1,5 +1,7 @@
1
1
  from __future__ import annotations
2
- from typing import Any, Protocol
2
+ import asyncio
3
+ from typing import Any, Iterable, Protocol
4
+ from graphlib import TopologicalSorter, CycleError
3
5
  from graphai.callback import Callback
4
6
  from graphai.utils import logger
5
7
 
@@ -165,14 +167,14 @@ class Graph:
165
167
  self.end_node = node
166
168
  return self
167
169
 
168
- def compile(self) -> "Graph":
170
+ def compile(self, *, strict: bool = False) -> Graph:
169
171
  """
170
172
  Validate the graph:
171
173
  - exactly one start node present (or Graph.start_node set)
172
174
  - at least one end node present
173
175
  - all edges reference known nodes
174
- - no cycles
175
176
  - all nodes reachable from the start
177
+ (optional) **no cycles** when strict=True
176
178
  Returns self on success; raises GraphCompileError otherwise.
177
179
  """
178
180
  # nodes map
@@ -301,6 +303,17 @@ class Graph:
301
303
  if unreachable:
302
304
  raise GraphCompileError(f"Unreachable nodes: {unreachable}")
303
305
 
306
+ # optional cycle detection (strict mode)
307
+ if strict:
308
+ preds: dict[str, set[str]] = {n: set() for n in nodes.keys()}
309
+ for s, ds in adj.items():
310
+ for d in ds:
311
+ preds[d].add(s)
312
+ try:
313
+ list(TopologicalSorter(preds).static_order())
314
+ except CycleError as e:
315
+ raise GraphCompileError("cycle detected in graph (strict mode)") from e
316
+
304
317
  return self
305
318
 
306
319
  def _validate_output(self, output: dict[str, Any], node_name: str):
@@ -362,6 +375,32 @@ class Graph:
362
375
  del state["callback"]
363
376
  return state
364
377
 
378
+ async def execute_many(
379
+ self, inputs: Iterable[dict[str, Any]], *, concurrency: int = 5
380
+ ) -> list[Any]:
381
+ """Execute the graph on many inputs concurrently.
382
+
383
+ :param inputs: An iterable of input dicts to feed into the graph.
384
+ :type inputs: Iterable[dict]
385
+ :param concurrency: Maximum number of graph executions to run at once.
386
+ :type concurrency: int
387
+ :param state: Optional shared state to pass to each execution.
388
+ If you want isolated state per execution, pass None
389
+ and the graph's normal semantics will apply.
390
+ :type state: Optional[Any]
391
+ :returns: The list of results in the same order as ``inputs``.
392
+ :rtype: list[Any]
393
+ """
394
+
395
+ sem = asyncio.Semaphore(concurrency)
396
+
397
+ async def _run_one(inp: dict[str, Any]) -> Any:
398
+ async with sem:
399
+ return await self.execute(input=inp)
400
+
401
+ tasks = [asyncio.create_task(_run_one(i)) for i in inputs]
402
+ return await asyncio.gather(*tasks)
403
+
365
404
  def get_callback(self):
366
405
  """Get a new instance of the callback class.
367
406
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphai-lib
3
- Version: 0.0.9rc2
3
+ Version: 0.0.10rc1
4
4
  Summary: Not an AI framework
5
5
  Requires-Python: <3.14,>=3.10
6
6
  Description-Content-Type: text/markdown
@@ -0,0 +1,11 @@
1
+ graphai/__init__.py,sha256=UbqXq7iGIYe1GyTPcpgLSXbgWovggFsAbTMtr4JQm3M,160
2
+ graphai/callback.py,sha256=xQhK1xbjkouemaIB5hjPZpo3thxiA6VVu9TpHslOsww,13851
3
+ graphai/graph.py,sha256=CSdRYlQnfVHNw4XyKMEZ1rJqqg6rOx-aT8w9xumFhBs,20138
4
+ graphai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ graphai/utils.py,sha256=LIFg-fQalU9sB5DCuk6is48OdpEgNX95i9h-YddFbvM,11717
6
+ graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
7
+ graphai/nodes/base.py,sha256=SoKfOdRu5EIJ_z8xIz5zbNXcxPI2l9MKTQDeaQI-2no,7494
8
+ graphai_lib-0.0.10rc1.dist-info/METADATA,sha256=2sIQkGfW7m0UoOEioVono53cVqpDkxRlxoDVlt2Tl1o,914
9
+ graphai_lib-0.0.10rc1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
+ graphai_lib-0.0.10rc1.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
11
+ graphai_lib-0.0.10rc1.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- graphai/__init__.py,sha256=UbqXq7iGIYe1GyTPcpgLSXbgWovggFsAbTMtr4JQm3M,160
2
- graphai/callback.py,sha256=Wl0JCmE8NcFifKmP9-a5bFa0WKVHTdrSClHVRmIEPpc,7323
3
- graphai/graph.py,sha256=O0hZ_29ln8oTYnk5EX5bdhO6-Cr9pm-6IVCq-ZPWbgY,18526
4
- graphai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- graphai/utils.py,sha256=LIFg-fQalU9sB5DCuk6is48OdpEgNX95i9h-YddFbvM,11717
6
- graphai/nodes/__init__.py,sha256=IaMUryAqTZlcEqh-ZS6A4NIYG18JZwzo145dzxsYjAk,74
7
- graphai/nodes/base.py,sha256=SoKfOdRu5EIJ_z8xIz5zbNXcxPI2l9MKTQDeaQI-2no,7494
8
- graphai_lib-0.0.9rc2.dist-info/METADATA,sha256=_3lwoAyLzSSgFq8GQHW_FAGd6A0kKiGQB3XWCnirJRs,913
9
- graphai_lib-0.0.9rc2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
10
- graphai_lib-0.0.9rc2.dist-info/top_level.txt,sha256=TXlqmhLViX-3xGH2g5w6cavRd-QMf229Hl88jdMOGt8,8
11
- graphai_lib-0.0.9rc2.dist-info/RECORD,,