langgraph-celery 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_celery/__init__.py +5 -0
- langgraph_celery/bridge.py +32 -0
- langgraph_celery/checkpointing.py +30 -0
- langgraph_celery/emitter.py +18 -0
- langgraph_celery/events.py +84 -0
- langgraph_celery/interrupt.py +41 -0
- langgraph_celery/py.typed +0 -0
- langgraph_celery/streaming/__init__.py +0 -0
- langgraph_celery/streaming/redis.py +67 -0
- langgraph_celery/task.py +116 -0
- langgraph_celery-0.1.0.dist-info/METADATA +8 -0
- langgraph_celery-0.1.0.dist-info/RECORD +13 -0
- langgraph_celery-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def invoke(graph: Any, input: dict, config: dict | None = None) -> dict:
|
|
10
|
+
return await graph.ainvoke(input, config=config or {})
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
async def stream_events(
|
|
14
|
+
graph: Any,
|
|
15
|
+
input: dict,
|
|
16
|
+
config: dict | None = None,
|
|
17
|
+
) -> AsyncIterator[tuple[str, str, dict]]:
|
|
18
|
+
async for event in graph.astream_events(input, config=config or {}, version="v2"):
|
|
19
|
+
kind = event.get("event", "")
|
|
20
|
+
name = event.get("name", "")
|
|
21
|
+
data = event.get("data", {})
|
|
22
|
+
yield kind, name, data
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def run_sync(coro) -> Any:
|
|
26
|
+
try:
|
|
27
|
+
asyncio.get_running_loop()
|
|
28
|
+
except RuntimeError:
|
|
29
|
+
return asyncio.run(coro)
|
|
30
|
+
with ThreadPoolExecutor(max_workers=1) as pool:
|
|
31
|
+
future = pool.submit(asyncio.run, coro)
|
|
32
|
+
return future.result()
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def build_checkpointer(checkpointer: str) -> Any:
|
|
7
|
+
if checkpointer == "memory":
|
|
8
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
9
|
+
return MemorySaver()
|
|
10
|
+
if checkpointer == "postgres":
|
|
11
|
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
12
|
+
return AsyncPostgresSaver
|
|
13
|
+
raise ValueError(f"Unknown checkpointer: {checkpointer!r}. Use 'memory' or 'postgres'.")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def resolve_thread_id(thread_id_from: str, kwargs: dict) -> str:
|
|
17
|
+
value = kwargs.get(thread_id_from)
|
|
18
|
+
if value is None:
|
|
19
|
+
raise KeyError(
|
|
20
|
+
f"thread_id_from={thread_id_from!r} not found in task kwargs: {list(kwargs)}"
|
|
21
|
+
)
|
|
22
|
+
return f"task:{value}"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def inject_thread_id(config: dict, thread_id: str) -> dict:
|
|
26
|
+
config = dict(config)
|
|
27
|
+
configurable = dict(config.get("configurable", {}))
|
|
28
|
+
configurable["thread_id"] = thread_id
|
|
29
|
+
config["configurable"] = configurable
|
|
30
|
+
return config
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from langgraph_celery.bridge import stream_events
|
|
7
|
+
from langgraph_celery.events import GraphResult, NodeEvent
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class NodeEventEmitter:
|
|
11
|
+
def __init__(self, callback: Callable[[NodeEvent], None]) -> None:
|
|
12
|
+
self._callback = callback
|
|
13
|
+
|
|
14
|
+
async def run(self, graph: Any, input: dict, config: dict | None = None) -> GraphResult:
|
|
15
|
+
result = await GraphResult.from_stream(stream_events(graph, input, config))
|
|
16
|
+
for event in result.node_events:
|
|
17
|
+
self._callback(event)
|
|
18
|
+
return result
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncIterator
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class NodeEvent:
|
|
10
|
+
kind: str
|
|
11
|
+
node: str
|
|
12
|
+
data: dict[str, Any] = field(default_factory=dict)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class GraphResult:
|
|
17
|
+
output: dict[str, Any]
|
|
18
|
+
full_answer: str = ""
|
|
19
|
+
tool_calls: list[dict] = field(default_factory=list)
|
|
20
|
+
node_events: list[NodeEvent] = field(default_factory=list)
|
|
21
|
+
interrupted: bool = False
|
|
22
|
+
|
|
23
|
+
@classmethod
|
|
24
|
+
def from_output(cls, output: dict) -> GraphResult:
|
|
25
|
+
from langchain_core.messages import AIMessage
|
|
26
|
+
|
|
27
|
+
messages = output.get("messages", [])
|
|
28
|
+
full_answer = ""
|
|
29
|
+
tool_calls = []
|
|
30
|
+
|
|
31
|
+
for msg in reversed(messages):
|
|
32
|
+
if isinstance(msg, AIMessage):
|
|
33
|
+
if msg.content and not getattr(msg, "tool_calls", None):
|
|
34
|
+
full_answer = msg.content
|
|
35
|
+
break
|
|
36
|
+
if getattr(msg, "tool_calls", None):
|
|
37
|
+
tool_calls.extend(msg.tool_calls)
|
|
38
|
+
|
|
39
|
+
return cls(output=output, full_answer=full_answer, tool_calls=tool_calls)
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
async def from_stream(
|
|
43
|
+
cls,
|
|
44
|
+
event_iter: AsyncIterator[tuple[str, str, dict]],
|
|
45
|
+
) -> GraphResult:
|
|
46
|
+
tokens: list[str] = []
|
|
47
|
+
tool_calls: list[dict] = []
|
|
48
|
+
node_events: list[NodeEvent] = []
|
|
49
|
+
output: dict[str, Any] = {}
|
|
50
|
+
|
|
51
|
+
async for kind, name, data in event_iter:
|
|
52
|
+
if kind == "on_chat_model_stream":
|
|
53
|
+
chunk = data.get("chunk")
|
|
54
|
+
if chunk is not None:
|
|
55
|
+
content = getattr(chunk, "content", None)
|
|
56
|
+
if isinstance(content, str) and content:
|
|
57
|
+
tokens.append(content)
|
|
58
|
+
elif isinstance(content, list):
|
|
59
|
+
for part in content:
|
|
60
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
61
|
+
tokens.append(part.get("text", ""))
|
|
62
|
+
|
|
63
|
+
elif kind == "on_tool_start":
|
|
64
|
+
node_events.append(NodeEvent(kind=kind, node=name, data=data))
|
|
65
|
+
|
|
66
|
+
elif kind == "on_tool_end":
|
|
67
|
+
node_events.append(NodeEvent(kind=kind, node=name, data=data))
|
|
68
|
+
tool_input = data.get("input") or {}
|
|
69
|
+
tool_output = data.get("output")
|
|
70
|
+
tool_calls.append({"name": name, "input": tool_input, "output": tool_output})
|
|
71
|
+
|
|
72
|
+
elif kind == "on_chain_end" and name == "LangGraph":
|
|
73
|
+
output = data.get("output", {})
|
|
74
|
+
|
|
75
|
+
if not output and tokens:
|
|
76
|
+
output = {}
|
|
77
|
+
|
|
78
|
+
result = cls.from_output(output)
|
|
79
|
+
if tokens and not result.full_answer:
|
|
80
|
+
result.full_answer = "".join(tokens)
|
|
81
|
+
result.node_events = node_events
|
|
82
|
+
if not result.tool_calls:
|
|
83
|
+
result.tool_calls = tool_calls
|
|
84
|
+
return result
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from langgraph_celery.bridge import invoke, run_sync
|
|
6
|
+
from langgraph_celery.events import GraphResult
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def _run_with_interrupt(graph: Any, graph_input: Any, config: dict) -> GraphResult:
|
|
10
|
+
output = await invoke(graph, graph_input, config)
|
|
11
|
+
interrupts = output.get("__interrupt__")
|
|
12
|
+
if interrupts:
|
|
13
|
+
interrupt_values = [i.value for i in interrupts]
|
|
14
|
+
result = GraphResult(
|
|
15
|
+
output=output,
|
|
16
|
+
interrupted=True,
|
|
17
|
+
)
|
|
18
|
+
result.output["__interrupt_values__"] = interrupt_values
|
|
19
|
+
return result
|
|
20
|
+
return GraphResult.from_output(output)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def _resume(graph: Any, resume_value: Any, config: dict) -> GraphResult:
|
|
24
|
+
from langgraph.types import Command
|
|
25
|
+
|
|
26
|
+
output = await invoke(graph, Command(resume=resume_value), config)
|
|
27
|
+
interrupts = output.get("__interrupt__")
|
|
28
|
+
if interrupts:
|
|
29
|
+
interrupt_values = [i.value for i in interrupts]
|
|
30
|
+
result = GraphResult(output=output, interrupted=True)
|
|
31
|
+
result.output["__interrupt_values__"] = interrupt_values
|
|
32
|
+
return result
|
|
33
|
+
return GraphResult.from_output(output)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def run_with_interrupt(graph: Any, graph_input: Any, config: dict) -> GraphResult:
|
|
37
|
+
return run_sync(_run_with_interrupt(graph, graph_input, config))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def resume(graph: Any, resume_value: Any, config: dict) -> GraphResult:
|
|
41
|
+
return run_sync(_resume(graph, resume_value, config))
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from langgraph_celery.bridge import stream_events
|
|
7
|
+
from langgraph_celery.events import GraphResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RedisStreamer:
|
|
11
|
+
def __init__(self, redis_url: str, channel: str) -> None:
|
|
12
|
+
self._redis_url = redis_url
|
|
13
|
+
self._channel = channel
|
|
14
|
+
|
|
15
|
+
def _resolve_channel(self, kwargs: dict) -> str:
|
|
16
|
+
try:
|
|
17
|
+
return self._channel.format(**kwargs)
|
|
18
|
+
except KeyError:
|
|
19
|
+
return self._channel
|
|
20
|
+
|
|
21
|
+
async def run(
|
|
22
|
+
self,
|
|
23
|
+
graph: Any,
|
|
24
|
+
input: dict,
|
|
25
|
+
config: dict | None = None,
|
|
26
|
+
*,
|
|
27
|
+
task_kwargs: dict | None = None,
|
|
28
|
+
) -> GraphResult:
|
|
29
|
+
import redis.asyncio as aioredis
|
|
30
|
+
|
|
31
|
+
channel = self._resolve_channel(task_kwargs or {})
|
|
32
|
+
client = aioredis.from_url(self._redis_url)
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
result = await GraphResult.from_stream(
|
|
36
|
+
self._publish_stream(client, channel, graph, input, config)
|
|
37
|
+
)
|
|
38
|
+
finally:
|
|
39
|
+
await client.aclose()
|
|
40
|
+
|
|
41
|
+
return result
|
|
42
|
+
|
|
43
|
+
async def _publish_stream(
|
|
44
|
+
self, client: Any, channel: str, graph: Any, input: dict, config: dict | None
|
|
45
|
+
):
|
|
46
|
+
async for kind, name, data in stream_events(graph, input, config):
|
|
47
|
+
if kind == "on_chat_model_stream":
|
|
48
|
+
chunk = data.get("chunk")
|
|
49
|
+
if chunk is not None:
|
|
50
|
+
content = getattr(chunk, "content", None)
|
|
51
|
+
if isinstance(content, str) and content:
|
|
52
|
+
msg = json.dumps({"type": "token", "content": content})
|
|
53
|
+
await client.publish(channel, msg)
|
|
54
|
+
elif isinstance(content, list):
|
|
55
|
+
for part in content:
|
|
56
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
57
|
+
text = part.get("text", "")
|
|
58
|
+
if text:
|
|
59
|
+
msg = json.dumps({"type": "token", "content": text})
|
|
60
|
+
await client.publish(channel, msg)
|
|
61
|
+
elif kind == "on_tool_start":
|
|
62
|
+
await client.publish(channel, json.dumps({"type": "tool_start", "name": name}))
|
|
63
|
+
elif kind == "on_tool_end":
|
|
64
|
+
await client.publish(channel, json.dumps({"type": "tool_end", "name": name}))
|
|
65
|
+
elif kind == "on_chain_end" and name == "LangGraph":
|
|
66
|
+
await client.publish(channel, json.dumps({"type": "done"}))
|
|
67
|
+
yield kind, name, data
|
langgraph_celery/task.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import inspect
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from langgraph_celery.bridge import invoke, run_sync
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def task(
|
|
12
|
+
graph_factory: Callable[..., Any] | None = None,
|
|
13
|
+
*,
|
|
14
|
+
streaming: str | None = None,
|
|
15
|
+
channel: str | None = None,
|
|
16
|
+
checkpointer: str | None = None,
|
|
17
|
+
thread_id_from: str | None = None,
|
|
18
|
+
emit_node_events: bool = False,
|
|
19
|
+
on_node_event: Callable | None = None,
|
|
20
|
+
interrupt_before: list[str] | None = None,
|
|
21
|
+
):
|
|
22
|
+
def decorator(fn: Callable) -> Callable:
|
|
23
|
+
@functools.wraps(fn)
|
|
24
|
+
def wrapper(*args, **kwargs):
|
|
25
|
+
graph = _resolve_graph(graph_factory, fn, args, kwargs)
|
|
26
|
+
graph_input = _build_input(fn, args, kwargs)
|
|
27
|
+
config = _build_config(checkpointer, thread_id_from, kwargs, graph)
|
|
28
|
+
|
|
29
|
+
if streaming == "redis":
|
|
30
|
+
return run_sync(_run_redis(graph, graph_input, config, channel, kwargs))
|
|
31
|
+
|
|
32
|
+
if emit_node_events:
|
|
33
|
+
return run_sync(_run_emit(graph, graph_input, config, on_node_event))
|
|
34
|
+
|
|
35
|
+
if interrupt_before:
|
|
36
|
+
from langgraph_celery.interrupt import run_with_interrupt
|
|
37
|
+
|
|
38
|
+
return run_with_interrupt(graph, graph_input, config)
|
|
39
|
+
|
|
40
|
+
return run_sync(invoke(graph, graph_input, config))
|
|
41
|
+
|
|
42
|
+
wrapper._langgraph_celery = {
|
|
43
|
+
"graph_factory": graph_factory,
|
|
44
|
+
"streaming": streaming,
|
|
45
|
+
"channel": channel,
|
|
46
|
+
"checkpointer": checkpointer,
|
|
47
|
+
"thread_id_from": thread_id_from,
|
|
48
|
+
"emit_node_events": emit_node_events,
|
|
49
|
+
"interrupt_before": interrupt_before,
|
|
50
|
+
}
|
|
51
|
+
return wrapper
|
|
52
|
+
|
|
53
|
+
if callable(graph_factory):
|
|
54
|
+
fn = graph_factory
|
|
55
|
+
graph_factory = None
|
|
56
|
+
return decorator(fn)
|
|
57
|
+
|
|
58
|
+
return decorator
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _resolve_graph(graph_factory, fn, args, kwargs):
|
|
62
|
+
if graph_factory is not None:
|
|
63
|
+
sig = inspect.signature(graph_factory)
|
|
64
|
+
params = list(sig.parameters.keys())
|
|
65
|
+
bound_kwargs = {k: v for k, v in kwargs.items() if k in params}
|
|
66
|
+
return graph_factory(**bound_kwargs)
|
|
67
|
+
raise ValueError("graph_factory must be provided via task(graph=...)")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _build_input(fn, args, kwargs) -> dict:
|
|
71
|
+
sig = inspect.signature(fn)
|
|
72
|
+
bound = sig.bind(*args, **kwargs)
|
|
73
|
+
bound.apply_defaults()
|
|
74
|
+
return dict(bound.arguments)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _build_config(checkpointer_name, thread_id_from, kwargs, graph) -> dict:
|
|
78
|
+
config: dict = {}
|
|
79
|
+
|
|
80
|
+
if checkpointer_name is not None:
|
|
81
|
+
from langgraph_celery.checkpointing import (
|
|
82
|
+
build_checkpointer,
|
|
83
|
+
inject_thread_id,
|
|
84
|
+
resolve_thread_id,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
cp = build_checkpointer(checkpointer_name)
|
|
88
|
+
try:
|
|
89
|
+
graph.checkpointer = cp
|
|
90
|
+
except Exception:
|
|
91
|
+
pass
|
|
92
|
+
|
|
93
|
+
if thread_id_from:
|
|
94
|
+
thread_id = resolve_thread_id(thread_id_from, kwargs)
|
|
95
|
+
config = inject_thread_id(config, thread_id)
|
|
96
|
+
|
|
97
|
+
return config
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
async def _run_redis(graph, graph_input, config, channel, task_kwargs) -> Any:
|
|
101
|
+
from langgraph_celery.streaming.redis import RedisStreamer
|
|
102
|
+
|
|
103
|
+
if channel is None:
|
|
104
|
+
raise ValueError("channel must be set when streaming='redis'")
|
|
105
|
+
|
|
106
|
+
redis_url = task_kwargs.get("redis_url", "redis://localhost:6379")
|
|
107
|
+
streamer = RedisStreamer(redis_url=redis_url, channel=channel)
|
|
108
|
+
return await streamer.run(graph, graph_input, config, task_kwargs=task_kwargs)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
async def _run_emit(graph, graph_input, config, on_node_event) -> Any:
|
|
112
|
+
from langgraph_celery.emitter import NodeEventEmitter
|
|
113
|
+
|
|
114
|
+
callback = on_node_event if on_node_event is not None else lambda e: None
|
|
115
|
+
emitter = NodeEventEmitter(callback=callback)
|
|
116
|
+
return await emitter.run(graph, graph_input, config)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
langgraph_celery/__init__.py,sha256=bL5p51zqAw7C7mqazOunH-H4nFru_l6JYiQVJZJ-7nM,202
|
|
2
|
+
langgraph_celery/bridge.py,sha256=YfEYV5sIGHj-WCkGcVz7Bd4rZzqEazwkx3DPym6yv88,933
|
|
3
|
+
langgraph_celery/checkpointing.py,sha256=5XLrrjsa9umViY6FFT6CtCrT2NBC11A13EgvMbu0Hu0,993
|
|
4
|
+
langgraph_celery/emitter.py,sha256=jIzlZBvK9G9-egOFKm6895ibmMCfzyzPAM5HotwMMr8,612
|
|
5
|
+
langgraph_celery/events.py,sha256=8tE1eDNa_x3us8hImypNXBPoNjxrMCwUVGPvF1IY2NE,2901
|
|
6
|
+
langgraph_celery/interrupt.py,sha256=AkrvPl0P94_SgTXxQAz-hxO9okK41-_YJajg7-R3CP4,1447
|
|
7
|
+
langgraph_celery/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
langgraph_celery/task.py,sha256=XUf70IbQKwynGOFh7ExlOcHlqX8XSiX6JauLo1iXFYc,3783
|
|
9
|
+
langgraph_celery/streaming/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
langgraph_celery/streaming/redis.py,sha256=4PVn2bsnPiX-D4E73-wGn0fr1wlMI3zuJk9SomUeQbk,2547
|
|
11
|
+
langgraph_celery-0.1.0.dist-info/METADATA,sha256=jrZcQQQUaJr8tp0bgBERT8b-4u_epqzqq91A_G68qwA,248
|
|
12
|
+
langgraph_celery-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
|
|
13
|
+
langgraph_celery-0.1.0.dist-info/RECORD,,
|