flatagents 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ """
2
+ Google Cloud Platform backends for FlatAgents.
3
+
4
+ Provides Firestore-based persistence and result storage for
5
+ Cloud Functions and Firebase deployments.
6
+
7
+ Usage:
8
+ from flatagents.gcp import FirestoreBackend
9
+
10
+ backend = FirestoreBackend(collection="flatagents")
11
+ machine = FlatMachine(
12
+ config_file="machine.yml",
13
+ persistence=backend,
14
+ result_backend=backend
15
+ )
16
+
17
+ Requirements:
18
+ pip install google-cloud-firestore
19
+ """
20
+
21
+ from .firestore import FirestoreBackend
22
+
23
+ __all__ = [
24
+ "FirestoreBackend",
25
+ ]
@@ -0,0 +1,227 @@
1
+ """
2
+ Firestore backend for FlatAgents persistence and results.
3
+
4
+ Implements both PersistenceBackend and ResultBackend using Firestore.
5
+ Compatible with Cloud Functions, Firebase, and local emulator.
6
+
7
+ Document structure:
8
+ Collection: flatagents (configurable)
9
+ Document ID: {execution_id}
10
+ Subcollections:
11
+ - checkpoints/{step}_{event}
12
+ - results/{path}
13
+ """
14
+
15
+ import asyncio
16
+ import json
17
+ import logging
18
+ from datetime import datetime, timezone
19
+ from typing import Any, Dict, List, Optional
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Lazy import to avoid hard dependency
24
+ _firestore = None
25
+ _async_client = None
26
+
27
+
28
+ def _get_firestore():
29
+ global _firestore
30
+ if _firestore is None:
31
+ try:
32
+ from google.cloud import firestore
33
+ _firestore = firestore
34
+ except ImportError:
35
+ raise ImportError(
36
+ "google-cloud-firestore is required for GCP backends. "
37
+ "Install with: pip install google-cloud-firestore"
38
+ )
39
+ return _firestore
40
+
41
+
42
+ def _get_async_client():
43
+ """Get or create an async Firestore client."""
44
+ global _async_client
45
+ if _async_client is None:
46
+ firestore = _get_firestore()
47
+ _async_client = firestore.AsyncClient()
48
+ return _async_client
49
+
50
+
51
+ class FirestoreBackend:
52
+ """
53
+ Combined Persistence and Result backend using Firestore.
54
+
55
+ Implements both PersistenceBackend and ResultBackend interfaces.
56
+ Uses a single collection with subcollections for organization.
57
+
58
+ Args:
59
+ collection: Root collection name (default: "flatagents")
60
+ project: GCP project ID (optional, uses default)
61
+
62
+ Document Layout:
63
+ flatagents/{execution_id}/checkpoints/{step_event} = checkpoint data
64
+ flatagents/{execution_id}/results/{path} = result data
65
+ flatagents/{execution_id}/_meta = metadata (created_at, etc.)
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ collection: str = "flatagents",
71
+ project: Optional[str] = None
72
+ ):
73
+ self.collection = collection
74
+ self.project = project
75
+ self._db = None
76
+
77
+ @property
78
+ def db(self):
79
+ """Lazy-load Firestore client."""
80
+ if self._db is None:
81
+ firestore = _get_firestore()
82
+ if self.project:
83
+ self._db = firestore.AsyncClient(project=self.project)
84
+ else:
85
+ self._db = firestore.AsyncClient()
86
+ return self._db
87
+
88
+ def _doc_ref(self, execution_id: str, subcollection: str, doc_id: str):
89
+ """Get document reference."""
90
+ return (
91
+ self.db.collection(self.collection)
92
+ .document(execution_id)
93
+ .collection(subcollection)
94
+ .document(doc_id)
95
+ )
96
+
97
+ # =========================================================================
98
+ # PersistenceBackend Interface
99
+ # =========================================================================
100
+
101
+ async def save(self, key: str, value: bytes) -> None:
102
+ """Save checkpoint data.
103
+
104
+ Args:
105
+ key: Format "{execution_id}/step_{step}_{event}"
106
+ value: JSON-encoded checkpoint bytes
107
+ """
108
+ parts = key.split("/", 1)
109
+ execution_id = parts[0]
110
+ doc_id = parts[1] if len(parts) > 1 else "latest"
111
+
112
+ doc_ref = self._doc_ref(execution_id, "checkpoints", doc_id)
113
+
114
+ await doc_ref.set({
115
+ "data": value.decode("utf-8"),
116
+ "created_at": datetime.now(timezone.utc).isoformat(),
117
+ })
118
+
119
+ logger.debug(f"Firestore: saved checkpoint {execution_id}/{doc_id}")
120
+
121
+ async def load(self, key: str) -> Optional[bytes]:
122
+ """Load checkpoint data."""
123
+ parts = key.split("/", 1)
124
+ execution_id = parts[0]
125
+ doc_id = parts[1] if len(parts) > 1 else "latest"
126
+
127
+ doc_ref = self._doc_ref(execution_id, "checkpoints", doc_id)
128
+ doc = await doc_ref.get()
129
+
130
+ if not doc.exists:
131
+ return None
132
+
133
+ return doc.to_dict()["data"].encode("utf-8")
134
+
135
+ async def delete(self, key: str) -> None:
136
+ """Delete checkpoint data."""
137
+ parts = key.split("/", 1)
138
+ execution_id = parts[0]
139
+ doc_id = parts[1] if len(parts) > 1 else "latest"
140
+
141
+ doc_ref = self._doc_ref(execution_id, "checkpoints", doc_id)
142
+ await doc_ref.delete()
143
+
144
+ async def list(self, prefix: str) -> List[str]:
145
+ """List all keys matching prefix."""
146
+ execution_id = prefix.rstrip("/")
147
+
148
+ collection_ref = (
149
+ self.db.collection(self.collection)
150
+ .document(execution_id)
151
+ .collection("checkpoints")
152
+ )
153
+
154
+ docs = collection_ref.stream()
155
+ keys = []
156
+ async for doc in docs:
157
+ keys.append(f"{execution_id}/{doc.id}")
158
+
159
+ return sorted(keys)
160
+
161
+ # =========================================================================
162
+ # ResultBackend Interface
163
+ # =========================================================================
164
+
165
+ async def write(self, uri: str, data: Any) -> None:
166
+ """Write result to a URI."""
167
+ from ..backends import parse_uri
168
+
169
+ execution_id, path = parse_uri(uri)
170
+ doc_ref = self._doc_ref(execution_id, "results", path)
171
+
172
+ await doc_ref.set({
173
+ "data": json.dumps(data),
174
+ "created_at": datetime.now(timezone.utc).isoformat(),
175
+ })
176
+
177
+ logger.debug(f"Firestore: wrote result {execution_id}/{path}")
178
+
179
+ async def read(
180
+ self,
181
+ uri: str,
182
+ block: bool = True,
183
+ timeout: Optional[float] = None
184
+ ) -> Any:
185
+ """Read result from a URI, optionally blocking until available."""
186
+ from ..backends import parse_uri
187
+
188
+ execution_id, path = parse_uri(uri)
189
+ doc_ref = self._doc_ref(execution_id, "results", path)
190
+
191
+ start_time = datetime.now(timezone.utc).timestamp()
192
+ poll_interval = 0.5
193
+
194
+ while True:
195
+ doc = await doc_ref.get()
196
+
197
+ if doc.exists:
198
+ return json.loads(doc.to_dict()["data"])
199
+
200
+ if not block:
201
+ return None
202
+
203
+ if timeout:
204
+ elapsed = datetime.now(timezone.utc).timestamp() - start_time
205
+ if elapsed >= timeout:
206
+ raise TimeoutError(f"Timeout waiting for result at {uri}")
207
+
208
+ await asyncio.sleep(poll_interval)
209
+ poll_interval = min(poll_interval * 1.5, 5.0)
210
+
211
+ async def exists(self, uri: str) -> bool:
212
+ """Check if result exists at URI."""
213
+ from ..backends import parse_uri
214
+
215
+ execution_id, path = parse_uri(uri)
216
+ doc_ref = self._doc_ref(execution_id, "results", path)
217
+ doc = await doc_ref.get()
218
+
219
+ return doc.exists
220
+
221
+ async def delete(self, uri: str) -> None:
222
+ """Delete result at URI."""
223
+ from ..backends import parse_uri
224
+
225
+ execution_id, path = parse_uri(uri)
226
+ doc_ref = self._doc_ref(execution_id, "results", path)
227
+ await doc_ref.delete()
flatagents/hooks.py ADDED
@@ -0,0 +1,380 @@
1
+ """
2
+ MachineHooks - Extensibility points for FlatMachine.
3
+
4
+ Hooks allow custom logic at key points in machine execution:
5
+ - Before/after state entry/exit
6
+ - Before/after agent calls
7
+ - On transitions
8
+ - On errors
9
+
10
+ Includes built-in LoggingHooks and MetricsHooks implementations.
11
+ """
12
+
13
+ import logging
14
+ import time
15
+ from abc import ABC
16
+ from typing import Any, Dict, Optional
17
+
18
+ from .monitoring import get_logger
19
+
20
+ logger = get_logger(__name__)
21
+
22
+ try:
23
+ import httpx
24
+ except ImportError:
25
+ httpx = None
26
+
27
+
28
+ class MachineHooks(ABC):
29
+ """
30
+ Base class for machine hooks.
31
+
32
+ Override methods to customize machine behavior.
33
+ All methods have default implementations that pass through unchanged.
34
+
35
+ Example:
36
+ from flatagents import get_logger
37
+ logger = get_logger(__name__)
38
+
39
+ class MyHooks(MachineHooks):
40
+ def on_state_enter(self, state_name, context):
41
+ logger.info(f"Entering state: {state_name}")
42
+ return context
43
+
44
+ machine = FlatMachine(config_file="...", hooks=MyHooks())
45
+ """
46
+
47
+ def on_machine_start(self, context: Dict[str, Any]) -> Dict[str, Any]:
48
+ """
49
+ Called when machine execution starts.
50
+
51
+ Args:
52
+ context: Initial context
53
+
54
+ Returns:
55
+ Modified context
56
+ """
57
+ return context
58
+
59
+ def on_machine_end(self, context: Dict[str, Any], final_output: Dict[str, Any]) -> Dict[str, Any]:
60
+ """
61
+ Called when machine execution ends.
62
+
63
+ Args:
64
+ context: Final context
65
+ final_output: Output from final state
66
+
67
+ Returns:
68
+ Modified final output
69
+ """
70
+ return final_output
71
+
72
+ def on_state_enter(self, state_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
73
+ """
74
+ Called before executing a state.
75
+
76
+ Args:
77
+ state_name: Name of the state being entered
78
+ context: Current context
79
+
80
+ Returns:
81
+ Modified context
82
+ """
83
+ return context
84
+
85
+ def on_state_exit(
86
+ self,
87
+ state_name: str,
88
+ context: Dict[str, Any],
89
+ output: Optional[Dict[str, Any]]
90
+ ) -> Optional[Dict[str, Any]]:
91
+ """
92
+ Called after executing a state.
93
+
94
+ Args:
95
+ state_name: Name of the state that was executed
96
+ context: Current context
97
+ output: Output from the state (agent output or None)
98
+
99
+ Returns:
100
+ Modified output
101
+ """
102
+ return output
103
+
104
+ def on_transition(
105
+ self,
106
+ from_state: str,
107
+ to_state: str,
108
+ context: Dict[str, Any]
109
+ ) -> str:
110
+ """
111
+ Called when transitioning between states.
112
+
113
+ Can override the target state.
114
+
115
+ Args:
116
+ from_state: Source state name
117
+ to_state: Target state name (from transition evaluation)
118
+ context: Current context
119
+
120
+ Returns:
121
+ Actual target state name (can override)
122
+ """
123
+ return to_state
124
+
125
+ def on_error(
126
+ self,
127
+ state_name: str,
128
+ error: Exception,
129
+ context: Dict[str, Any]
130
+ ) -> Optional[str]:
131
+ """
132
+ Called when an error occurs during state execution.
133
+
134
+ Args:
135
+ state_name: Name of the state where error occurred
136
+ error: The exception that was raised
137
+ context: Current context
138
+
139
+ Returns:
140
+ State to transition to, or None to re-raise the error
141
+ """
142
+ return None # Re-raise by default
143
+
144
+ def on_action(
145
+ self,
146
+ action_name: str,
147
+ context: Dict[str, Any]
148
+ ) -> Dict[str, Any]:
149
+ """
150
+ Called for custom hook actions defined in states.
151
+
152
+ Args:
153
+ action_name: Name of the action to execute
154
+ context: Current context
155
+
156
+ Returns:
157
+ Modified context
158
+ """
159
+ logger.warning(f"Unhandled action: {action_name}")
160
+ return context
161
+
162
+
163
+ class LoggingHooks(MachineHooks):
164
+ """Hooks that log all state transitions."""
165
+
166
+ def __init__(self, log_level: int = logging.INFO):
167
+ self.log_level = log_level
168
+
169
+ def on_machine_start(self, context: Dict[str, Any]) -> Dict[str, Any]:
170
+ logger.log(self.log_level, "Machine execution started")
171
+ return context
172
+
173
+ def on_machine_end(self, context: Dict[str, Any], final_output: Dict[str, Any]) -> Dict[str, Any]:
174
+ logger.log(self.log_level, f"Machine execution ended with output: {final_output}")
175
+ return final_output
176
+
177
+ def on_state_enter(self, state_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
178
+ logger.log(self.log_level, f"Entering state: {state_name}")
179
+ return context
180
+
181
+ def on_state_exit(
182
+ self,
183
+ state_name: str,
184
+ context: Dict[str, Any],
185
+ output: Optional[Dict[str, Any]]
186
+ ) -> Optional[Dict[str, Any]]:
187
+ logger.log(self.log_level, f"Exiting state: {state_name}")
188
+ return output
189
+
190
+ def on_transition(self, from_state: str, to_state: str, context: Dict[str, Any]) -> str:
191
+ logger.log(self.log_level, f"Transition: {from_state} -> {to_state}")
192
+ return to_state
193
+
194
+
195
+ class MetricsHooks(MachineHooks):
196
+ """Hooks that track execution metrics."""
197
+
198
+ def __init__(self):
199
+ self.state_counts: Dict[str, int] = {}
200
+ self.transition_counts: Dict[str, int] = {}
201
+ self.total_states_executed = 0
202
+ self.error_count = 0
203
+
204
+ def on_state_enter(self, state_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
205
+ self.state_counts[state_name] = self.state_counts.get(state_name, 0) + 1
206
+ self.total_states_executed += 1
207
+ return context
208
+
209
+ def on_transition(self, from_state: str, to_state: str, context: Dict[str, Any]) -> str:
210
+ key = f"{from_state}->{to_state}"
211
+ self.transition_counts[key] = self.transition_counts.get(key, 0) + 1
212
+ return to_state
213
+
214
+ def on_error(self, state_name: str, error: Exception, context: Dict[str, Any]) -> Optional[str]:
215
+ self.error_count += 1
216
+ return None
217
+
218
+ def get_metrics(self) -> Dict[str, Any]:
219
+ """Get collected metrics."""
220
+ return {
221
+ "state_counts": self.state_counts,
222
+ "transition_counts": self.transition_counts,
223
+ "total_states_executed": self.total_states_executed,
224
+ "error_count": self.error_count,
225
+ }
226
+
227
+
228
+ class CompositeHooks(MachineHooks):
229
+ """Compose multiple hooks together."""
230
+
231
+ def __init__(self, *hooks: MachineHooks):
232
+ self.hooks = list(hooks)
233
+
234
+ def on_machine_start(self, context: Dict[str, Any]) -> Dict[str, Any]:
235
+ for hook in self.hooks:
236
+ context = hook.on_machine_start(context)
237
+ return context
238
+
239
+ def on_machine_end(self, context: Dict[str, Any], final_output: Dict[str, Any]) -> Dict[str, Any]:
240
+ for hook in self.hooks:
241
+ final_output = hook.on_machine_end(context, final_output)
242
+ return final_output
243
+
244
+ def on_state_enter(self, state_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
245
+ for hook in self.hooks:
246
+ context = hook.on_state_enter(state_name, context)
247
+ return context
248
+
249
+ def on_state_exit(
250
+ self,
251
+ state_name: str,
252
+ context: Dict[str, Any],
253
+ output: Optional[Dict[str, Any]]
254
+ ) -> Optional[Dict[str, Any]]:
255
+ for hook in self.hooks:
256
+ output = hook.on_state_exit(state_name, context, output)
257
+ return output
258
+
259
+ def on_transition(self, from_state: str, to_state: str, context: Dict[str, Any]) -> str:
260
+ for hook in self.hooks:
261
+ to_state = hook.on_transition(from_state, to_state, context)
262
+ return to_state
263
+
264
+ def on_error(self, state_name: str, error: Exception, context: Dict[str, Any]) -> Optional[str]:
265
+ for hook in self.hooks:
266
+ result = hook.on_error(state_name, error, context)
267
+ if result is not None:
268
+ return result
269
+ return None
270
+
271
+ def on_action(self, action_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
272
+ for hook in self.hooks:
273
+ context = hook.on_action(action_name, context)
274
+ return context
275
+
276
+
277
+ class WebhookHooks(MachineHooks):
278
+ """
279
+ Hooks that dispatch events to an HTTP endpoint.
280
+
281
+ Requires 'httpx' installed.
282
+ """
283
+
284
+ def __init__(
285
+ self,
286
+ endpoint: str,
287
+ timeout: float = 5.0,
288
+ api_key: Optional[str] = None
289
+ ):
290
+ if httpx is None:
291
+ raise ImportError("httpx is required for WebhookHooks")
292
+
293
+ self.endpoint = endpoint
294
+ self.timeout = timeout
295
+ self.headers = {
296
+ "Content-Type": "application/json",
297
+ "User-Agent": "FlatAgents/0.1.0"
298
+ }
299
+ if api_key:
300
+ self.headers["Authorization"] = f"Bearer {api_key}"
301
+
302
+ async def _send(self, event: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
303
+ """Send event to webhook."""
304
+ data = {"event": event, **payload}
305
+ try:
306
+ async with httpx.AsyncClient() as client:
307
+ response = await client.post(
308
+ self.endpoint,
309
+ json=data,
310
+ headers=self.headers,
311
+ timeout=self.timeout
312
+ )
313
+ response.raise_for_status()
314
+ if response.status_code == 204:
315
+ return None
316
+ return response.json()
317
+ except Exception as e:
318
+ logger.error(f"Webhook error ({event}): {e}")
319
+ return None
320
+
321
+ async def on_machine_start(self, context: Dict[str, Any]) -> Dict[str, Any]:
322
+ resp = await self._send("machine_start", {"context": context})
323
+ if resp and "context" in resp:
324
+ return resp["context"]
325
+ return context
326
+
327
+ async def on_machine_end(self, context: Dict[str, Any], final_output: Dict[str, Any]) -> Dict[str, Any]:
328
+ resp = await self._send("machine_end", {"context": context, "output": final_output})
329
+ if resp and "output" in resp:
330
+ return resp["output"]
331
+ return final_output
332
+
333
+ async def on_state_enter(self, state_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
334
+ resp = await self._send("state_enter", {"state": state_name, "context": context})
335
+ if resp and "context" in resp:
336
+ return resp["context"]
337
+ return context
338
+
339
+ async def on_state_exit(
340
+ self,
341
+ state_name: str,
342
+ context: Dict[str, Any],
343
+ output: Optional[Dict[str, Any]]
344
+ ) -> Optional[Dict[str, Any]]:
345
+ resp = await self._send("state_exit", {"state": state_name, "context": context, "output": output})
346
+ if resp and "output" in resp:
347
+ return resp["output"]
348
+ return output
349
+
350
+ async def on_transition(self, from_state: str, to_state: str, context: Dict[str, Any]) -> str:
351
+ resp = await self._send("transition", {"from": from_state, "to": to_state, "context": context})
352
+ if resp and "to_state" in resp:
353
+ return resp["to_state"]
354
+ return to_state
355
+
356
+ async def on_error(self, state_name: str, error: Exception, context: Dict[str, Any]) -> Optional[str]:
357
+ resp = await self._send("error", {
358
+ "state": state_name,
359
+ "error": str(error),
360
+ "error_type": type(error).__name__,
361
+ "context": context
362
+ })
363
+ if resp and "recovery_state" in resp:
364
+ return resp["recovery_state"]
365
+ return None # Re-raise
366
+
367
+ async def on_action(self, action_name: str, context: Dict[str, Any]) -> Dict[str, Any]:
368
+ resp = await self._send("action", {"action": action_name, "context": context})
369
+ if resp and "context" in resp:
370
+ return resp["context"]
371
+ return context
372
+
373
+
374
+ __all__ = [
375
+ "MachineHooks",
376
+ "LoggingHooks",
377
+ "MetricsHooks",
378
+ "CompositeHooks",
379
+ "WebhookHooks",
380
+ ]
flatagents/locking.py ADDED
@@ -0,0 +1,69 @@
1
+ import fcntl
2
+ import asyncio
3
+ import os
4
+ from abc import ABC, abstractmethod
5
+ from typing import Optional
6
+ from pathlib import Path
7
+ import contextlib
8
+
9
+ class ExecutionLock(ABC):
10
+ """Abstract interface for concurrency control."""
11
+
12
+ @abstractmethod
13
+ async def acquire(self, key: str) -> bool:
14
+ """Acquire lock for key. Returns True if successful."""
15
+ pass
16
+
17
+ @abstractmethod
18
+ async def release(self, key: str) -> None:
19
+ """Release lock for key."""
20
+ pass
21
+
22
+ class LocalFileLock(ExecutionLock):
23
+ """
24
+ File-based lock using fcntl.flock.
25
+ Works on local filesystems and NFS (mostly).
26
+ NOT suited for distributed cloud storage (S3/GCS).
27
+ """
28
+
29
+ def __init__(self, lock_dir: str = ".locks"):
30
+ self.lock_dir = Path(lock_dir)
31
+ self.lock_dir.mkdir(parents=True, exist_ok=True)
32
+ self._files = {}
33
+
34
+ async def acquire(self, key: str) -> bool:
35
+ """Attempts to acquire a non-blocking exclusive lock."""
36
+ path = self.lock_dir / f"{key}.lock"
37
+
38
+ try:
39
+ # Keep file handle open while locked
40
+ f = open(path, 'a+')
41
+ try:
42
+ # LOCK_EX | LOCK_NB = Exclusive, Non-Blocking
43
+ fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
44
+ self._files[key] = f
45
+ return True
46
+ except (IOError, OSError):
47
+ f.close()
48
+ return False
49
+ except Exception:
50
+ return False
51
+
52
+ async def release(self, key: str) -> None:
53
+ if key in self._files:
54
+ f = self._files.pop(key)
55
+ try:
56
+ fcntl.flock(f.fileno(), fcntl.LOCK_UN)
57
+ finally:
58
+ f.close()
59
+ # Optional: unlink file? Usually simpler to leave it empty
60
+ # Path(f.name).unlink(missing_ok=True)
61
+
62
+ class NoOpLock(ExecutionLock):
63
+ """Used when concurrency control is disabled or managed externally."""
64
+
65
+ async def acquire(self, key: str) -> bool:
66
+ return True
67
+
68
+ async def release(self, key: str) -> None:
69
+ pass