langgraph-celery 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,5 @@
1
+ from langgraph_celery.events import GraphResult, NodeEvent
2
+ from langgraph_celery.interrupt import resume
3
+ from langgraph_celery.task import task
4
+
5
+ __all__ = ["task", "resume", "GraphResult", "NodeEvent"]
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncIterator
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from typing import Any
7
+
8
+
9
+ async def invoke(graph: Any, input: dict, config: dict | None = None) -> dict:
10
+ return await graph.ainvoke(input, config=config or {})
11
+
12
+
13
+ async def stream_events(
14
+ graph: Any,
15
+ input: dict,
16
+ config: dict | None = None,
17
+ ) -> AsyncIterator[tuple[str, str, dict]]:
18
+ async for event in graph.astream_events(input, config=config or {}, version="v2"):
19
+ kind = event.get("event", "")
20
+ name = event.get("name", "")
21
+ data = event.get("data", {})
22
+ yield kind, name, data
23
+
24
+
25
+ def run_sync(coro) -> Any:
26
+ try:
27
+ asyncio.get_running_loop()
28
+ except RuntimeError:
29
+ return asyncio.run(coro)
30
+ with ThreadPoolExecutor(max_workers=1) as pool:
31
+ future = pool.submit(asyncio.run, coro)
32
+ return future.result()
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def build_checkpointer(checkpointer: str) -> Any:
7
+ if checkpointer == "memory":
8
+ from langgraph.checkpoint.memory import MemorySaver
9
+ return MemorySaver()
10
+ if checkpointer == "postgres":
11
+ from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
12
+ return AsyncPostgresSaver
13
+ raise ValueError(f"Unknown checkpointer: {checkpointer!r}. Use 'memory' or 'postgres'.")
14
+
15
+
16
+ def resolve_thread_id(thread_id_from: str, kwargs: dict) -> str:
17
+ value = kwargs.get(thread_id_from)
18
+ if value is None:
19
+ raise KeyError(
20
+ f"thread_id_from={thread_id_from!r} not found in task kwargs: {list(kwargs)}"
21
+ )
22
+ return f"task:{value}"
23
+
24
+
25
+ def inject_thread_id(config: dict, thread_id: str) -> dict:
26
+ config = dict(config)
27
+ configurable = dict(config.get("configurable", {}))
28
+ configurable["thread_id"] = thread_id
29
+ config["configurable"] = configurable
30
+ return config
@@ -0,0 +1,18 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ from langgraph_celery.bridge import stream_events
7
+ from langgraph_celery.events import GraphResult, NodeEvent
8
+
9
+
10
+ class NodeEventEmitter:
11
+ def __init__(self, callback: Callable[[NodeEvent], None]) -> None:
12
+ self._callback = callback
13
+
14
+ async def run(self, graph: Any, input: dict, config: dict | None = None) -> GraphResult:
15
+ result = await GraphResult.from_stream(stream_events(graph, input, config))
16
+ for event in result.node_events:
17
+ self._callback(event)
18
+ return result
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+
8
+ @dataclass
9
+ class NodeEvent:
10
+ kind: str
11
+ node: str
12
+ data: dict[str, Any] = field(default_factory=dict)
13
+
14
+
15
+ @dataclass
16
+ class GraphResult:
17
+ output: dict[str, Any]
18
+ full_answer: str = ""
19
+ tool_calls: list[dict] = field(default_factory=list)
20
+ node_events: list[NodeEvent] = field(default_factory=list)
21
+ interrupted: bool = False
22
+
23
+ @classmethod
24
+ def from_output(cls, output: dict) -> GraphResult:
25
+ from langchain_core.messages import AIMessage
26
+
27
+ messages = output.get("messages", [])
28
+ full_answer = ""
29
+ tool_calls = []
30
+
31
+ for msg in reversed(messages):
32
+ if isinstance(msg, AIMessage):
33
+ if msg.content and not getattr(msg, "tool_calls", None):
34
+ full_answer = msg.content
35
+ break
36
+ if getattr(msg, "tool_calls", None):
37
+ tool_calls.extend(msg.tool_calls)
38
+
39
+ return cls(output=output, full_answer=full_answer, tool_calls=tool_calls)
40
+
41
+ @classmethod
42
+ async def from_stream(
43
+ cls,
44
+ event_iter: AsyncIterator[tuple[str, str, dict]],
45
+ ) -> GraphResult:
46
+ tokens: list[str] = []
47
+ tool_calls: list[dict] = []
48
+ node_events: list[NodeEvent] = []
49
+ output: dict[str, Any] = {}
50
+
51
+ async for kind, name, data in event_iter:
52
+ if kind == "on_chat_model_stream":
53
+ chunk = data.get("chunk")
54
+ if chunk is not None:
55
+ content = getattr(chunk, "content", None)
56
+ if isinstance(content, str) and content:
57
+ tokens.append(content)
58
+ elif isinstance(content, list):
59
+ for part in content:
60
+ if isinstance(part, dict) and part.get("type") == "text":
61
+ tokens.append(part.get("text", ""))
62
+
63
+ elif kind == "on_tool_start":
64
+ node_events.append(NodeEvent(kind=kind, node=name, data=data))
65
+
66
+ elif kind == "on_tool_end":
67
+ node_events.append(NodeEvent(kind=kind, node=name, data=data))
68
+ tool_input = data.get("input") or {}
69
+ tool_output = data.get("output")
70
+ tool_calls.append({"name": name, "input": tool_input, "output": tool_output})
71
+
72
+ elif kind == "on_chain_end" and name == "LangGraph":
73
+ output = data.get("output", {})
74
+
75
+ if not output and tokens:
76
+ output = {}
77
+
78
+ result = cls.from_output(output)
79
+ if tokens and not result.full_answer:
80
+ result.full_answer = "".join(tokens)
81
+ result.node_events = node_events
82
+ if not result.tool_calls:
83
+ result.tool_calls = tool_calls
84
+ return result
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from langgraph_celery.bridge import invoke, run_sync
6
+ from langgraph_celery.events import GraphResult
7
+
8
+
9
+ async def _run_with_interrupt(graph: Any, graph_input: Any, config: dict) -> GraphResult:
10
+ output = await invoke(graph, graph_input, config)
11
+ interrupts = output.get("__interrupt__")
12
+ if interrupts:
13
+ interrupt_values = [i.value for i in interrupts]
14
+ result = GraphResult(
15
+ output=output,
16
+ interrupted=True,
17
+ )
18
+ result.output["__interrupt_values__"] = interrupt_values
19
+ return result
20
+ return GraphResult.from_output(output)
21
+
22
+
23
+ async def _resume(graph: Any, resume_value: Any, config: dict) -> GraphResult:
24
+ from langgraph.types import Command
25
+
26
+ output = await invoke(graph, Command(resume=resume_value), config)
27
+ interrupts = output.get("__interrupt__")
28
+ if interrupts:
29
+ interrupt_values = [i.value for i in interrupts]
30
+ result = GraphResult(output=output, interrupted=True)
31
+ result.output["__interrupt_values__"] = interrupt_values
32
+ return result
33
+ return GraphResult.from_output(output)
34
+
35
+
36
+ def run_with_interrupt(graph: Any, graph_input: Any, config: dict) -> GraphResult:
37
+ return run_sync(_run_with_interrupt(graph, graph_input, config))
38
+
39
+
40
+ def resume(graph: Any, resume_value: Any, config: dict) -> GraphResult:
41
+ return run_sync(_resume(graph, resume_value, config))
File without changes
File without changes
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from langgraph_celery.bridge import stream_events
7
+ from langgraph_celery.events import GraphResult
8
+
9
+
10
+ class RedisStreamer:
11
+ def __init__(self, redis_url: str, channel: str) -> None:
12
+ self._redis_url = redis_url
13
+ self._channel = channel
14
+
15
+ def _resolve_channel(self, kwargs: dict) -> str:
16
+ try:
17
+ return self._channel.format(**kwargs)
18
+ except KeyError:
19
+ return self._channel
20
+
21
+ async def run(
22
+ self,
23
+ graph: Any,
24
+ input: dict,
25
+ config: dict | None = None,
26
+ *,
27
+ task_kwargs: dict | None = None,
28
+ ) -> GraphResult:
29
+ import redis.asyncio as aioredis
30
+
31
+ channel = self._resolve_channel(task_kwargs or {})
32
+ client = aioredis.from_url(self._redis_url)
33
+
34
+ try:
35
+ result = await GraphResult.from_stream(
36
+ self._publish_stream(client, channel, graph, input, config)
37
+ )
38
+ finally:
39
+ await client.aclose()
40
+
41
+ return result
42
+
43
+ async def _publish_stream(
44
+ self, client: Any, channel: str, graph: Any, input: dict, config: dict | None
45
+ ):
46
+ async for kind, name, data in stream_events(graph, input, config):
47
+ if kind == "on_chat_model_stream":
48
+ chunk = data.get("chunk")
49
+ if chunk is not None:
50
+ content = getattr(chunk, "content", None)
51
+ if isinstance(content, str) and content:
52
+ msg = json.dumps({"type": "token", "content": content})
53
+ await client.publish(channel, msg)
54
+ elif isinstance(content, list):
55
+ for part in content:
56
+ if isinstance(part, dict) and part.get("type") == "text":
57
+ text = part.get("text", "")
58
+ if text:
59
+ msg = json.dumps({"type": "token", "content": text})
60
+ await client.publish(channel, msg)
61
+ elif kind == "on_tool_start":
62
+ await client.publish(channel, json.dumps({"type": "tool_start", "name": name}))
63
+ elif kind == "on_tool_end":
64
+ await client.publish(channel, json.dumps({"type": "tool_end", "name": name}))
65
+ elif kind == "on_chain_end" and name == "LangGraph":
66
+ await client.publish(channel, json.dumps({"type": "done"}))
67
+ yield kind, name, data
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import inspect
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ from langgraph_celery.bridge import invoke, run_sync
9
+
10
+
11
+ def task(
12
+ graph_factory: Callable[..., Any] | None = None,
13
+ *,
14
+ streaming: str | None = None,
15
+ channel: str | None = None,
16
+ checkpointer: str | None = None,
17
+ thread_id_from: str | None = None,
18
+ emit_node_events: bool = False,
19
+ on_node_event: Callable | None = None,
20
+ interrupt_before: list[str] | None = None,
21
+ ):
22
+ def decorator(fn: Callable) -> Callable:
23
+ @functools.wraps(fn)
24
+ def wrapper(*args, **kwargs):
25
+ graph = _resolve_graph(graph_factory, fn, args, kwargs)
26
+ graph_input = _build_input(fn, args, kwargs)
27
+ config = _build_config(checkpointer, thread_id_from, kwargs, graph)
28
+
29
+ if streaming == "redis":
30
+ return run_sync(_run_redis(graph, graph_input, config, channel, kwargs))
31
+
32
+ if emit_node_events:
33
+ return run_sync(_run_emit(graph, graph_input, config, on_node_event))
34
+
35
+ if interrupt_before:
36
+ from langgraph_celery.interrupt import run_with_interrupt
37
+
38
+ return run_with_interrupt(graph, graph_input, config)
39
+
40
+ return run_sync(invoke(graph, graph_input, config))
41
+
42
+ wrapper._langgraph_celery = {
43
+ "graph_factory": graph_factory,
44
+ "streaming": streaming,
45
+ "channel": channel,
46
+ "checkpointer": checkpointer,
47
+ "thread_id_from": thread_id_from,
48
+ "emit_node_events": emit_node_events,
49
+ "interrupt_before": interrupt_before,
50
+ }
51
+ return wrapper
52
+
53
+ if callable(graph_factory):
54
+ fn = graph_factory
55
+ graph_factory = None
56
+ return decorator(fn)
57
+
58
+ return decorator
59
+
60
+
61
+ def _resolve_graph(graph_factory, fn, args, kwargs):
62
+ if graph_factory is not None:
63
+ sig = inspect.signature(graph_factory)
64
+ params = list(sig.parameters.keys())
65
+ bound_kwargs = {k: v for k, v in kwargs.items() if k in params}
66
+ return graph_factory(**bound_kwargs)
67
+ raise ValueError("graph_factory must be provided via task(graph=...)")
68
+
69
+
70
+ def _build_input(fn, args, kwargs) -> dict:
71
+ sig = inspect.signature(fn)
72
+ bound = sig.bind(*args, **kwargs)
73
+ bound.apply_defaults()
74
+ return dict(bound.arguments)
75
+
76
+
77
+ def _build_config(checkpointer_name, thread_id_from, kwargs, graph) -> dict:
78
+ config: dict = {}
79
+
80
+ if checkpointer_name is not None:
81
+ from langgraph_celery.checkpointing import (
82
+ build_checkpointer,
83
+ inject_thread_id,
84
+ resolve_thread_id,
85
+ )
86
+
87
+ cp = build_checkpointer(checkpointer_name)
88
+ try:
89
+ graph.checkpointer = cp
90
+ except Exception:
91
+ pass
92
+
93
+ if thread_id_from:
94
+ thread_id = resolve_thread_id(thread_id_from, kwargs)
95
+ config = inject_thread_id(config, thread_id)
96
+
97
+ return config
98
+
99
+
100
+ async def _run_redis(graph, graph_input, config, channel, task_kwargs) -> Any:
101
+ from langgraph_celery.streaming.redis import RedisStreamer
102
+
103
+ if channel is None:
104
+ raise ValueError("channel must be set when streaming='redis'")
105
+
106
+ redis_url = task_kwargs.get("redis_url", "redis://localhost:6379")
107
+ streamer = RedisStreamer(redis_url=redis_url, channel=channel)
108
+ return await streamer.run(graph, graph_input, config, task_kwargs=task_kwargs)
109
+
110
+
111
+ async def _run_emit(graph, graph_input, config, on_node_event) -> Any:
112
+ from langgraph_celery.emitter import NodeEventEmitter
113
+
114
+ callback = on_node_event if on_node_event is not None else lambda e: None
115
+ emitter = NodeEventEmitter(callback=callback)
116
+ return await emitter.run(graph, graph_input, config)
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: langgraph-celery
3
+ Version: 0.1.0
4
+ Summary: Bridge between Celery workers and LangGraph agents
5
+ Author-email: Charef <ccherrad@gmail.com>
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: celery>=5.6.3
8
+ Requires-Dist: langgraph>=1.1.10
@@ -0,0 +1,13 @@
1
+ langgraph_celery/__init__.py,sha256=bL5p51zqAw7C7mqazOunH-H4nFru_l6JYiQVJZJ-7nM,202
2
+ langgraph_celery/bridge.py,sha256=YfEYV5sIGHj-WCkGcVz7Bd4rZzqEazwkx3DPym6yv88,933
3
+ langgraph_celery/checkpointing.py,sha256=5XLrrjsa9umViY6FFT6CtCrT2NBC11A13EgvMbu0Hu0,993
4
+ langgraph_celery/emitter.py,sha256=jIzlZBvK9G9-egOFKm6895ibmMCfzyzPAM5HotwMMr8,612
5
+ langgraph_celery/events.py,sha256=8tE1eDNa_x3us8hImypNXBPoNjxrMCwUVGPvF1IY2NE,2901
6
+ langgraph_celery/interrupt.py,sha256=AkrvPl0P94_SgTXxQAz-hxO9okK41-_YJajg7-R3CP4,1447
7
+ langgraph_celery/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ langgraph_celery/task.py,sha256=XUf70IbQKwynGOFh7ExlOcHlqX8XSiX6JauLo1iXFYc,3783
9
+ langgraph_celery/streaming/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ langgraph_celery/streaming/redis.py,sha256=4PVn2bsnPiX-D4E73-wGn0fr1wlMI3zuJk9SomUeQbk,2547
11
+ langgraph_celery-0.1.0.dist-info/METADATA,sha256=jrZcQQQUaJr8tp0bgBERT8b-4u_epqzqq91A_G68qwA,248
12
+ langgraph_celery-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
13
+ langgraph_celery-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.29.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any