flatmachines 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. flatmachines/__init__.py +136 -0
  2. flatmachines/actions.py +408 -0
  3. flatmachines/adapters/__init__.py +38 -0
  4. flatmachines/adapters/flatagent.py +86 -0
  5. flatmachines/adapters/pi_agent_bridge.py +127 -0
  6. flatmachines/adapters/pi_agent_runner.mjs +99 -0
  7. flatmachines/adapters/smolagents.py +125 -0
  8. flatmachines/agents.py +144 -0
  9. flatmachines/assets/MACHINES.md +141 -0
  10. flatmachines/assets/README.md +11 -0
  11. flatmachines/assets/__init__.py +0 -0
  12. flatmachines/assets/flatagent.d.ts +219 -0
  13. flatmachines/assets/flatagent.schema.json +271 -0
  14. flatmachines/assets/flatagent.slim.d.ts +58 -0
  15. flatmachines/assets/flatagents-runtime.d.ts +523 -0
  16. flatmachines/assets/flatagents-runtime.schema.json +281 -0
  17. flatmachines/assets/flatagents-runtime.slim.d.ts +187 -0
  18. flatmachines/assets/flatmachine.d.ts +403 -0
  19. flatmachines/assets/flatmachine.schema.json +620 -0
  20. flatmachines/assets/flatmachine.slim.d.ts +106 -0
  21. flatmachines/assets/profiles.d.ts +140 -0
  22. flatmachines/assets/profiles.schema.json +93 -0
  23. flatmachines/assets/profiles.slim.d.ts +26 -0
  24. flatmachines/backends.py +222 -0
  25. flatmachines/distributed.py +835 -0
  26. flatmachines/distributed_hooks.py +351 -0
  27. flatmachines/execution.py +638 -0
  28. flatmachines/expressions/__init__.py +60 -0
  29. flatmachines/expressions/cel.py +101 -0
  30. flatmachines/expressions/simple.py +166 -0
  31. flatmachines/flatmachine.py +1263 -0
  32. flatmachines/hooks.py +381 -0
  33. flatmachines/locking.py +69 -0
  34. flatmachines/monitoring.py +505 -0
  35. flatmachines/persistence.py +213 -0
  36. flatmachines/run.py +117 -0
  37. flatmachines/utils.py +166 -0
  38. flatmachines/validation.py +79 -0
  39. flatmachines-1.0.0.dist-info/METADATA +390 -0
  40. flatmachines-1.0.0.dist-info/RECORD +41 -0
  41. flatmachines-1.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,351 @@
1
+ """
2
+ Base hooks for distributed worker patterns.
3
+
4
+ Provides ready-to-use action handlers for:
5
+ - Pool management (get_pool_state, claim_job, complete_job, fail_job)
6
+ - Worker lifecycle (register_worker, deregister_worker, heartbeat)
7
+ - Reaper (list_stale_workers, reap_worker, reap_stale_workers)
8
+ - Auto-scaling (calculate_spawn, spawn_workers)
9
+
10
+ Users extend this class and add custom job-processing actions.
11
+
12
+ Example:
13
+ from flatmachines import DistributedWorkerHooks, SQLiteRegistrationBackend, SQLiteWorkBackend
14
+
15
+ class MyHooks(DistributedWorkerHooks):
16
+ def __init__(self):
17
+ super().__init__(
18
+ registration=SQLiteRegistrationBackend(db_path="./data/workers.db"),
19
+ work=SQLiteWorkBackend(db_path="./data/workers.db"),
20
+ )
21
+
22
+ async def _process_my_job(self, context):
23
+ # Custom job processing
24
+ return context
25
+ """
26
+
27
+ import logging
28
+ from typing import Dict, Any, Optional, Protocol, List, TYPE_CHECKING
29
+
30
+ from .hooks import MachineHooks
31
+ from .distributed import (
32
+ RegistrationBackend,
33
+ WorkBackend,
34
+ WorkerRegistration,
35
+ WorkerFilter,
36
+ )
37
+
38
+ if TYPE_CHECKING:
39
+ pass
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class DistributedWorkerHooks(MachineHooks):
45
+ """
46
+ Ready-to-use hooks for distributed worker patterns.
47
+
48
+ Provides standard actions for work distribution, worker lifecycle,
49
+ stale worker cleanup, and auto-scaling.
50
+
51
+ Extend this class and add your own job-processing actions.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ registration: RegistrationBackend,
57
+ work: WorkBackend,
58
+ ):
59
+ """
60
+ Initialize with backend instances.
61
+
62
+ Args:
63
+ registration: Backend for worker registration/lifecycle
64
+ work: Backend for work pool operations
65
+ """
66
+ self._registration = registration
67
+ self._work = work
68
+
69
+ async def on_action(self, action: str, context: Dict[str, Any]) -> Dict[str, Any]:
70
+ """Route action names to handler methods."""
71
+
72
+ # Standard distributed worker actions
73
+ handlers = {
74
+ # Pool state (for checkers)
75
+ "get_pool_state": self._get_pool_state,
76
+
77
+ # Job operations (for workers)
78
+ "claim_job": self._claim_job,
79
+ "complete_job": self._complete_job,
80
+ "fail_job": self._fail_job,
81
+
82
+ # Worker lifecycle
83
+ "register_worker": self._register_worker,
84
+ "deregister_worker": self._deregister_worker,
85
+ "heartbeat": self._heartbeat,
86
+
87
+ # Reaper operations
88
+ "list_stale_workers": self._list_stale_workers,
89
+ "reap_worker": self._reap_worker,
90
+ "reap_stale_workers": self._reap_stale_workers,
91
+
92
+ # Auto-scaling
93
+ "calculate_spawn": self._calculate_spawn,
94
+ "spawn_workers": self._spawn_workers,
95
+ }
96
+
97
+ handler = handlers.get(action)
98
+ if handler:
99
+ return await handler(context)
100
+
101
+ # Fall through to subclass or default behavior
102
+ return context
103
+
104
+ # -------------------------------------------------------------------------
105
+ # Pool State Actions (for parallelization checker)
106
+ # -------------------------------------------------------------------------
107
+
108
+ async def _get_pool_state(self, context: Dict[str, Any]) -> Dict[str, Any]:
109
+ """Get current pool depth and active worker count."""
110
+ pool_id = context.get("pool_id", "default")
111
+
112
+ pool = self._work.pool(pool_id)
113
+ workers = await self._registration.list(WorkerFilter(status="active"))
114
+
115
+ # Merge output into context (preserve existing context keys)
116
+ context["queue_depth"] = await pool.size()
117
+ context["active_workers"] = len(workers)
118
+ return context
119
+
120
+ # -------------------------------------------------------------------------
121
+ # Job Actions (for job workers)
122
+ # -------------------------------------------------------------------------
123
+
124
+ async def _claim_job(self, context: Dict[str, Any]) -> Dict[str, Any]:
125
+ """Atomically claim the next available job."""
126
+ pool_id = context.get("pool_id", "default")
127
+ worker_id = context.get("worker_id")
128
+
129
+ if not worker_id:
130
+ raise ValueError("worker_id is required for claim_job")
131
+
132
+ pool = self._work.pool(pool_id)
133
+ item = await pool.claim(worker_id)
134
+
135
+ if item:
136
+ context["job"] = item.data
137
+ context["job_id"] = item.id
138
+ else:
139
+ context["job"] = None
140
+ context["job_id"] = None
141
+ return context
142
+
143
+ async def _complete_job(self, context: Dict[str, Any]) -> Dict[str, Any]:
144
+ """Mark job as successfully completed."""
145
+ pool_id = context.get("pool_id", "default")
146
+ job_id = context.get("job_id")
147
+ result = context.get("result")
148
+
149
+ if not job_id:
150
+ raise ValueError("job_id is required for complete_job")
151
+
152
+ pool = self._work.pool(pool_id)
153
+ await pool.complete(job_id, result)
154
+
155
+ return context
156
+
157
+ async def _fail_job(self, context: Dict[str, Any]) -> Dict[str, Any]:
158
+ """Mark job as failed. Will retry or poison based on attempts."""
159
+ pool_id = context.get("pool_id", "default")
160
+ job_id = context.get("job_id")
161
+ error = context.get("error")
162
+
163
+ if not job_id:
164
+ raise ValueError("job_id is required for fail_job")
165
+
166
+ pool = self._work.pool(pool_id)
167
+ await pool.fail(job_id, error)
168
+
169
+ return context
170
+
171
+ # -------------------------------------------------------------------------
172
+ # Worker Lifecycle Actions
173
+ # -------------------------------------------------------------------------
174
+
175
+ async def _register_worker(self, context: Dict[str, Any]) -> Dict[str, Any]:
176
+ """Register a new worker."""
177
+ import socket
178
+ import os
179
+
180
+ worker_id = context.get("worker_id")
181
+ if not worker_id:
182
+ raise ValueError("worker_id is required for register_worker")
183
+
184
+ registration = WorkerRegistration(
185
+ worker_id=worker_id,
186
+ host=socket.gethostname(),
187
+ pid=os.getpid(),
188
+ capabilities=context.get("capabilities", []),
189
+ )
190
+
191
+ record = await self._registration.register(registration)
192
+
193
+ # Merge output into context
194
+ context["worker_id"] = record.worker_id
195
+ context["status"] = record.status
196
+ context["registered_at"] = record.started_at
197
+ return context
198
+
199
+ async def _deregister_worker(self, context: Dict[str, Any]) -> Dict[str, Any]:
200
+ """Mark worker as terminated (clean shutdown)."""
201
+ worker_id = context.get("worker_id")
202
+ if not worker_id:
203
+ raise ValueError("worker_id is required for deregister_worker")
204
+
205
+ await self._registration.update_status(worker_id, "terminated")
206
+
207
+ return context
208
+
209
+ async def _heartbeat(self, context: Dict[str, Any]) -> Dict[str, Any]:
210
+ """Send heartbeat for a worker."""
211
+ worker_id = context.get("worker_id")
212
+ if not worker_id:
213
+ raise ValueError("worker_id is required for heartbeat")
214
+
215
+ await self._registration.heartbeat(worker_id)
216
+
217
+ return context
218
+
219
+ # -------------------------------------------------------------------------
220
+ # Reaper Actions
221
+ # -------------------------------------------------------------------------
222
+
223
+ async def _list_stale_workers(self, context: Dict[str, Any]) -> Dict[str, Any]:
224
+ """Find workers that have missed heartbeat threshold."""
225
+ threshold = context.get("stale_threshold_seconds", 60)
226
+
227
+ workers = await self._registration.list(WorkerFilter(
228
+ status="active",
229
+ stale_threshold_seconds=threshold,
230
+ ))
231
+
232
+ stale_workers = [
233
+ {
234
+ "worker_id": w.worker_id,
235
+ "last_heartbeat": w.last_heartbeat,
236
+ "host": w.host,
237
+ }
238
+ for w in workers
239
+ ]
240
+
241
+ # Merge into context with count for condition evaluation
242
+ context["workers"] = stale_workers
243
+ context["stale_count"] = len(stale_workers)
244
+ return context
245
+
246
+ async def _reap_worker(self, context: Dict[str, Any]) -> Dict[str, Any]:
247
+ """Mark worker as lost and release their claimed jobs."""
248
+ worker = context.get("worker")
249
+ pool_id = context.get("pool_id", "default")
250
+
251
+ if not worker:
252
+ raise ValueError("worker is required for reap_worker")
253
+
254
+ worker_id = worker.get("worker_id")
255
+
256
+ # Mark as lost
257
+ await self._registration.update_status(worker_id, "lost")
258
+
259
+ # Release any jobs claimed by this worker
260
+ pool = self._work.pool(pool_id)
261
+ released = await pool.release_by_worker(worker_id)
262
+
263
+ context["reaped_worker_id"] = worker_id
264
+ context["jobs_released"] = released
265
+ return context
266
+
267
+ async def _reap_stale_workers(self, context: Dict[str, Any]) -> Dict[str, Any]:
268
+ """Reap all stale workers in one action (batch processing)."""
269
+ stale_workers = context.get("stale_workers", [])
270
+ pool_id = context.get("pool_id", "default")
271
+
272
+ reaped = []
273
+ total_jobs_released = 0
274
+
275
+ for worker in stale_workers:
276
+ worker_id = worker.get("worker_id")
277
+
278
+ # Mark as lost
279
+ await self._registration.update_status(worker_id, "lost")
280
+
281
+ # Release any jobs claimed by this worker
282
+ pool = self._work.pool(pool_id)
283
+ released = await pool.release_by_worker(worker_id)
284
+
285
+ reaped.append(worker_id)
286
+ total_jobs_released += released
287
+
288
+ context["reaped_workers"] = reaped
289
+ context["reaped_count"] = len(reaped)
290
+ context["total_jobs_released"] = total_jobs_released
291
+ return context
292
+
293
+ # -------------------------------------------------------------------------
294
+ # Auto-Scaling Actions (for parallelization checker)
295
+ # -------------------------------------------------------------------------
296
+
297
+ async def _calculate_spawn(self, context: Dict[str, Any]) -> Dict[str, Any]:
298
+ """Calculate how many workers to spawn based on pool state.
299
+
300
+ Default strategy: workers_needed = min(queue_depth, max_workers)
301
+ Override this method for custom scaling logic.
302
+ """
303
+ queue_depth = int(context.get("queue_depth", 0))
304
+ active_workers = int(context.get("active_workers", 0))
305
+ max_workers = int(context.get("max_workers", 3))
306
+
307
+ # Workers needed = min(queue_depth, max_workers)
308
+ workers_needed = min(queue_depth, max_workers)
309
+
310
+ # Workers to spawn = max(0, workers_needed - active_workers)
311
+ workers_to_spawn = max(0, workers_needed - active_workers)
312
+
313
+ context["workers_needed"] = workers_needed
314
+ context["workers_to_spawn"] = workers_to_spawn
315
+ context["spawn_list"] = list(range(workers_to_spawn)) # For foreach iteration
316
+ return context
317
+
318
+ async def _spawn_workers(self, context: Dict[str, Any]) -> Dict[str, Any]:
319
+ """Spawn worker subprocesses based on workers_to_spawn count.
320
+
321
+ Requires `worker_config_path` in context - the path to the worker YAML.
322
+ """
323
+ import uuid
324
+ from .actions import launch_machine
325
+
326
+ workers_to_spawn = int(context.get("workers_to_spawn", 0))
327
+ pool_id = context.get("pool_id", "default")
328
+ worker_config_path = context.get("worker_config_path")
329
+
330
+ if not worker_config_path and workers_to_spawn > 0:
331
+ raise ValueError("worker_config_path required in context for spawn_workers")
332
+
333
+ spawned_ids = []
334
+ for i in range(workers_to_spawn):
335
+ worker_id = f"worker-{uuid.uuid4().hex[:8]}"
336
+
337
+ # Launch worker in subprocess
338
+ launch_machine(
339
+ machine_config=worker_config_path,
340
+ input_data={
341
+ "pool_id": pool_id,
342
+ "worker_id": worker_id,
343
+ },
344
+ )
345
+
346
+ spawned_ids.append(worker_id)
347
+ logger.info(f"Spawned worker subprocess: {worker_id}")
348
+
349
+ context["spawned_ids"] = spawned_ids
350
+ context["spawned_count"] = len(spawned_ids)
351
+ return context