duragraph-python 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,43 @@
1
+ """Prompt decorators for DuraGraph."""
2
+
3
+ from collections.abc import Callable
4
+ from functools import wraps
5
+ from typing import Any, TypeVar
6
+
7
+ F = TypeVar("F", bound=Callable[..., Any])
8
+
9
+
10
+ def prompt(
11
+ prompt_id: str,
12
+ *,
13
+ version: str | None = None,
14
+ variant: str | None = None,
15
+ ) -> Callable[[F], F]:
16
+ """Decorator to attach a prompt from the prompt store to a node.
17
+
18
+ Args:
19
+ prompt_id: Identifier for the prompt (e.g., "support/classify_intent").
20
+ version: Optional specific version (e.g., "2.1.0"). Defaults to latest.
21
+ variant: Optional A/B test variant.
22
+
23
+ Example:
24
+ @llm_node(model="gpt-4o-mini")
25
+ @prompt("support/classify_intent", version="2.1.0")
26
+ def classify(self, state):
27
+ return state
28
+ """
29
+
30
+ def decorator(func: F) -> F:
31
+ @wraps(func)
32
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
33
+ return func(*args, **kwargs)
34
+
35
+ # Attach prompt metadata
36
+ wrapper._prompt_metadata = { # type: ignore
37
+ "prompt_id": prompt_id,
38
+ "version": version,
39
+ "variant": variant,
40
+ }
41
+ return wrapper # type: ignore
42
+
43
+ return decorator
@@ -0,0 +1,171 @@
1
+ """Prompt store client for DuraGraph."""
2
+
3
+ from typing import Any
4
+
5
+ import httpx
6
+
7
+
8
+ class PromptStore:
9
+ """Client for interacting with the DuraGraph Prompt Store."""
10
+
11
+ def __init__(
12
+ self,
13
+ base_url: str,
14
+ *,
15
+ api_key: str | None = None,
16
+ ):
17
+ """Initialize prompt store client.
18
+
19
+ Args:
20
+ base_url: URL of the prompt store API.
21
+ api_key: Optional API key for authentication.
22
+ """
23
+ self.base_url = base_url.rstrip("/")
24
+ self.api_key = api_key
25
+ self._client = httpx.Client(timeout=30.0)
26
+
27
+ def _headers(self) -> dict[str, str]:
28
+ """Get request headers."""
29
+ headers = {"Content-Type": "application/json"}
30
+ if self.api_key:
31
+ headers["Authorization"] = f"Bearer {self.api_key}"
32
+ return headers
33
+
34
+ def get_prompt(
35
+ self,
36
+ prompt_id: str,
37
+ *,
38
+ version: str | None = None,
39
+ variant: str | None = None,
40
+ ) -> dict[str, Any]:
41
+ """Get a prompt from the store.
42
+
43
+ Args:
44
+ prompt_id: Prompt identifier.
45
+ version: Optional version (default: latest).
46
+ variant: Optional A/B variant.
47
+
48
+ Returns:
49
+ Prompt data including content and metadata.
50
+ """
51
+ params: dict[str, str] = {}
52
+ if version:
53
+ params["version"] = version
54
+ if variant:
55
+ params["variant"] = variant
56
+
57
+ response = self._client.get(
58
+ f"{self.base_url}/api/v1/prompts/{prompt_id}",
59
+ headers=self._headers(),
60
+ params=params,
61
+ )
62
+ response.raise_for_status()
63
+ return response.json()
64
+
65
+ def list_prompts(
66
+ self,
67
+ *,
68
+ namespace: str | None = None,
69
+ tag: str | None = None,
70
+ ) -> list[dict[str, Any]]:
71
+ """List prompts in the store.
72
+
73
+ Args:
74
+ namespace: Optional namespace filter.
75
+ tag: Optional tag filter.
76
+
77
+ Returns:
78
+ List of prompt metadata.
79
+ """
80
+ params: dict[str, str] = {}
81
+ if namespace:
82
+ params["namespace"] = namespace
83
+ if tag:
84
+ params["tag"] = tag
85
+
86
+ response = self._client.get(
87
+ f"{self.base_url}/api/v1/prompts",
88
+ headers=self._headers(),
89
+ params=params,
90
+ )
91
+ response.raise_for_status()
92
+ return response.json()["prompts"]
93
+
94
+ def create_prompt(
95
+ self,
96
+ prompt_id: str,
97
+ content: str,
98
+ *,
99
+ description: str | None = None,
100
+ tags: list[str] | None = None,
101
+ metadata: dict[str, Any] | None = None,
102
+ ) -> dict[str, Any]:
103
+ """Create a new prompt.
104
+
105
+ Args:
106
+ prompt_id: Prompt identifier.
107
+ content: Prompt content template.
108
+ description: Optional description.
109
+ tags: Optional tags for categorization.
110
+ metadata: Optional additional metadata.
111
+
112
+ Returns:
113
+ Created prompt data.
114
+ """
115
+ payload = {
116
+ "prompt_id": prompt_id,
117
+ "content": content,
118
+ }
119
+ if description:
120
+ payload["description"] = description
121
+ if tags:
122
+ payload["tags"] = tags
123
+ if metadata:
124
+ payload["metadata"] = metadata
125
+
126
+ response = self._client.post(
127
+ f"{self.base_url}/api/v1/prompts",
128
+ headers=self._headers(),
129
+ json=payload,
130
+ )
131
+ response.raise_for_status()
132
+ return response.json()
133
+
134
+ def create_version(
135
+ self,
136
+ prompt_id: str,
137
+ content: str,
138
+ *,
139
+ change_log: str | None = None,
140
+ ) -> dict[str, Any]:
141
+ """Create a new version of an existing prompt.
142
+
143
+ Args:
144
+ prompt_id: Prompt identifier.
145
+ content: New prompt content.
146
+ change_log: Optional change description.
147
+
148
+ Returns:
149
+ New version data.
150
+ """
151
+ payload = {"content": content}
152
+ if change_log:
153
+ payload["change_log"] = change_log
154
+
155
+ response = self._client.post(
156
+ f"{self.base_url}/api/v1/prompts/{prompt_id}/versions",
157
+ headers=self._headers(),
158
+ json=payload,
159
+ )
160
+ response.raise_for_status()
161
+ return response.json()
162
+
163
+ def close(self) -> None:
164
+ """Close the HTTP client."""
165
+ self._client.close()
166
+
167
+ def __enter__(self) -> "PromptStore":
168
+ return self
169
+
170
+ def __exit__(self, *args: Any) -> None:
171
+ self.close()
duragraph/py.typed ADDED
File without changes
duragraph/types.py ADDED
@@ -0,0 +1,100 @@
1
+ """Type definitions for DuraGraph SDK."""
2
+
3
+ from typing import Any, Literal, TypedDict, Union
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ # State is a dictionary that flows through the graph
8
+ State = dict[str, Any]
9
+
10
+
11
+ class Message(BaseModel):
12
+ """Base message type."""
13
+
14
+ role: Literal["human", "assistant", "tool", "system"]
15
+ content: str
16
+ name: str | None = None
17
+ metadata: dict[str, Any] = Field(default_factory=dict)
18
+
19
+
20
+ class HumanMessage(Message):
21
+ """Message from a human user."""
22
+
23
+ role: Literal["human"] = "human"
24
+
25
+
26
+ class AIMessage(Message):
27
+ """Message from an AI assistant."""
28
+
29
+ role: Literal["assistant"] = "assistant"
30
+ tool_calls: list[dict[str, Any]] | None = None
31
+
32
+
33
+ class ToolMessage(Message):
34
+ """Result from a tool call."""
35
+
36
+ role: Literal["tool"] = "tool"
37
+ tool_call_id: str
38
+
39
+
40
+ class SystemMessage(Message):
41
+ """System message for LLM context."""
42
+
43
+ role: Literal["system"] = "system"
44
+
45
+
46
+ # Union of all message types
47
+ AnyMessage = Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]
48
+
49
+
50
+ class NodeConfig(TypedDict, total=False):
51
+ """Configuration for a node."""
52
+
53
+ model: str
54
+ temperature: float
55
+ max_tokens: int
56
+ system_prompt: str
57
+ tools: list[str]
58
+ stream: bool
59
+ retry_on: list[str]
60
+ max_retries: int
61
+ retry_delay: float
62
+
63
+
64
+ class GraphConfig(TypedDict, total=False):
65
+ """Configuration for graph execution."""
66
+
67
+ checkpoint_id: str
68
+ stream_mode: list[Literal["values", "updates", "messages", "events"]]
69
+ recursion_limit: int
70
+ timeout: float
71
+
72
+
73
+ class RunResult(BaseModel):
74
+ """Result of a graph execution."""
75
+
76
+ run_id: str
77
+ status: Literal["completed", "failed", "interrupted", "cancelled"]
78
+ output: dict[str, Any]
79
+ error: str | None = None
80
+ nodes_executed: list[str] = Field(default_factory=list)
81
+ tokens: dict[str, int] | None = None
82
+ duration_ms: float | None = None
83
+
84
+
85
+ class Event(BaseModel):
86
+ """Streaming event from graph execution."""
87
+
88
+ type: Literal[
89
+ "run_started",
90
+ "run_completed",
91
+ "run_failed",
92
+ "node_started",
93
+ "node_completed",
94
+ "token",
95
+ "checkpoint",
96
+ ]
97
+ run_id: str
98
+ node_id: str | None = None
99
+ data: dict[str, Any] = Field(default_factory=dict)
100
+ timestamp: str
@@ -0,0 +1,5 @@
1
+ """Worker module for DuraGraph control plane integration."""
2
+
3
+ from duragraph.worker.worker import Worker
4
+
5
+ __all__ = ["Worker"]
@@ -0,0 +1,327 @@
1
+ """Worker implementation for DuraGraph control plane."""
2
+
3
+ import asyncio
4
+ import signal
5
+ from collections.abc import Callable
6
+ from typing import Any
7
+ from uuid import uuid4
8
+
9
+ import httpx
10
+
11
+ from duragraph.graph import GraphDefinition
12
+
13
+
14
+ class Worker:
15
+ """Worker that connects to DuraGraph control plane and executes graphs."""
16
+
17
+ def __init__(
18
+ self,
19
+ control_plane_url: str,
20
+ *,
21
+ name: str | None = None,
22
+ capabilities: list[str] | None = None,
23
+ poll_interval: float = 1.0,
24
+ ):
25
+ """Initialize worker.
26
+
27
+ Args:
28
+ control_plane_url: URL of the DuraGraph control plane.
29
+ name: Optional name for this worker.
30
+ capabilities: Optional list of capabilities (e.g., ["openai", "tools"]).
31
+ poll_interval: Interval in seconds between polling for work.
32
+ """
33
+ self.control_plane_url = control_plane_url.rstrip("/")
34
+ self.name = name or f"worker-{uuid4().hex[:8]}"
35
+ self.capabilities = capabilities or []
36
+ self.poll_interval = poll_interval
37
+
38
+ self._worker_id: str | None = None
39
+ self._graphs: dict[str, GraphDefinition] = {}
40
+ self._executors: dict[str, Callable[..., Any]] = {}
41
+ self._running = False
42
+ self._client: httpx.AsyncClient | None = None
43
+
44
+ def register_graph(
45
+ self,
46
+ definition: GraphDefinition,
47
+ executor: Callable[..., Any] | None = None,
48
+ ) -> None:
49
+ """Register a graph definition with this worker.
50
+
51
+ Args:
52
+ definition: The graph definition to register.
53
+ executor: Optional custom executor function.
54
+ """
55
+ self._graphs[definition.graph_id] = definition
56
+ if executor:
57
+ self._executors[definition.graph_id] = executor
58
+
59
+ async def _register_with_control_plane(self) -> str:
60
+ """Register this worker with the control plane."""
61
+ if self._client is None:
62
+ self._client = httpx.AsyncClient(timeout=30.0)
63
+
64
+ # Prepare graph definitions
65
+ graphs = [
66
+ {"graph_id": g.graph_id, "definition": g.to_ir()}
67
+ for g in self._graphs.values()
68
+ ]
69
+
70
+ payload = {
71
+ "name": self.name,
72
+ "capabilities": self.capabilities,
73
+ "graphs": graphs,
74
+ }
75
+
76
+ response = await self._client.post(
77
+ f"{self.control_plane_url}/api/v1/workers/register",
78
+ json=payload,
79
+ )
80
+ response.raise_for_status()
81
+
82
+ data = response.json()
83
+ return data["worker_id"]
84
+
85
+ async def _poll_for_work(self) -> dict[str, Any] | None:
86
+ """Poll the control plane for work."""
87
+ if self._client is None or self._worker_id is None:
88
+ return None
89
+
90
+ try:
91
+ response = await self._client.get(
92
+ f"{self.control_plane_url}/api/v1/workers/{self._worker_id}/poll",
93
+ )
94
+ if response.status_code == 204:
95
+ return None
96
+ response.raise_for_status()
97
+ return response.json()
98
+ except httpx.HTTPStatusError as e:
99
+ if e.response.status_code == 404:
100
+ # Worker not found, re-register
101
+ self._worker_id = await self._register_with_control_plane()
102
+ return None
103
+ except Exception:
104
+ return None
105
+
106
+ async def _execute_run(self, work: dict[str, Any]) -> None:
107
+ """Execute a run from the control plane."""
108
+ run_id = work.get("run_id")
109
+ graph_id = work.get("graph_id")
110
+ input_data = work.get("input", {})
111
+ thread_id = work.get("thread_id")
112
+
113
+ if not run_id or not graph_id:
114
+ return
115
+
116
+ # Find the graph definition
117
+ graph_def = self._graphs.get(graph_id)
118
+ if not graph_def:
119
+ await self._send_event(run_id, "run_failed", {
120
+ "error": f"Graph '{graph_id}' not registered with this worker",
121
+ })
122
+ return
123
+
124
+ # Start the run
125
+ await self._send_event(run_id, "run_started", {"thread_id": thread_id})
126
+
127
+ try:
128
+ # Execute nodes
129
+ state = input_data.copy()
130
+ current_node = graph_def.entrypoint
131
+
132
+ while current_node:
133
+ await self._send_event(run_id, "node_started", {
134
+ "node_id": current_node,
135
+ })
136
+
137
+ # Get node metadata
138
+ node_meta = graph_def.nodes.get(current_node)
139
+ if not node_meta:
140
+ raise ValueError(f"Node '{current_node}' not found")
141
+
142
+ # Execute based on node type
143
+ if node_meta.node_type == "llm":
144
+ result = await self._execute_llm_node(node_meta, state)
145
+ elif node_meta.node_type == "tool":
146
+ result = await self._execute_tool_node(node_meta, state)
147
+ elif node_meta.node_type == "human":
148
+ result = await self._execute_human_node(
149
+ run_id, node_meta, state
150
+ )
151
+ if result is None:
152
+ # Interrupted, waiting for human input
153
+ return
154
+ else:
155
+ # Default function node - just pass through
156
+ result = state
157
+
158
+ if isinstance(result, dict):
159
+ state.update(result)
160
+
161
+ await self._send_event(run_id, "node_completed", {
162
+ "node_id": current_node,
163
+ "output": result,
164
+ })
165
+
166
+ # Find next node
167
+ next_node = None
168
+ for edge in graph_def.edges:
169
+ if edge.source == current_node:
170
+ if isinstance(edge.target, str):
171
+ next_node = edge.target
172
+ elif isinstance(edge.target, dict):
173
+ if isinstance(result, str) and result in edge.target:
174
+ next_node = edge.target[result]
175
+ break
176
+
177
+ current_node = next_node
178
+
179
+ # Run completed
180
+ await self._send_event(run_id, "run_completed", {
181
+ "output": state,
182
+ "thread_id": thread_id,
183
+ })
184
+
185
+ except Exception as e:
186
+ await self._send_event(run_id, "run_failed", {
187
+ "error": str(e),
188
+ "thread_id": thread_id,
189
+ })
190
+
191
+ async def _execute_llm_node(
192
+ self,
193
+ node_meta: Any,
194
+ state: dict[str, Any],
195
+ ) -> dict[str, Any]:
196
+ """Execute an LLM node."""
197
+ # Placeholder - would integrate with LLM providers
198
+ config = node_meta.config
199
+ model = config.get("model", "gpt-4o-mini")
200
+
201
+ # For now, just echo the state
202
+ return {"llm_response": f"[{model}] Processed state"}
203
+
204
+ async def _execute_tool_node(
205
+ self,
206
+ node_meta: Any,
207
+ state: dict[str, Any],
208
+ ) -> dict[str, Any]:
209
+ """Execute a tool node."""
210
+ # Placeholder - would execute registered tools
211
+ return state
212
+
213
+ async def _execute_human_node(
214
+ self,
215
+ run_id: str,
216
+ node_meta: Any,
217
+ state: dict[str, Any],
218
+ ) -> dict[str, Any] | None:
219
+ """Execute a human-in-the-loop node."""
220
+ config = node_meta.config
221
+ prompt = config.get("prompt", "Please review")
222
+
223
+ # Signal that human input is required
224
+ await self._send_event(run_id, "run_requires_action", {
225
+ "action_type": "human_review",
226
+ "prompt": prompt,
227
+ "state": state,
228
+ })
229
+
230
+ # Return None to indicate the run is waiting
231
+ return None
232
+
233
+ async def _send_event(
234
+ self,
235
+ run_id: str,
236
+ event_type: str,
237
+ data: dict[str, Any],
238
+ ) -> None:
239
+ """Send an event to the control plane."""
240
+ if self._client is None or self._worker_id is None:
241
+ return
242
+
243
+ payload = {
244
+ "run_id": run_id,
245
+ "event_type": event_type,
246
+ "data": data,
247
+ }
248
+
249
+ try:
250
+ response = await self._client.post(
251
+ f"{self.control_plane_url}/api/v1/workers/{self._worker_id}/events",
252
+ json=payload,
253
+ )
254
+ response.raise_for_status()
255
+ except Exception:
256
+ pass # Best effort
257
+
258
+ async def _heartbeat(self) -> None:
259
+ """Send heartbeat to control plane."""
260
+ if self._client is None or self._worker_id is None:
261
+ return
262
+
263
+ try:
264
+ await self._client.post(
265
+ f"{self.control_plane_url}/api/v1/workers/{self._worker_id}/heartbeat",
266
+ )
267
+ except Exception:
268
+ pass
269
+
270
+ async def _run_loop(self) -> None:
271
+ """Main worker loop."""
272
+ self._running = True
273
+
274
+ # Register with control plane
275
+ print(f"Registering worker '{self.name}' with control plane...")
276
+ self._worker_id = await self._register_with_control_plane()
277
+ print(f"Registered with worker_id: {self._worker_id}")
278
+
279
+ heartbeat_counter = 0
280
+
281
+ while self._running:
282
+ # Poll for work
283
+ work = await self._poll_for_work()
284
+ if work:
285
+ print(f"Received work: {work.get('run_id')}")
286
+ await self._execute_run(work)
287
+
288
+ # Periodic heartbeat
289
+ heartbeat_counter += 1
290
+ if heartbeat_counter >= 30: # Every 30 poll intervals
291
+ await self._heartbeat()
292
+ heartbeat_counter = 0
293
+
294
+ await asyncio.sleep(self.poll_interval)
295
+
296
+ def run(self) -> None:
297
+ """Run the worker (blocking)."""
298
+ loop = asyncio.new_event_loop()
299
+ asyncio.set_event_loop(loop)
300
+
301
+ # Handle shutdown signals
302
+ for sig in (signal.SIGTERM, signal.SIGINT):
303
+ loop.add_signal_handler(sig, self._shutdown)
304
+
305
+ try:
306
+ loop.run_until_complete(self._run_loop())
307
+ finally:
308
+ if self._client:
309
+ loop.run_until_complete(self._client.aclose())
310
+ loop.close()
311
+
312
+ async def arun(self) -> None:
313
+ """Run the worker asynchronously."""
314
+ try:
315
+ await self._run_loop()
316
+ finally:
317
+ if self._client:
318
+ await self._client.aclose()
319
+
320
+ def _shutdown(self) -> None:
321
+ """Shutdown the worker."""
322
+ print("\nShutting down worker...")
323
+ self._running = False
324
+
325
+ def stop(self) -> None:
326
+ """Stop the worker."""
327
+ self._running = False