langgraph-celery 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,22 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: ["main"]
6
+ pull_request:
7
+
8
+ jobs:
9
+ check:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v4
13
+
14
+ - uses: astral-sh/setup-uv@v5
15
+ with:
16
+ enable-cache: true
17
+
18
+ - run: uv sync --frozen
19
+
20
+ - run: uv run ruff check src/
21
+
22
+ - run: uv run pytest
@@ -0,0 +1,31 @@
1
+ name: Release
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ jobs:
9
+ release:
10
+ runs-on: ubuntu-latest
11
+ environment: release
12
+ permissions:
13
+ id-token: write
14
+ contents: read
15
+
16
+ steps:
17
+ - uses: actions/checkout@v4
18
+
19
+ - uses: astral-sh/setup-uv@v5
20
+ with:
21
+ enable-cache: true
22
+
23
+ - run: uv sync --frozen
24
+
25
+ - run: uv run ruff check src/
26
+
27
+ - run: uv run pytest
28
+
29
+ - run: uv build
30
+
31
+ - uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,6 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.egg-info/
4
+ dist/
5
+ .venv/
6
+ .env
@@ -0,0 +1 @@
1
+ 3.13
@@ -0,0 +1,8 @@
1
+ Metadata-Version: 2.4
2
+ Name: langgraph-celery
3
+ Version: 0.1.0
4
+ Summary: Bridge between Celery workers and LangGraph agents
5
+ Author-email: Charef <ccherrad@gmail.com>
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: celery>=5.6.3
8
+ Requires-Dist: langgraph>=1.1.10
File without changes
@@ -0,0 +1,35 @@
1
+ [project]
2
+ name = "langgraph-celery"
3
+ version = "0.1.0"
4
+ description = "Bridge between Celery workers and LangGraph agents"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Charef", email = "ccherrad@gmail.com" }
8
+ ]
9
+ requires-python = ">=3.11"
10
+ dependencies = [
11
+ "celery>=5.6.3",
12
+ "langgraph>=1.1.10",
13
+ ]
14
+
15
+ [build-system]
16
+ requires = ["hatchling"]
17
+ build-backend = "hatchling.build"
18
+
19
+ [dependency-groups]
20
+ dev = [
21
+ "pytest>=9.0.3",
22
+ "pytest-asyncio>=1.3.0",
23
+ "ruff>=0.15.12",
24
+ ]
25
+
26
+ [tool.pytest.ini_options]
27
+ asyncio_mode = "auto"
28
+ testpaths = ["tests"]
29
+
30
+ [tool.ruff]
31
+ line-length = 100
32
+ target-version = "py311"
33
+
34
+ [tool.ruff.lint]
35
+ select = ["E", "F", "I", "UP"]
@@ -0,0 +1,5 @@
1
+ from langgraph_celery.events import GraphResult, NodeEvent
2
+ from langgraph_celery.interrupt import resume
3
+ from langgraph_celery.task import task
4
+
5
+ __all__ = ["task", "resume", "GraphResult", "NodeEvent"]
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ from collections.abc import AsyncIterator
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from typing import Any
7
+
8
+
9
+ async def invoke(graph: Any, input: dict, config: dict | None = None) -> dict:
10
+ return await graph.ainvoke(input, config=config or {})
11
+
12
+
13
+ async def stream_events(
14
+ graph: Any,
15
+ input: dict,
16
+ config: dict | None = None,
17
+ ) -> AsyncIterator[tuple[str, str, dict]]:
18
+ async for event in graph.astream_events(input, config=config or {}, version="v2"):
19
+ kind = event.get("event", "")
20
+ name = event.get("name", "")
21
+ data = event.get("data", {})
22
+ yield kind, name, data
23
+
24
+
25
+ def run_sync(coro) -> Any:
26
+ try:
27
+ asyncio.get_running_loop()
28
+ except RuntimeError:
29
+ return asyncio.run(coro)
30
+ with ThreadPoolExecutor(max_workers=1) as pool:
31
+ future = pool.submit(asyncio.run, coro)
32
+ return future.result()
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def build_checkpointer(checkpointer: str) -> Any:
7
+ if checkpointer == "memory":
8
+ from langgraph.checkpoint.memory import MemorySaver
9
+ return MemorySaver()
10
+ if checkpointer == "postgres":
11
+ from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
12
+ return AsyncPostgresSaver
13
+ raise ValueError(f"Unknown checkpointer: {checkpointer!r}. Use 'memory' or 'postgres'.")
14
+
15
+
16
+ def resolve_thread_id(thread_id_from: str, kwargs: dict) -> str:
17
+ value = kwargs.get(thread_id_from)
18
+ if value is None:
19
+ raise KeyError(
20
+ f"thread_id_from={thread_id_from!r} not found in task kwargs: {list(kwargs)}"
21
+ )
22
+ return f"task:{value}"
23
+
24
+
25
+ def inject_thread_id(config: dict, thread_id: str) -> dict:
26
+ config = dict(config)
27
+ configurable = dict(config.get("configurable", {}))
28
+ configurable["thread_id"] = thread_id
29
+ config["configurable"] = configurable
30
+ return config
@@ -0,0 +1,18 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ from langgraph_celery.bridge import stream_events
7
+ from langgraph_celery.events import GraphResult, NodeEvent
8
+
9
+
10
+ class NodeEventEmitter:
11
+ def __init__(self, callback: Callable[[NodeEvent], None]) -> None:
12
+ self._callback = callback
13
+
14
+ async def run(self, graph: Any, input: dict, config: dict | None = None) -> GraphResult:
15
+ result = await GraphResult.from_stream(stream_events(graph, input, config))
16
+ for event in result.node_events:
17
+ self._callback(event)
18
+ return result
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import AsyncIterator
4
+ from dataclasses import dataclass, field
5
+ from typing import Any
6
+
7
+
8
+ @dataclass
9
+ class NodeEvent:
10
+ kind: str
11
+ node: str
12
+ data: dict[str, Any] = field(default_factory=dict)
13
+
14
+
15
+ @dataclass
16
+ class GraphResult:
17
+ output: dict[str, Any]
18
+ full_answer: str = ""
19
+ tool_calls: list[dict] = field(default_factory=list)
20
+ node_events: list[NodeEvent] = field(default_factory=list)
21
+ interrupted: bool = False
22
+
23
+ @classmethod
24
+ def from_output(cls, output: dict) -> GraphResult:
25
+ from langchain_core.messages import AIMessage
26
+
27
+ messages = output.get("messages", [])
28
+ full_answer = ""
29
+ tool_calls = []
30
+
31
+ for msg in reversed(messages):
32
+ if isinstance(msg, AIMessage):
33
+ if msg.content and not getattr(msg, "tool_calls", None):
34
+ full_answer = msg.content
35
+ break
36
+ if getattr(msg, "tool_calls", None):
37
+ tool_calls.extend(msg.tool_calls)
38
+
39
+ return cls(output=output, full_answer=full_answer, tool_calls=tool_calls)
40
+
41
+ @classmethod
42
+ async def from_stream(
43
+ cls,
44
+ event_iter: AsyncIterator[tuple[str, str, dict]],
45
+ ) -> GraphResult:
46
+ tokens: list[str] = []
47
+ tool_calls: list[dict] = []
48
+ node_events: list[NodeEvent] = []
49
+ output: dict[str, Any] = {}
50
+
51
+ async for kind, name, data in event_iter:
52
+ if kind == "on_chat_model_stream":
53
+ chunk = data.get("chunk")
54
+ if chunk is not None:
55
+ content = getattr(chunk, "content", None)
56
+ if isinstance(content, str) and content:
57
+ tokens.append(content)
58
+ elif isinstance(content, list):
59
+ for part in content:
60
+ if isinstance(part, dict) and part.get("type") == "text":
61
+ tokens.append(part.get("text", ""))
62
+
63
+ elif kind == "on_tool_start":
64
+ node_events.append(NodeEvent(kind=kind, node=name, data=data))
65
+
66
+ elif kind == "on_tool_end":
67
+ node_events.append(NodeEvent(kind=kind, node=name, data=data))
68
+ tool_input = data.get("input") or {}
69
+ tool_output = data.get("output")
70
+ tool_calls.append({"name": name, "input": tool_input, "output": tool_output})
71
+
72
+ elif kind == "on_chain_end" and name == "LangGraph":
73
+ output = data.get("output", {})
74
+
75
+ if not output and tokens:
76
+ output = {}
77
+
78
+ result = cls.from_output(output)
79
+ if tokens and not result.full_answer:
80
+ result.full_answer = "".join(tokens)
81
+ result.node_events = node_events
82
+ if not result.tool_calls:
83
+ result.tool_calls = tool_calls
84
+ return result
@@ -0,0 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from langgraph_celery.bridge import invoke, run_sync
6
+ from langgraph_celery.events import GraphResult
7
+
8
+
9
+ async def _run_with_interrupt(graph: Any, graph_input: Any, config: dict) -> GraphResult:
10
+ output = await invoke(graph, graph_input, config)
11
+ interrupts = output.get("__interrupt__")
12
+ if interrupts:
13
+ interrupt_values = [i.value for i in interrupts]
14
+ result = GraphResult(
15
+ output=output,
16
+ interrupted=True,
17
+ )
18
+ result.output["__interrupt_values__"] = interrupt_values
19
+ return result
20
+ return GraphResult.from_output(output)
21
+
22
+
23
+ async def _resume(graph: Any, resume_value: Any, config: dict) -> GraphResult:
24
+ from langgraph.types import Command
25
+
26
+ output = await invoke(graph, Command(resume=resume_value), config)
27
+ interrupts = output.get("__interrupt__")
28
+ if interrupts:
29
+ interrupt_values = [i.value for i in interrupts]
30
+ result = GraphResult(output=output, interrupted=True)
31
+ result.output["__interrupt_values__"] = interrupt_values
32
+ return result
33
+ return GraphResult.from_output(output)
34
+
35
+
36
+ def run_with_interrupt(graph: Any, graph_input: Any, config: dict) -> GraphResult:
37
+ return run_sync(_run_with_interrupt(graph, graph_input, config))
38
+
39
+
40
+ def resume(graph: Any, resume_value: Any, config: dict) -> GraphResult:
41
+ return run_sync(_resume(graph, resume_value, config))
File without changes
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Any
5
+
6
+ from langgraph_celery.bridge import stream_events
7
+ from langgraph_celery.events import GraphResult
8
+
9
+
10
+ class RedisStreamer:
11
+ def __init__(self, redis_url: str, channel: str) -> None:
12
+ self._redis_url = redis_url
13
+ self._channel = channel
14
+
15
+ def _resolve_channel(self, kwargs: dict) -> str:
16
+ try:
17
+ return self._channel.format(**kwargs)
18
+ except KeyError:
19
+ return self._channel
20
+
21
+ async def run(
22
+ self,
23
+ graph: Any,
24
+ input: dict,
25
+ config: dict | None = None,
26
+ *,
27
+ task_kwargs: dict | None = None,
28
+ ) -> GraphResult:
29
+ import redis.asyncio as aioredis
30
+
31
+ channel = self._resolve_channel(task_kwargs or {})
32
+ client = aioredis.from_url(self._redis_url)
33
+
34
+ try:
35
+ result = await GraphResult.from_stream(
36
+ self._publish_stream(client, channel, graph, input, config)
37
+ )
38
+ finally:
39
+ await client.aclose()
40
+
41
+ return result
42
+
43
+ async def _publish_stream(
44
+ self, client: Any, channel: str, graph: Any, input: dict, config: dict | None
45
+ ):
46
+ async for kind, name, data in stream_events(graph, input, config):
47
+ if kind == "on_chat_model_stream":
48
+ chunk = data.get("chunk")
49
+ if chunk is not None:
50
+ content = getattr(chunk, "content", None)
51
+ if isinstance(content, str) and content:
52
+ msg = json.dumps({"type": "token", "content": content})
53
+ await client.publish(channel, msg)
54
+ elif isinstance(content, list):
55
+ for part in content:
56
+ if isinstance(part, dict) and part.get("type") == "text":
57
+ text = part.get("text", "")
58
+ if text:
59
+ msg = json.dumps({"type": "token", "content": text})
60
+ await client.publish(channel, msg)
61
+ elif kind == "on_tool_start":
62
+ await client.publish(channel, json.dumps({"type": "tool_start", "name": name}))
63
+ elif kind == "on_tool_end":
64
+ await client.publish(channel, json.dumps({"type": "tool_end", "name": name}))
65
+ elif kind == "on_chain_end" and name == "LangGraph":
66
+ await client.publish(channel, json.dumps({"type": "done"}))
67
+ yield kind, name, data
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import inspect
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+
8
+ from langgraph_celery.bridge import invoke, run_sync
9
+
10
+
11
+ def task(
12
+ graph_factory: Callable[..., Any] | None = None,
13
+ *,
14
+ streaming: str | None = None,
15
+ channel: str | None = None,
16
+ checkpointer: str | None = None,
17
+ thread_id_from: str | None = None,
18
+ emit_node_events: bool = False,
19
+ on_node_event: Callable | None = None,
20
+ interrupt_before: list[str] | None = None,
21
+ ):
22
+ def decorator(fn: Callable) -> Callable:
23
+ @functools.wraps(fn)
24
+ def wrapper(*args, **kwargs):
25
+ graph = _resolve_graph(graph_factory, fn, args, kwargs)
26
+ graph_input = _build_input(fn, args, kwargs)
27
+ config = _build_config(checkpointer, thread_id_from, kwargs, graph)
28
+
29
+ if streaming == "redis":
30
+ return run_sync(_run_redis(graph, graph_input, config, channel, kwargs))
31
+
32
+ if emit_node_events:
33
+ return run_sync(_run_emit(graph, graph_input, config, on_node_event))
34
+
35
+ if interrupt_before:
36
+ from langgraph_celery.interrupt import run_with_interrupt
37
+
38
+ return run_with_interrupt(graph, graph_input, config)
39
+
40
+ return run_sync(invoke(graph, graph_input, config))
41
+
42
+ wrapper._langgraph_celery = {
43
+ "graph_factory": graph_factory,
44
+ "streaming": streaming,
45
+ "channel": channel,
46
+ "checkpointer": checkpointer,
47
+ "thread_id_from": thread_id_from,
48
+ "emit_node_events": emit_node_events,
49
+ "interrupt_before": interrupt_before,
50
+ }
51
+ return wrapper
52
+
53
+ if callable(graph_factory):
54
+ fn = graph_factory
55
+ graph_factory = None
56
+ return decorator(fn)
57
+
58
+ return decorator
59
+
60
+
61
+ def _resolve_graph(graph_factory, fn, args, kwargs):
62
+ if graph_factory is not None:
63
+ sig = inspect.signature(graph_factory)
64
+ params = list(sig.parameters.keys())
65
+ bound_kwargs = {k: v for k, v in kwargs.items() if k in params}
66
+ return graph_factory(**bound_kwargs)
67
+ raise ValueError("graph_factory must be provided via task(graph=...)")
68
+
69
+
70
+ def _build_input(fn, args, kwargs) -> dict:
71
+ sig = inspect.signature(fn)
72
+ bound = sig.bind(*args, **kwargs)
73
+ bound.apply_defaults()
74
+ return dict(bound.arguments)
75
+
76
+
77
+ def _build_config(checkpointer_name, thread_id_from, kwargs, graph) -> dict:
78
+ config: dict = {}
79
+
80
+ if checkpointer_name is not None:
81
+ from langgraph_celery.checkpointing import (
82
+ build_checkpointer,
83
+ inject_thread_id,
84
+ resolve_thread_id,
85
+ )
86
+
87
+ cp = build_checkpointer(checkpointer_name)
88
+ try:
89
+ graph.checkpointer = cp
90
+ except Exception:
91
+ pass
92
+
93
+ if thread_id_from:
94
+ thread_id = resolve_thread_id(thread_id_from, kwargs)
95
+ config = inject_thread_id(config, thread_id)
96
+
97
+ return config
98
+
99
+
100
+ async def _run_redis(graph, graph_input, config, channel, task_kwargs) -> Any:
101
+ from langgraph_celery.streaming.redis import RedisStreamer
102
+
103
+ if channel is None:
104
+ raise ValueError("channel must be set when streaming='redis'")
105
+
106
+ redis_url = task_kwargs.get("redis_url", "redis://localhost:6379")
107
+ streamer = RedisStreamer(redis_url=redis_url, channel=channel)
108
+ return await streamer.run(graph, graph_input, config, task_kwargs=task_kwargs)
109
+
110
+
111
+ async def _run_emit(graph, graph_input, config, on_node_event) -> Any:
112
+ from langgraph_celery.emitter import NodeEventEmitter
113
+
114
+ callback = on_node_event if on_node_event is not None else lambda e: None
115
+ emitter = NodeEventEmitter(callback=callback)
116
+ return await emitter.run(graph, graph_input, config)
File without changes
@@ -0,0 +1,22 @@
1
+ import asyncio
2
+
3
+ import pytest
4
+
5
+ from langgraph_celery.bridge import run_sync
6
+
7
+
8
+ async def _add(x, y):
9
+ return x + y
10
+
11
+
12
+ def test_run_sync_no_loop():
13
+ result = run_sync(_add(1, 2))
14
+ assert result == 3
15
+
16
+
17
+ def test_run_sync_inside_loop():
18
+ async def inner():
19
+ return run_sync(_add(3, 4))
20
+
21
+ result = asyncio.run(inner())
22
+ assert result == 7
@@ -0,0 +1,22 @@
1
+ from langchain_core.messages import AIMessage
2
+
3
+ from langgraph_celery.events import GraphResult
4
+
5
+
6
+ def test_graph_result_extracts_final_answer():
7
+ msg = AIMessage(content="Hello world")
8
+ result = GraphResult.from_output({"messages": [msg]})
9
+ assert result.full_answer == "Hello world"
10
+ assert not result.interrupted
11
+
12
+
13
+ def test_graph_result_skips_tool_call_messages():
14
+ tool_msg = AIMessage(content="", tool_calls=[{"name": "search", "args": {}, "id": "1"}])
15
+ final_msg = AIMessage(content="Final answer")
16
+ result = GraphResult.from_output({"messages": [tool_msg, final_msg]})
17
+ assert result.full_answer == "Final answer"
18
+
19
+
20
+ def test_graph_result_empty_messages():
21
+ result = GraphResult.from_output({"messages": []})
22
+ assert result.full_answer == ""
@@ -0,0 +1,50 @@
1
+ import pytest
2
+ from unittest.mock import MagicMock
3
+
4
+ from langgraph_celery.events import GraphResult, NodeEvent
5
+
6
+
7
+ async def _make_iter(events):
8
+ for e in events:
9
+ yield e
10
+
11
+
12
+ def _chunk(text):
13
+ m = MagicMock()
14
+ m.content = text
15
+ return m
16
+
17
+
18
+ async def test_from_stream_tokens():
19
+ events = [
20
+ ("on_chat_model_stream", "model", {"chunk": _chunk("Hello ")}),
21
+ ("on_chat_model_stream", "model", {"chunk": _chunk("world")}),
22
+ ("on_chain_end", "LangGraph", {"output": {}}),
23
+ ]
24
+ result = await GraphResult.from_stream(_make_iter(events))
25
+ assert result.full_answer == "Hello world"
26
+
27
+
28
+ async def test_from_stream_tool_events():
29
+ events = [
30
+ ("on_tool_start", "search", {"input": {"query": "foo"}}),
31
+ ("on_tool_end", "search", {"output": "bar", "input": {"query": "foo"}}),
32
+ ("on_chain_end", "LangGraph", {"output": {}}),
33
+ ]
34
+ result = await GraphResult.from_stream(_make_iter(events))
35
+ assert len(result.node_events) == 2
36
+ assert result.node_events[0].kind == "on_tool_start"
37
+ assert result.node_events[1].kind == "on_tool_end"
38
+ assert result.tool_calls[0]["name"] == "search"
39
+
40
+
41
+ async def test_from_stream_prefers_chain_end_output():
42
+ from langchain_core.messages import AIMessage
43
+
44
+ msg = AIMessage(content="From output")
45
+ events = [
46
+ ("on_chat_model_stream", "model", {"chunk": _chunk("From stream")}),
47
+ ("on_chain_end", "LangGraph", {"output": {"messages": [msg]}}),
48
+ ]
49
+ result = await GraphResult.from_stream(_make_iter(events))
50
+ assert result.full_answer == "From output"