duragraph-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- duragraph/__init__.py +35 -0
- duragraph/cli/__init__.py +5 -0
- duragraph/cli/main.py +163 -0
- duragraph/edges.py +116 -0
- duragraph/graph.py +429 -0
- duragraph/nodes.py +252 -0
- duragraph/prompts/__init__.py +6 -0
- duragraph/prompts/decorators.py +43 -0
- duragraph/prompts/store.py +171 -0
- duragraph/py.typed +0 -0
- duragraph/types.py +100 -0
- duragraph/worker/__init__.py +5 -0
- duragraph/worker/worker.py +327 -0
- duragraph_python-0.1.0.dist-info/METADATA +224 -0
- duragraph_python-0.1.0.dist-info/RECORD +18 -0
- duragraph_python-0.1.0.dist-info/WHEEL +4 -0
- duragraph_python-0.1.0.dist-info/entry_points.txt +2 -0
- duragraph_python-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Prompt decorators for DuraGraph."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from typing import Any, TypeVar
|
|
6
|
+
|
|
7
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def prompt(
|
|
11
|
+
prompt_id: str,
|
|
12
|
+
*,
|
|
13
|
+
version: str | None = None,
|
|
14
|
+
variant: str | None = None,
|
|
15
|
+
) -> Callable[[F], F]:
|
|
16
|
+
"""Decorator to attach a prompt from the prompt store to a node.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
prompt_id: Identifier for the prompt (e.g., "support/classify_intent").
|
|
20
|
+
version: Optional specific version (e.g., "2.1.0"). Defaults to latest.
|
|
21
|
+
variant: Optional A/B test variant.
|
|
22
|
+
|
|
23
|
+
Example:
|
|
24
|
+
@llm_node(model="gpt-4o-mini")
|
|
25
|
+
@prompt("support/classify_intent", version="2.1.0")
|
|
26
|
+
def classify(self, state):
|
|
27
|
+
return state
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def decorator(func: F) -> F:
|
|
31
|
+
@wraps(func)
|
|
32
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
33
|
+
return func(*args, **kwargs)
|
|
34
|
+
|
|
35
|
+
# Attach prompt metadata
|
|
36
|
+
wrapper._prompt_metadata = { # type: ignore
|
|
37
|
+
"prompt_id": prompt_id,
|
|
38
|
+
"version": version,
|
|
39
|
+
"variant": variant,
|
|
40
|
+
}
|
|
41
|
+
return wrapper # type: ignore
|
|
42
|
+
|
|
43
|
+
return decorator
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
"""Prompt store client for DuraGraph."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class PromptStore:
|
|
9
|
+
"""Client for interacting with the DuraGraph Prompt Store."""
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
base_url: str,
|
|
14
|
+
*,
|
|
15
|
+
api_key: str | None = None,
|
|
16
|
+
):
|
|
17
|
+
"""Initialize prompt store client.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
base_url: URL of the prompt store API.
|
|
21
|
+
api_key: Optional API key for authentication.
|
|
22
|
+
"""
|
|
23
|
+
self.base_url = base_url.rstrip("/")
|
|
24
|
+
self.api_key = api_key
|
|
25
|
+
self._client = httpx.Client(timeout=30.0)
|
|
26
|
+
|
|
27
|
+
def _headers(self) -> dict[str, str]:
|
|
28
|
+
"""Get request headers."""
|
|
29
|
+
headers = {"Content-Type": "application/json"}
|
|
30
|
+
if self.api_key:
|
|
31
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
32
|
+
return headers
|
|
33
|
+
|
|
34
|
+
def get_prompt(
|
|
35
|
+
self,
|
|
36
|
+
prompt_id: str,
|
|
37
|
+
*,
|
|
38
|
+
version: str | None = None,
|
|
39
|
+
variant: str | None = None,
|
|
40
|
+
) -> dict[str, Any]:
|
|
41
|
+
"""Get a prompt from the store.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
prompt_id: Prompt identifier.
|
|
45
|
+
version: Optional version (default: latest).
|
|
46
|
+
variant: Optional A/B variant.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Prompt data including content and metadata.
|
|
50
|
+
"""
|
|
51
|
+
params: dict[str, str] = {}
|
|
52
|
+
if version:
|
|
53
|
+
params["version"] = version
|
|
54
|
+
if variant:
|
|
55
|
+
params["variant"] = variant
|
|
56
|
+
|
|
57
|
+
response = self._client.get(
|
|
58
|
+
f"{self.base_url}/api/v1/prompts/{prompt_id}",
|
|
59
|
+
headers=self._headers(),
|
|
60
|
+
params=params,
|
|
61
|
+
)
|
|
62
|
+
response.raise_for_status()
|
|
63
|
+
return response.json()
|
|
64
|
+
|
|
65
|
+
def list_prompts(
|
|
66
|
+
self,
|
|
67
|
+
*,
|
|
68
|
+
namespace: str | None = None,
|
|
69
|
+
tag: str | None = None,
|
|
70
|
+
) -> list[dict[str, Any]]:
|
|
71
|
+
"""List prompts in the store.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
namespace: Optional namespace filter.
|
|
75
|
+
tag: Optional tag filter.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
List of prompt metadata.
|
|
79
|
+
"""
|
|
80
|
+
params: dict[str, str] = {}
|
|
81
|
+
if namespace:
|
|
82
|
+
params["namespace"] = namespace
|
|
83
|
+
if tag:
|
|
84
|
+
params["tag"] = tag
|
|
85
|
+
|
|
86
|
+
response = self._client.get(
|
|
87
|
+
f"{self.base_url}/api/v1/prompts",
|
|
88
|
+
headers=self._headers(),
|
|
89
|
+
params=params,
|
|
90
|
+
)
|
|
91
|
+
response.raise_for_status()
|
|
92
|
+
return response.json()["prompts"]
|
|
93
|
+
|
|
94
|
+
def create_prompt(
|
|
95
|
+
self,
|
|
96
|
+
prompt_id: str,
|
|
97
|
+
content: str,
|
|
98
|
+
*,
|
|
99
|
+
description: str | None = None,
|
|
100
|
+
tags: list[str] | None = None,
|
|
101
|
+
metadata: dict[str, Any] | None = None,
|
|
102
|
+
) -> dict[str, Any]:
|
|
103
|
+
"""Create a new prompt.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
prompt_id: Prompt identifier.
|
|
107
|
+
content: Prompt content template.
|
|
108
|
+
description: Optional description.
|
|
109
|
+
tags: Optional tags for categorization.
|
|
110
|
+
metadata: Optional additional metadata.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Created prompt data.
|
|
114
|
+
"""
|
|
115
|
+
payload = {
|
|
116
|
+
"prompt_id": prompt_id,
|
|
117
|
+
"content": content,
|
|
118
|
+
}
|
|
119
|
+
if description:
|
|
120
|
+
payload["description"] = description
|
|
121
|
+
if tags:
|
|
122
|
+
payload["tags"] = tags
|
|
123
|
+
if metadata:
|
|
124
|
+
payload["metadata"] = metadata
|
|
125
|
+
|
|
126
|
+
response = self._client.post(
|
|
127
|
+
f"{self.base_url}/api/v1/prompts",
|
|
128
|
+
headers=self._headers(),
|
|
129
|
+
json=payload,
|
|
130
|
+
)
|
|
131
|
+
response.raise_for_status()
|
|
132
|
+
return response.json()
|
|
133
|
+
|
|
134
|
+
def create_version(
|
|
135
|
+
self,
|
|
136
|
+
prompt_id: str,
|
|
137
|
+
content: str,
|
|
138
|
+
*,
|
|
139
|
+
change_log: str | None = None,
|
|
140
|
+
) -> dict[str, Any]:
|
|
141
|
+
"""Create a new version of an existing prompt.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
prompt_id: Prompt identifier.
|
|
145
|
+
content: New prompt content.
|
|
146
|
+
change_log: Optional change description.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
New version data.
|
|
150
|
+
"""
|
|
151
|
+
payload = {"content": content}
|
|
152
|
+
if change_log:
|
|
153
|
+
payload["change_log"] = change_log
|
|
154
|
+
|
|
155
|
+
response = self._client.post(
|
|
156
|
+
f"{self.base_url}/api/v1/prompts/{prompt_id}/versions",
|
|
157
|
+
headers=self._headers(),
|
|
158
|
+
json=payload,
|
|
159
|
+
)
|
|
160
|
+
response.raise_for_status()
|
|
161
|
+
return response.json()
|
|
162
|
+
|
|
163
|
+
def close(self) -> None:
|
|
164
|
+
"""Close the HTTP client."""
|
|
165
|
+
self._client.close()
|
|
166
|
+
|
|
167
|
+
def __enter__(self) -> "PromptStore":
|
|
168
|
+
return self
|
|
169
|
+
|
|
170
|
+
def __exit__(self, *args: Any) -> None:
|
|
171
|
+
self.close()
|
duragraph/py.typed
ADDED
|
File without changes
|
duragraph/types.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Type definitions for DuraGraph SDK."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, TypedDict, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
# State is a dictionary that flows through the graph
|
|
8
|
+
State = dict[str, Any]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Message(BaseModel):
|
|
12
|
+
"""Base message type."""
|
|
13
|
+
|
|
14
|
+
role: Literal["human", "assistant", "tool", "system"]
|
|
15
|
+
content: str
|
|
16
|
+
name: str | None = None
|
|
17
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class HumanMessage(Message):
|
|
21
|
+
"""Message from a human user."""
|
|
22
|
+
|
|
23
|
+
role: Literal["human"] = "human"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AIMessage(Message):
|
|
27
|
+
"""Message from an AI assistant."""
|
|
28
|
+
|
|
29
|
+
role: Literal["assistant"] = "assistant"
|
|
30
|
+
tool_calls: list[dict[str, Any]] | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ToolMessage(Message):
|
|
34
|
+
"""Result from a tool call."""
|
|
35
|
+
|
|
36
|
+
role: Literal["tool"] = "tool"
|
|
37
|
+
tool_call_id: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SystemMessage(Message):
|
|
41
|
+
"""System message for LLM context."""
|
|
42
|
+
|
|
43
|
+
role: Literal["system"] = "system"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Union of all message types
|
|
47
|
+
AnyMessage = Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class NodeConfig(TypedDict, total=False):
|
|
51
|
+
"""Configuration for a node."""
|
|
52
|
+
|
|
53
|
+
model: str
|
|
54
|
+
temperature: float
|
|
55
|
+
max_tokens: int
|
|
56
|
+
system_prompt: str
|
|
57
|
+
tools: list[str]
|
|
58
|
+
stream: bool
|
|
59
|
+
retry_on: list[str]
|
|
60
|
+
max_retries: int
|
|
61
|
+
retry_delay: float
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class GraphConfig(TypedDict, total=False):
|
|
65
|
+
"""Configuration for graph execution."""
|
|
66
|
+
|
|
67
|
+
checkpoint_id: str
|
|
68
|
+
stream_mode: list[Literal["values", "updates", "messages", "events"]]
|
|
69
|
+
recursion_limit: int
|
|
70
|
+
timeout: float
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class RunResult(BaseModel):
|
|
74
|
+
"""Result of a graph execution."""
|
|
75
|
+
|
|
76
|
+
run_id: str
|
|
77
|
+
status: Literal["completed", "failed", "interrupted", "cancelled"]
|
|
78
|
+
output: dict[str, Any]
|
|
79
|
+
error: str | None = None
|
|
80
|
+
nodes_executed: list[str] = Field(default_factory=list)
|
|
81
|
+
tokens: dict[str, int] | None = None
|
|
82
|
+
duration_ms: float | None = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class Event(BaseModel):
|
|
86
|
+
"""Streaming event from graph execution."""
|
|
87
|
+
|
|
88
|
+
type: Literal[
|
|
89
|
+
"run_started",
|
|
90
|
+
"run_completed",
|
|
91
|
+
"run_failed",
|
|
92
|
+
"node_started",
|
|
93
|
+
"node_completed",
|
|
94
|
+
"token",
|
|
95
|
+
"checkpoint",
|
|
96
|
+
]
|
|
97
|
+
run_id: str
|
|
98
|
+
node_id: str | None = None
|
|
99
|
+
data: dict[str, Any] = Field(default_factory=dict)
|
|
100
|
+
timestamp: str
|
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
"""Worker implementation for DuraGraph control plane."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import signal
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any
|
|
7
|
+
from uuid import uuid4
|
|
8
|
+
|
|
9
|
+
import httpx
|
|
10
|
+
|
|
11
|
+
from duragraph.graph import GraphDefinition
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Worker:
|
|
15
|
+
"""Worker that connects to DuraGraph control plane and executes graphs."""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
control_plane_url: str,
|
|
20
|
+
*,
|
|
21
|
+
name: str | None = None,
|
|
22
|
+
capabilities: list[str] | None = None,
|
|
23
|
+
poll_interval: float = 1.0,
|
|
24
|
+
):
|
|
25
|
+
"""Initialize worker.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
control_plane_url: URL of the DuraGraph control plane.
|
|
29
|
+
name: Optional name for this worker.
|
|
30
|
+
capabilities: Optional list of capabilities (e.g., ["openai", "tools"]).
|
|
31
|
+
poll_interval: Interval in seconds between polling for work.
|
|
32
|
+
"""
|
|
33
|
+
self.control_plane_url = control_plane_url.rstrip("/")
|
|
34
|
+
self.name = name or f"worker-{uuid4().hex[:8]}"
|
|
35
|
+
self.capabilities = capabilities or []
|
|
36
|
+
self.poll_interval = poll_interval
|
|
37
|
+
|
|
38
|
+
self._worker_id: str | None = None
|
|
39
|
+
self._graphs: dict[str, GraphDefinition] = {}
|
|
40
|
+
self._executors: dict[str, Callable[..., Any]] = {}
|
|
41
|
+
self._running = False
|
|
42
|
+
self._client: httpx.AsyncClient | None = None
|
|
43
|
+
|
|
44
|
+
def register_graph(
|
|
45
|
+
self,
|
|
46
|
+
definition: GraphDefinition,
|
|
47
|
+
executor: Callable[..., Any] | None = None,
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Register a graph definition with this worker.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
definition: The graph definition to register.
|
|
53
|
+
executor: Optional custom executor function.
|
|
54
|
+
"""
|
|
55
|
+
self._graphs[definition.graph_id] = definition
|
|
56
|
+
if executor:
|
|
57
|
+
self._executors[definition.graph_id] = executor
|
|
58
|
+
|
|
59
|
+
async def _register_with_control_plane(self) -> str:
|
|
60
|
+
"""Register this worker with the control plane."""
|
|
61
|
+
if self._client is None:
|
|
62
|
+
self._client = httpx.AsyncClient(timeout=30.0)
|
|
63
|
+
|
|
64
|
+
# Prepare graph definitions
|
|
65
|
+
graphs = [
|
|
66
|
+
{"graph_id": g.graph_id, "definition": g.to_ir()}
|
|
67
|
+
for g in self._graphs.values()
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
payload = {
|
|
71
|
+
"name": self.name,
|
|
72
|
+
"capabilities": self.capabilities,
|
|
73
|
+
"graphs": graphs,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
response = await self._client.post(
|
|
77
|
+
f"{self.control_plane_url}/api/v1/workers/register",
|
|
78
|
+
json=payload,
|
|
79
|
+
)
|
|
80
|
+
response.raise_for_status()
|
|
81
|
+
|
|
82
|
+
data = response.json()
|
|
83
|
+
return data["worker_id"]
|
|
84
|
+
|
|
85
|
+
async def _poll_for_work(self) -> dict[str, Any] | None:
|
|
86
|
+
"""Poll the control plane for work."""
|
|
87
|
+
if self._client is None or self._worker_id is None:
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
response = await self._client.get(
|
|
92
|
+
f"{self.control_plane_url}/api/v1/workers/{self._worker_id}/poll",
|
|
93
|
+
)
|
|
94
|
+
if response.status_code == 204:
|
|
95
|
+
return None
|
|
96
|
+
response.raise_for_status()
|
|
97
|
+
return response.json()
|
|
98
|
+
except httpx.HTTPStatusError as e:
|
|
99
|
+
if e.response.status_code == 404:
|
|
100
|
+
# Worker not found, re-register
|
|
101
|
+
self._worker_id = await self._register_with_control_plane()
|
|
102
|
+
return None
|
|
103
|
+
except Exception:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
async def _execute_run(self, work: dict[str, Any]) -> None:
|
|
107
|
+
"""Execute a run from the control plane."""
|
|
108
|
+
run_id = work.get("run_id")
|
|
109
|
+
graph_id = work.get("graph_id")
|
|
110
|
+
input_data = work.get("input", {})
|
|
111
|
+
thread_id = work.get("thread_id")
|
|
112
|
+
|
|
113
|
+
if not run_id or not graph_id:
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
# Find the graph definition
|
|
117
|
+
graph_def = self._graphs.get(graph_id)
|
|
118
|
+
if not graph_def:
|
|
119
|
+
await self._send_event(run_id, "run_failed", {
|
|
120
|
+
"error": f"Graph '{graph_id}' not registered with this worker",
|
|
121
|
+
})
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
# Start the run
|
|
125
|
+
await self._send_event(run_id, "run_started", {"thread_id": thread_id})
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
# Execute nodes
|
|
129
|
+
state = input_data.copy()
|
|
130
|
+
current_node = graph_def.entrypoint
|
|
131
|
+
|
|
132
|
+
while current_node:
|
|
133
|
+
await self._send_event(run_id, "node_started", {
|
|
134
|
+
"node_id": current_node,
|
|
135
|
+
})
|
|
136
|
+
|
|
137
|
+
# Get node metadata
|
|
138
|
+
node_meta = graph_def.nodes.get(current_node)
|
|
139
|
+
if not node_meta:
|
|
140
|
+
raise ValueError(f"Node '{current_node}' not found")
|
|
141
|
+
|
|
142
|
+
# Execute based on node type
|
|
143
|
+
if node_meta.node_type == "llm":
|
|
144
|
+
result = await self._execute_llm_node(node_meta, state)
|
|
145
|
+
elif node_meta.node_type == "tool":
|
|
146
|
+
result = await self._execute_tool_node(node_meta, state)
|
|
147
|
+
elif node_meta.node_type == "human":
|
|
148
|
+
result = await self._execute_human_node(
|
|
149
|
+
run_id, node_meta, state
|
|
150
|
+
)
|
|
151
|
+
if result is None:
|
|
152
|
+
# Interrupted, waiting for human input
|
|
153
|
+
return
|
|
154
|
+
else:
|
|
155
|
+
# Default function node - just pass through
|
|
156
|
+
result = state
|
|
157
|
+
|
|
158
|
+
if isinstance(result, dict):
|
|
159
|
+
state.update(result)
|
|
160
|
+
|
|
161
|
+
await self._send_event(run_id, "node_completed", {
|
|
162
|
+
"node_id": current_node,
|
|
163
|
+
"output": result,
|
|
164
|
+
})
|
|
165
|
+
|
|
166
|
+
# Find next node
|
|
167
|
+
next_node = None
|
|
168
|
+
for edge in graph_def.edges:
|
|
169
|
+
if edge.source == current_node:
|
|
170
|
+
if isinstance(edge.target, str):
|
|
171
|
+
next_node = edge.target
|
|
172
|
+
elif isinstance(edge.target, dict):
|
|
173
|
+
if isinstance(result, str) and result in edge.target:
|
|
174
|
+
next_node = edge.target[result]
|
|
175
|
+
break
|
|
176
|
+
|
|
177
|
+
current_node = next_node
|
|
178
|
+
|
|
179
|
+
# Run completed
|
|
180
|
+
await self._send_event(run_id, "run_completed", {
|
|
181
|
+
"output": state,
|
|
182
|
+
"thread_id": thread_id,
|
|
183
|
+
})
|
|
184
|
+
|
|
185
|
+
except Exception as e:
|
|
186
|
+
await self._send_event(run_id, "run_failed", {
|
|
187
|
+
"error": str(e),
|
|
188
|
+
"thread_id": thread_id,
|
|
189
|
+
})
|
|
190
|
+
|
|
191
|
+
async def _execute_llm_node(
|
|
192
|
+
self,
|
|
193
|
+
node_meta: Any,
|
|
194
|
+
state: dict[str, Any],
|
|
195
|
+
) -> dict[str, Any]:
|
|
196
|
+
"""Execute an LLM node."""
|
|
197
|
+
# Placeholder - would integrate with LLM providers
|
|
198
|
+
config = node_meta.config
|
|
199
|
+
model = config.get("model", "gpt-4o-mini")
|
|
200
|
+
|
|
201
|
+
# For now, just echo the state
|
|
202
|
+
return {"llm_response": f"[{model}] Processed state"}
|
|
203
|
+
|
|
204
|
+
async def _execute_tool_node(
|
|
205
|
+
self,
|
|
206
|
+
node_meta: Any,
|
|
207
|
+
state: dict[str, Any],
|
|
208
|
+
) -> dict[str, Any]:
|
|
209
|
+
"""Execute a tool node."""
|
|
210
|
+
# Placeholder - would execute registered tools
|
|
211
|
+
return state
|
|
212
|
+
|
|
213
|
+
async def _execute_human_node(
|
|
214
|
+
self,
|
|
215
|
+
run_id: str,
|
|
216
|
+
node_meta: Any,
|
|
217
|
+
state: dict[str, Any],
|
|
218
|
+
) -> dict[str, Any] | None:
|
|
219
|
+
"""Execute a human-in-the-loop node."""
|
|
220
|
+
config = node_meta.config
|
|
221
|
+
prompt = config.get("prompt", "Please review")
|
|
222
|
+
|
|
223
|
+
# Signal that human input is required
|
|
224
|
+
await self._send_event(run_id, "run_requires_action", {
|
|
225
|
+
"action_type": "human_review",
|
|
226
|
+
"prompt": prompt,
|
|
227
|
+
"state": state,
|
|
228
|
+
})
|
|
229
|
+
|
|
230
|
+
# Return None to indicate the run is waiting
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
async def _send_event(
|
|
234
|
+
self,
|
|
235
|
+
run_id: str,
|
|
236
|
+
event_type: str,
|
|
237
|
+
data: dict[str, Any],
|
|
238
|
+
) -> None:
|
|
239
|
+
"""Send an event to the control plane."""
|
|
240
|
+
if self._client is None or self._worker_id is None:
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
payload = {
|
|
244
|
+
"run_id": run_id,
|
|
245
|
+
"event_type": event_type,
|
|
246
|
+
"data": data,
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
try:
|
|
250
|
+
response = await self._client.post(
|
|
251
|
+
f"{self.control_plane_url}/api/v1/workers/{self._worker_id}/events",
|
|
252
|
+
json=payload,
|
|
253
|
+
)
|
|
254
|
+
response.raise_for_status()
|
|
255
|
+
except Exception:
|
|
256
|
+
pass # Best effort
|
|
257
|
+
|
|
258
|
+
async def _heartbeat(self) -> None:
|
|
259
|
+
"""Send heartbeat to control plane."""
|
|
260
|
+
if self._client is None or self._worker_id is None:
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
await self._client.post(
|
|
265
|
+
f"{self.control_plane_url}/api/v1/workers/{self._worker_id}/heartbeat",
|
|
266
|
+
)
|
|
267
|
+
except Exception:
|
|
268
|
+
pass
|
|
269
|
+
|
|
270
|
+
async def _run_loop(self) -> None:
|
|
271
|
+
"""Main worker loop."""
|
|
272
|
+
self._running = True
|
|
273
|
+
|
|
274
|
+
# Register with control plane
|
|
275
|
+
print(f"Registering worker '{self.name}' with control plane...")
|
|
276
|
+
self._worker_id = await self._register_with_control_plane()
|
|
277
|
+
print(f"Registered with worker_id: {self._worker_id}")
|
|
278
|
+
|
|
279
|
+
heartbeat_counter = 0
|
|
280
|
+
|
|
281
|
+
while self._running:
|
|
282
|
+
# Poll for work
|
|
283
|
+
work = await self._poll_for_work()
|
|
284
|
+
if work:
|
|
285
|
+
print(f"Received work: {work.get('run_id')}")
|
|
286
|
+
await self._execute_run(work)
|
|
287
|
+
|
|
288
|
+
# Periodic heartbeat
|
|
289
|
+
heartbeat_counter += 1
|
|
290
|
+
if heartbeat_counter >= 30: # Every 30 poll intervals
|
|
291
|
+
await self._heartbeat()
|
|
292
|
+
heartbeat_counter = 0
|
|
293
|
+
|
|
294
|
+
await asyncio.sleep(self.poll_interval)
|
|
295
|
+
|
|
296
|
+
def run(self) -> None:
|
|
297
|
+
"""Run the worker (blocking)."""
|
|
298
|
+
loop = asyncio.new_event_loop()
|
|
299
|
+
asyncio.set_event_loop(loop)
|
|
300
|
+
|
|
301
|
+
# Handle shutdown signals
|
|
302
|
+
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
303
|
+
loop.add_signal_handler(sig, self._shutdown)
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
loop.run_until_complete(self._run_loop())
|
|
307
|
+
finally:
|
|
308
|
+
if self._client:
|
|
309
|
+
loop.run_until_complete(self._client.aclose())
|
|
310
|
+
loop.close()
|
|
311
|
+
|
|
312
|
+
async def arun(self) -> None:
|
|
313
|
+
"""Run the worker asynchronously."""
|
|
314
|
+
try:
|
|
315
|
+
await self._run_loop()
|
|
316
|
+
finally:
|
|
317
|
+
if self._client:
|
|
318
|
+
await self._client.aclose()
|
|
319
|
+
|
|
320
|
+
def _shutdown(self) -> None:
|
|
321
|
+
"""Shutdown the worker."""
|
|
322
|
+
print("\nShutting down worker...")
|
|
323
|
+
self._running = False
|
|
324
|
+
|
|
325
|
+
def stop(self) -> None:
|
|
326
|
+
"""Stop the worker."""
|
|
327
|
+
self._running = False
|