fairchild 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fairchild/worker.py ADDED
@@ -0,0 +1,495 @@
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import socket
5
+ import traceback
6
+ from datetime import datetime
7
+ from typing import Any
8
+ from uuid import uuid4
9
+
10
+ from fairchild.context import set_current_job, get_pending_children
11
+ from fairchild.fairchild import Fairchild
12
+ from fairchild.job import Job, JobState
13
+ from fairchild.record import Record
14
+ from fairchild.task import get_task
15
+
16
+
17
+ class Worker:
18
+ """A worker that processes jobs from a specific queue."""
19
+
20
+ def __init__(
21
+ self,
22
+ fairchild: Fairchild,
23
+ queue: str,
24
+ worker_id: int,
25
+ pool: "WorkerPool | None" = None,
26
+ ):
27
+ self.fairchild = fairchild
28
+ self.queue = queue
29
+ self.worker_id = worker_id
30
+ self.pool = pool
31
+ self._running = False
32
+ self._current_job: Job | None = None
33
+
34
+ @property
35
+ def name(self) -> str:
36
+ return f"{self.queue}:{self.worker_id}"
37
+
38
+ async def run(self):
39
+ """Main worker loop."""
40
+ self._running = True
41
+ print(f"[{self.name}] Started")
42
+
43
+ while self._running:
44
+ try:
45
+ # Check if pool is paused
46
+ if self.pool and self.pool.is_paused:
47
+ await asyncio.sleep(1.0)
48
+ continue
49
+
50
+ job = await self._fetch_job()
51
+
52
+ if job is None:
53
+ # No job available, wait before polling again
54
+ await asyncio.sleep(1.0)
55
+ continue
56
+
57
+ self._current_job = job
58
+ await self._execute_job(job)
59
+ self._current_job = None
60
+
61
+ except asyncio.CancelledError:
62
+ break
63
+ except Exception as e:
64
+ print(f"[{self.name}] Error in worker loop: {e}")
65
+ await asyncio.sleep(1.0)
66
+
67
+ print(f"[{self.name}] Stopped")
68
+
69
+ async def _fetch_job(self) -> Job | None:
70
+ """Fetch the next available job from the queue."""
71
+ # Find and lock a ready job
72
+ # Jobs are available if:
73
+ # 1. state = 'available' and scheduled_at <= now()
74
+ # 2. All deps (job IDs) are completed
75
+ select_query = """
76
+ SELECT id FROM fairchild_jobs
77
+ WHERE queue = $1
78
+ AND state = 'available'
79
+ AND scheduled_at <= now()
80
+ AND (
81
+ -- Either no deps or all deps completed
82
+ deps = '{}'
83
+ OR NOT EXISTS (
84
+ SELECT 1 FROM fairchild_jobs dep
85
+ WHERE dep.id::text = ANY(fairchild_jobs.deps)
86
+ AND dep.state != 'completed'
87
+ )
88
+ )
89
+ ORDER BY priority, scheduled_at
90
+ LIMIT 1
91
+ FOR UPDATE SKIP LOCKED
92
+ """
93
+
94
+ # Then update and return it
95
+ update_query = """
96
+ UPDATE fairchild_jobs
97
+ SET state = 'running',
98
+ attempted_at = now(),
99
+ attempt = attempt + 1,
100
+ updated_at = now()
101
+ WHERE id = $1
102
+ RETURNING *
103
+ """
104
+
105
+ async with self.fairchild._pool.acquire() as conn:
106
+ async with conn.transaction():
107
+ row = await conn.fetchrow(select_query, self.queue)
108
+ if row is None:
109
+ return None
110
+
111
+ row = await conn.fetchrow(update_query, row["id"])
112
+
113
+ return Job.from_row(dict(row))
114
+
115
+ async def _resolve_future_args(self, args: dict) -> dict:
116
+ """Resolve any __future__ markers in args to their actual values."""
117
+
118
+ async def _resolve(obj):
119
+ if isinstance(obj, dict):
120
+ # Check if this is a future marker
121
+ if "__future__" in obj and len(obj) == 1:
122
+ job_id = obj["__future__"]
123
+ # Fetch the recorded result from the completed job
124
+ query = """
125
+ SELECT recorded FROM fairchild_jobs
126
+ WHERE id = $1 AND state = 'completed'
127
+ """
128
+ from uuid import UUID
129
+
130
+ result = await self.fairchild._pool.fetchval(query, UUID(job_id))
131
+ if result is not None:
132
+ import json
133
+
134
+ return json.loads(result)
135
+ return None
136
+ # Regular dict - recurse
137
+ return {k: await _resolve(v) for k, v in obj.items()}
138
+ elif isinstance(obj, list):
139
+ return [await _resolve(item) for item in obj]
140
+ return obj
141
+
142
+ return await _resolve(args)
143
+
144
+ async def _execute_job(self, job: Job):
145
+ """Execute a job."""
146
+ print(f"[{self.name}] Processing job {job.id} ({job.task_name})")
147
+
148
+ try:
149
+ task = get_task(job.task_name)
150
+ except ValueError as e:
151
+ await self._fail_job(job, str(e), discard=True)
152
+ return
153
+
154
+ try:
155
+ # Build arguments, resolving any futures
156
+ kwargs = await self._resolve_future_args(dict(job.args))
157
+
158
+ # Inject job if the task accepts it
159
+ if task._accepts_job:
160
+ kwargs["job"] = job
161
+
162
+ # Set context so spawned tasks know they're inside a worker
163
+ set_current_job(job, self.fairchild)
164
+
165
+ try:
166
+ # Execute the task
167
+ result = task.fn(**kwargs)
168
+
169
+ # Handle async tasks
170
+ if asyncio.iscoroutine(result):
171
+ result = await result
172
+
173
+ # Get any child jobs that were queued during execution
174
+ pending_children = get_pending_children()
175
+ finally:
176
+ # Always clear context after execution
177
+ set_current_job(None)
178
+
179
+ # Insert any child jobs that were spawned
180
+ for child_job in pending_children:
181
+ await self.fairchild._insert_job(child_job)
182
+
183
+ # Handle Record() return values
184
+ recorded_value = None
185
+ if isinstance(result, Record):
186
+ recorded_value = result.value
187
+
188
+ # Check if this job spawned children - if so, wait for them
189
+ has_children = len(pending_children) > 0
190
+ if has_children:
191
+ # Keep job in running state, it will be completed when children finish
192
+ await self._mark_waiting_for_children(job, recorded_value)
193
+ print(f"[{self.name}] Job {job.id} waiting for children to complete")
194
+ else:
195
+ await self._complete_job(job, recorded_value)
196
+ print(f"[{self.name}] Completed job {job.id}")
197
+
198
+ except Exception as e:
199
+ error_info = {
200
+ "attempt": job.attempt,
201
+ "error": str(e),
202
+ "traceback": traceback.format_exc(),
203
+ "at": datetime.utcnow().isoformat(),
204
+ }
205
+
206
+ if job.attempt >= job.max_attempts:
207
+ await self._fail_job(job, str(e), error_info=error_info, discard=True)
208
+ print(
209
+ f"[{self.name}] Discarded job {job.id} after {job.attempt} attempts: {e}"
210
+ )
211
+ else:
212
+ await self._fail_job(job, str(e), error_info=error_info, discard=False)
213
+ print(
214
+ f"[{self.name}] Job {job.id} failed (attempt {job.attempt}/{job.max_attempts}): {e}"
215
+ )
216
+
217
+ async def _complete_job(self, job: Job, recorded: Any | None):
218
+ """Mark a job as completed."""
219
+ query = """
220
+ UPDATE fairchild_jobs
221
+ SET state = 'completed',
222
+ completed_at = now(),
223
+ recorded = $2,
224
+ updated_at = now()
225
+ WHERE id = $1
226
+ """
227
+
228
+ recorded_json = json.dumps(recorded) if recorded is not None else None
229
+ await self.fairchild._pool.execute(query, job.id, recorded_json)
230
+
231
+ # Check if this completion unblocks jobs waiting on this one
232
+ await self._check_child_deps(job)
233
+
234
+ # Check if this completion allows a parent job to complete
235
+ if job.parent_id:
236
+ await self._check_parent_completion(job.parent_id)
237
+
238
+ async def _has_pending_children(self, job: Job) -> bool:
239
+ """Check if this job has any pending (non-completed) children."""
240
+ query = """
241
+ SELECT EXISTS(
242
+ SELECT 1 FROM fairchild_jobs
243
+ WHERE parent_id = $1
244
+ AND state != 'completed'
245
+ )
246
+ """
247
+ return await self.fairchild._pool.fetchval(query, job.id)
248
+
249
+ async def _mark_waiting_for_children(self, job: Job, recorded: Any | None):
250
+ """Mark a job as waiting for children (stays in running state but stores result)."""
251
+ # Store the recorded value so we can use it when completing later
252
+ query = """
253
+ UPDATE fairchild_jobs
254
+ SET recorded = $2,
255
+ updated_at = now()
256
+ WHERE id = $1
257
+ """
258
+ recorded_json = json.dumps(recorded) if recorded is not None else None
259
+ await self.fairchild._pool.execute(query, job.id, recorded_json)
260
+
261
+ async def _check_parent_completion(self, parent_id):
262
+ """Check if a parent job can be completed (all children done)."""
263
+ # Check if all children are completed
264
+ query = """
265
+ SELECT NOT EXISTS(
266
+ SELECT 1 FROM fairchild_jobs
267
+ WHERE parent_id = $1
268
+ AND state != 'completed'
269
+ )
270
+ """
271
+ all_children_done = await self.fairchild._pool.fetchval(query, parent_id)
272
+
273
+ if all_children_done:
274
+ # Complete the parent job
275
+ complete_query = """
276
+ UPDATE fairchild_jobs
277
+ SET state = 'completed',
278
+ completed_at = now(),
279
+ updated_at = now()
280
+ WHERE id = $1
281
+ AND state = 'running'
282
+ """
283
+ await self.fairchild._pool.execute(complete_query, parent_id)
284
+
285
+ # Check if the parent itself has a parent
286
+ parent_parent_query = """
287
+ SELECT parent_id FROM fairchild_jobs WHERE id = $1
288
+ """
289
+ parent_parent_id = await self.fairchild._pool.fetchval(
290
+ parent_parent_query, parent_id
291
+ )
292
+ if parent_parent_id:
293
+ await self._check_parent_completion(parent_parent_id)
294
+
295
+ async def _check_child_deps(self, completed_job: Job):
296
+ """Check if completing this job unblocks jobs waiting on it as a dependency."""
297
+ # Find jobs that have this job's ID in their deps array and are scheduled
298
+ # If all their deps are now completed, make them available
299
+ job_id_str = str(completed_job.id)
300
+ query = """
301
+ UPDATE fairchild_jobs
302
+ SET state = 'available', updated_at = now()
303
+ WHERE state = 'scheduled'
304
+ AND $1 = ANY(deps)
305
+ AND NOT EXISTS (
306
+ SELECT 1 FROM fairchild_jobs dep
307
+ WHERE dep.id::text = ANY(fairchild_jobs.deps)
308
+ AND dep.state != 'completed'
309
+ )
310
+ """
311
+ await self.fairchild._pool.execute(query, job_id_str)
312
+
313
+ async def _fail_job(
314
+ self,
315
+ job: Job,
316
+ error: str,
317
+ error_info: dict | None = None,
318
+ discard: bool = False,
319
+ ):
320
+ """Mark a job as failed or discarded."""
321
+ new_state = "discarded" if discard else "available"
322
+
323
+ if error_info:
324
+ # Append error to errors array
325
+ query = """
326
+ UPDATE fairchild_jobs
327
+ SET state = $2,
328
+ errors = errors || $3::jsonb,
329
+ updated_at = now()
330
+ WHERE id = $1
331
+ """
332
+ await self.fairchild._pool.execute(
333
+ query,
334
+ job.id,
335
+ new_state,
336
+ json.dumps(error_info),
337
+ )
338
+ else:
339
+ query = """
340
+ UPDATE fairchild_jobs
341
+ SET state = $2, updated_at = now()
342
+ WHERE id = $1
343
+ """
344
+ await self.fairchild._pool.execute(query, job.id, new_state)
345
+
346
+ def stop(self):
347
+ """Signal the worker to stop."""
348
+ self._running = False
349
+
350
+
351
+ class WorkerPool:
352
+ """Manages multiple workers across queues."""
353
+
354
+ def __init__(self, fairchild: Fairchild, queue_config: dict[str, int]):
355
+ """
356
+ Args:
357
+ fairchild: Fairchild instance
358
+ queue_config: Dict mapping queue name to worker count
359
+ """
360
+ self.fairchild = fairchild
361
+ self.queue_config = queue_config
362
+ self.workers: list[Worker] = []
363
+ self._tasks: list[asyncio.Task] = []
364
+ self._heartbeat_task: asyncio.Task | None = None
365
+
366
+ # Worker pool identity
367
+ self.id = uuid4()
368
+ self.hostname = socket.gethostname()
369
+ self.pid = os.getpid()
370
+ self._paused = False
371
+
372
+ async def run(self):
373
+ """Start all workers and wait for them."""
374
+ # Register this worker pool
375
+ await self._register()
376
+
377
+ # Create workers
378
+ for queue, count in self.queue_config.items():
379
+ for i in range(count):
380
+ worker = Worker(self.fairchild, queue, i, pool=self)
381
+ self.workers.append(worker)
382
+
383
+ # Start worker tasks
384
+ self._tasks = [asyncio.create_task(worker.run()) for worker in self.workers]
385
+
386
+ # Start heartbeat task
387
+ self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
388
+
389
+ total = sum(self.queue_config.values())
390
+ print(f"Started {total} workers across {len(self.queue_config)} queues")
391
+
392
+ # Wait for all tasks (they run until cancelled)
393
+ try:
394
+ await asyncio.gather(*self._tasks)
395
+ except asyncio.CancelledError:
396
+ pass
397
+
398
+ async def _register(self):
399
+ """Register this worker pool in the database."""
400
+ query = """
401
+ INSERT INTO fairchild_workers (id, hostname, pid, queues, state)
402
+ VALUES ($1, $2, $3, $4::jsonb, 'running')
403
+ ON CONFLICT (id) DO UPDATE SET
404
+ hostname = EXCLUDED.hostname,
405
+ pid = EXCLUDED.pid,
406
+ queues = EXCLUDED.queues,
407
+ state = 'running',
408
+ last_heartbeat_at = now()
409
+ """
410
+ # asyncpg requires JSON as a string for jsonb columns
411
+ import json as json_module
412
+
413
+ await self.fairchild._pool.execute(
414
+ query,
415
+ self.id,
416
+ self.hostname,
417
+ self.pid,
418
+ json_module.dumps(self.queue_config),
419
+ )
420
+ print(f"Registered worker pool {self.id} ({self.hostname}:{self.pid})")
421
+
422
+ async def _heartbeat_loop(self):
423
+ """Periodically send heartbeats and check for pause state."""
424
+ while True:
425
+ try:
426
+ await asyncio.sleep(5) # Heartbeat every 5 seconds
427
+ await self._send_heartbeat()
428
+ except asyncio.CancelledError:
429
+ break
430
+ except Exception as e:
431
+ print(f"Heartbeat error: {e}")
432
+
433
+ async def _send_heartbeat(self):
434
+ """Send a heartbeat and check pause state."""
435
+ active_jobs = sum(1 for w in self.workers if w._current_job is not None)
436
+
437
+ query = """
438
+ UPDATE fairchild_workers
439
+ SET last_heartbeat_at = now(),
440
+ active_jobs = $2
441
+ WHERE id = $1
442
+ RETURNING state
443
+ """
444
+ state = await self.fairchild._pool.fetchval(query, self.id, active_jobs)
445
+
446
+ # Update pause state based on database
447
+ new_paused = state == "paused"
448
+ if new_paused != self._paused:
449
+ self._paused = new_paused
450
+ if self._paused:
451
+ print(f"Worker pool {self.id} paused")
452
+ else:
453
+ print(f"Worker pool {self.id} resumed")
454
+
455
+ @property
456
+ def is_paused(self) -> bool:
457
+ return self._paused
458
+
459
+ async def _unregister(self):
460
+ """Mark this worker pool as stopped."""
461
+ query = """
462
+ UPDATE fairchild_workers
463
+ SET state = 'stopped', active_jobs = 0
464
+ WHERE id = $1
465
+ """
466
+ await self.fairchild._pool.execute(query, self.id)
467
+
468
+ async def shutdown(self):
469
+ """Gracefully shutdown all workers."""
470
+ print("Shutting down workers...")
471
+
472
+ # Cancel heartbeat task
473
+ if self._heartbeat_task:
474
+ self._heartbeat_task.cancel()
475
+ try:
476
+ await self._heartbeat_task
477
+ except asyncio.CancelledError:
478
+ pass
479
+
480
+ # Signal all workers to stop
481
+ for worker in self.workers:
482
+ worker.stop()
483
+
484
+ # Cancel all tasks
485
+ for task in self._tasks:
486
+ task.cancel()
487
+
488
+ # Wait for tasks to finish
489
+ if self._tasks:
490
+ await asyncio.gather(*self._tasks, return_exceptions=True)
491
+
492
+ # Unregister from database
493
+ await self._unregister()
494
+
495
+ print("All workers stopped")