agyqueue 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agyqueue/storage.py ADDED
@@ -0,0 +1,423 @@
1
+ import sqlite3
2
+ import os
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from datetime import datetime, timezone
6
+ from typing import Optional, List
7
+ from contextlib import contextmanager
8
+ from agyqueue.models import Task, TaskStatus
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class BaseTaskStore(ABC):
13
+ """Abstract base class defining the interface for task storage backends."""
14
+
15
+ @abstractmethod
16
+ def save_task(self, task: Task) -> None:
17
+ """Saves a new task or updates an existing one."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_task(self, task_id: str) -> Optional[Task]:
22
+ """Retrieves a task by its unique ID."""
23
+ pass
24
+
25
+ @abstractmethod
26
+ def update_task(
27
+ self,
28
+ task_id: str,
29
+ status: TaskStatus,
30
+ progress: int,
31
+ step: str,
32
+ result: Optional[str] = None,
33
+ error: Optional[str] = None,
34
+ ) -> Optional[Task]:
35
+ """Updates the status, progress, current step, and optionally result/error of a task."""
36
+ pass
37
+
38
+ @abstractmethod
39
+ def touch_task(self, task_id: str) -> None:
40
+ """Touches the task to update its updated_at heartbeat timestamp."""
41
+ pass
42
+
43
+ @abstractmethod
44
+ def list_tasks(self) -> List[Task]:
45
+ """Lists all tasks ordered by creation date desc."""
46
+ pass
47
+
48
+ @abstractmethod
49
+ def get_subtasks(self, parent_id: str) -> List[Task]:
50
+ """Lists all child subtasks for a given parent task."""
51
+ pass
52
+
53
+
54
+ class SQLiteTaskStore(BaseTaskStore):
55
+ """SQLite implementation of the TaskStore interface."""
56
+
57
+ def __init__(self, db_path: Optional[str] = None):
58
+ self.db_path = db_path or os.environ.get("AGYQUEUE_DB_PATH", "agyqueue.db")
59
+ self._init_db()
60
+
61
+ def _get_conn(self) -> sqlite3.Connection:
62
+ conn = sqlite3.connect(self.db_path, timeout=30.0)
63
+ # Enable WAL mode for better concurrency in multi-process environments
64
+ conn.execute("PRAGMA journal_mode=WAL;")
65
+ conn.row_factory = sqlite3.Row
66
+ return conn
67
+
68
+ def _init_db(self):
69
+ with self._get_conn() as conn:
70
+ conn.execute("""
71
+ CREATE TABLE IF NOT EXISTS tasks (
72
+ task_id TEXT PRIMARY KEY,
73
+ prompt TEXT NOT NULL,
74
+ task_type TEXT NOT NULL,
75
+ status TEXT NOT NULL,
76
+ progress INTEGER NOT NULL,
77
+ step TEXT NOT NULL,
78
+ result TEXT,
79
+ error TEXT,
80
+ parent_id TEXT,
81
+ created_at TEXT NOT NULL,
82
+ updated_at TEXT NOT NULL
83
+ )
84
+ """)
85
+ conn.commit()
86
+
87
+ # Schema migration: add parent_id if table already existed without it
88
+ try:
89
+ conn.execute("ALTER TABLE tasks ADD COLUMN parent_id TEXT;")
90
+ conn.commit()
91
+ except sqlite3.OperationalError:
92
+ # Column already exists
93
+ pass
94
+
95
+ def save_task(self, task: Task) -> None:
96
+ with self._get_conn() as conn:
97
+ conn.execute(
98
+ """
99
+ INSERT OR REPLACE INTO tasks
100
+ (task_id, prompt, task_type, status, progress, step, result, error, parent_id, created_at, updated_at)
101
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
102
+ """,
103
+ (
104
+ task.task_id,
105
+ task.prompt,
106
+ task.task_type,
107
+ task.status.value,
108
+ task.progress,
109
+ task.step,
110
+ task.result,
111
+ task.error,
112
+ task.parent_id,
113
+ task.created_at,
114
+ task.updated_at,
115
+ ),
116
+ )
117
+ conn.commit()
118
+
119
+ def get_task(self, task_id: str) -> Optional[Task]:
120
+ with self._get_conn() as conn:
121
+ row = conn.execute(
122
+ "SELECT * FROM tasks WHERE task_id = ?", (task_id,)
123
+ ).fetchone()
124
+ if not row:
125
+ return None
126
+ return Task(
127
+ task_id=row["task_id"],
128
+ prompt=row["prompt"],
129
+ task_type=row["task_type"],
130
+ status=TaskStatus(row["status"]),
131
+ progress=row["progress"],
132
+ step=row["step"],
133
+ result=row["result"],
134
+ error=row["error"],
135
+ parent_id=row["parent_id"],
136
+ created_at=row["created_at"],
137
+ updated_at=row["updated_at"],
138
+ )
139
+
140
+ def update_task(
141
+ self,
142
+ task_id: str,
143
+ status: TaskStatus,
144
+ progress: int,
145
+ step: str,
146
+ result: Optional[str] = None,
147
+ error: Optional[str] = None,
148
+ ) -> Optional[Task]:
149
+ now = datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
150
+ with self._get_conn() as conn:
151
+ conn.execute(
152
+ """
153
+ UPDATE tasks
154
+ SET status = ?, progress = ?, step = ?, result = COALESCE(?, result), error = COALESCE(?, error), updated_at = ?
155
+ WHERE task_id = ?
156
+ """,
157
+ (status.value, progress, step, result, error, now, task_id),
158
+ )
159
+ conn.commit()
160
+ return self.get_task(task_id)
161
+
162
+ def touch_task(self, task_id: str) -> None:
163
+ now = datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
164
+ with self._get_conn() as conn:
165
+ conn.execute(
166
+ "UPDATE tasks SET updated_at = ? WHERE task_id = ?",
167
+ (now, task_id),
168
+ )
169
+ conn.commit()
170
+
171
+ def list_tasks(self) -> List[Task]:
172
+ with self._get_conn() as conn:
173
+ rows = conn.execute("SELECT * FROM tasks ORDER BY created_at DESC").fetchall()
174
+ return [
175
+ Task(
176
+ task_id=row["task_id"],
177
+ prompt=row["prompt"],
178
+ task_type=row["task_type"],
179
+ status=TaskStatus(row["status"]),
180
+ progress=row["progress"],
181
+ step=row["step"],
182
+ result=row["result"],
183
+ error=row["error"],
184
+ parent_id=row["parent_id"],
185
+ created_at=row["created_at"],
186
+ updated_at=row["updated_at"],
187
+ )
188
+ for row in rows
189
+ ]
190
+
191
+ def get_subtasks(self, parent_id: str) -> List[Task]:
192
+ with self._get_conn() as conn:
193
+ rows = conn.execute(
194
+ "SELECT * FROM tasks WHERE parent_id = ? ORDER BY created_at ASC", (parent_id,)
195
+ ).fetchall()
196
+ return [
197
+ Task(
198
+ task_id=row["task_id"],
199
+ prompt=row["prompt"],
200
+ task_type=row["task_type"],
201
+ status=TaskStatus(row["status"]),
202
+ progress=row["progress"],
203
+ step=row["step"],
204
+ result=row["result"],
205
+ error=row["error"],
206
+ parent_id=row["parent_id"],
207
+ created_at=row["created_at"],
208
+ updated_at=row["updated_at"],
209
+ )
210
+ for row in rows
211
+ ]
212
+
213
+
214
+ class PostgreSQLTaskStore(BaseTaskStore):
215
+ """PostgreSQL implementation of the TaskStore interface using connection pooling."""
216
+
217
+ _pool = None
218
+
219
+ def __init__(self, connection_url: Optional[str] = None):
220
+ self.connection_url = connection_url
221
+ if not self.connection_url:
222
+ from agyqueue.config import settings
223
+ self.connection_url = settings.database_url
224
+
225
+ self._init_pool()
226
+ self._init_db()
227
+
228
+ def _init_pool(self):
229
+ if PostgreSQLTaskStore._pool is None:
230
+ import psycopg2.pool
231
+ PostgreSQLTaskStore._pool = psycopg2.pool.ThreadedConnectionPool(
232
+ minconn=1,
233
+ maxconn=10,
234
+ dsn=self.connection_url
235
+ )
236
+ logger.info("PostgreSQL thread pool connection manager initialized.")
237
+
238
+ @contextmanager
239
+ def _get_conn(self):
240
+ conn = PostgreSQLTaskStore._pool.getconn()
241
+ conn.autocommit = True
242
+ try:
243
+ yield conn
244
+ finally:
245
+ PostgreSQLTaskStore._pool.putconn(conn)
246
+
247
+ def _init_db(self):
248
+ with self._get_conn() as conn:
249
+ with conn.cursor() as cur:
250
+ cur.execute("""
251
+ CREATE TABLE IF NOT EXISTS tasks (
252
+ task_id VARCHAR(50) PRIMARY KEY,
253
+ prompt TEXT NOT NULL,
254
+ task_type VARCHAR(50) NOT NULL,
255
+ status VARCHAR(20) NOT NULL,
256
+ progress INTEGER NOT NULL,
257
+ step TEXT NOT NULL,
258
+ result TEXT,
259
+ error TEXT,
260
+ parent_id VARCHAR(50),
261
+ created_at VARCHAR(30) NOT NULL,
262
+ updated_at VARCHAR(30) NOT NULL
263
+ )
264
+ """)
265
+
266
+ def save_task(self, task: Task) -> None:
267
+ with self._get_conn() as conn:
268
+ with conn.cursor() as cur:
269
+ cur.execute(
270
+ """
271
+ INSERT INTO tasks
272
+ (task_id, prompt, task_type, status, progress, step, result, error, parent_id, created_at, updated_at)
273
+ VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
274
+ ON CONFLICT (task_id) DO UPDATE SET
275
+ prompt = EXCLUDED.prompt,
276
+ task_type = EXCLUDED.task_type,
277
+ status = EXCLUDED.status,
278
+ progress = EXCLUDED.progress,
279
+ step = EXCLUDED.step,
280
+ result = EXCLUDED.result,
281
+ error = EXCLUDED.error,
282
+ parent_id = EXCLUDED.parent_id,
283
+ created_at = EXCLUDED.created_at,
284
+ updated_at = EXCLUDED.updated_at
285
+ """,
286
+ (
287
+ task.task_id,
288
+ task.prompt,
289
+ task.task_type,
290
+ task.status.value,
291
+ task.progress,
292
+ task.step,
293
+ task.result,
294
+ task.error,
295
+ task.parent_id,
296
+ task.created_at,
297
+ task.updated_at,
298
+ ),
299
+ )
300
+
301
+ def get_task(self, task_id: str) -> Optional[Task]:
302
+ from psycopg2.extras import RealDictCursor
303
+ with self._get_conn() as conn:
304
+ with conn.cursor(cursor_factory=RealDictCursor) as cur:
305
+ cur.execute("SELECT * FROM tasks WHERE task_id = %s", (task_id,))
306
+ row = cur.fetchone()
307
+ if not row:
308
+ return None
309
+ return Task(
310
+ task_id=row["task_id"],
311
+ prompt=row["prompt"],
312
+ task_type=row["task_type"],
313
+ status=TaskStatus(row["status"]),
314
+ progress=row["progress"],
315
+ step=row["step"],
316
+ result=row["result"],
317
+ error=row["error"],
318
+ parent_id=row["parent_id"],
319
+ created_at=row["created_at"],
320
+ updated_at=row["updated_at"],
321
+ )
322
+
323
+ def update_task(
324
+ self,
325
+ task_id: str,
326
+ status: TaskStatus,
327
+ progress: int,
328
+ step: str,
329
+ result: Optional[str] = None,
330
+ error: Optional[str] = None,
331
+ ) -> Optional[Task]:
332
+ now = datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
333
+ with self._get_conn() as conn:
334
+ with conn.cursor() as cur:
335
+ cur.execute(
336
+ """
337
+ UPDATE tasks
338
+ SET status = %s, progress = %s, step = %s, result = COALESCE(%s, result), error = COALESCE(%s, error), updated_at = %s
339
+ WHERE task_id = %s
340
+ """,
341
+ (status.value, progress, step, result, error, now, task_id),
342
+ )
343
+ return self.get_task(task_id)
344
+
345
+ def touch_task(self, task_id: str) -> None:
346
+ now = datetime.now(timezone.utc).replace(tzinfo=None).isoformat()
347
+ with self._get_conn() as conn:
348
+ with conn.cursor() as cur:
349
+ cur.execute(
350
+ "UPDATE tasks SET updated_at = %s WHERE task_id = %s",
351
+ (now, task_id),
352
+ )
353
+
354
+ def list_tasks(self) -> List[Task]:
355
+ from psycopg2.extras import RealDictCursor
356
+ with self._get_conn() as conn:
357
+ with conn.cursor(cursor_factory=RealDictCursor) as cur:
358
+ cur.execute("SELECT * FROM tasks ORDER BY created_at DESC")
359
+ rows = cur.fetchall()
360
+ return [
361
+ Task(
362
+ task_id=row["task_id"],
363
+ prompt=row["prompt"],
364
+ task_type=row["task_type"],
365
+ status=TaskStatus(row["status"]),
366
+ progress=row["progress"],
367
+ step=row["step"],
368
+ result=row["result"],
369
+ error=row["error"],
370
+ parent_id=row["parent_id"],
371
+ created_at=row["created_at"],
372
+ updated_at=row["updated_at"],
373
+ )
374
+ for row in rows
375
+ ]
376
+
377
+ def get_subtasks(self, parent_id: str) -> List[Task]:
378
+ from psycopg2.extras import RealDictCursor
379
+ with self._get_conn() as conn:
380
+ with conn.cursor(cursor_factory=RealDictCursor) as cur:
381
+ cur.execute(
382
+ "SELECT * FROM tasks WHERE parent_id = %s ORDER BY created_at ASC", (parent_id,)
383
+ )
384
+ rows = cur.fetchall()
385
+ return [
386
+ Task(
387
+ task_id=row["task_id"],
388
+ prompt=row["prompt"],
389
+ task_type=row["task_type"],
390
+ status=TaskStatus(row["status"]),
391
+ progress=row["progress"],
392
+ step=row["step"],
393
+ result=row["result"],
394
+ error=row["error"],
395
+ parent_id=row["parent_id"],
396
+ created_at=row["created_at"],
397
+ updated_at=row["updated_at"],
398
+ )
399
+ for row in rows
400
+ ]
401
+
402
+
403
+ def TaskStore(db_path: Optional[str] = None) -> BaseTaskStore:
404
+ """Factory function returning the configured TaskStore implementation.
405
+
406
+ Can be configured via environment variables (e.g., AGYQUEUE_STORE_TYPE).
407
+ """
408
+ from agyqueue.config import settings
409
+
410
+ store_type = settings.store_type
411
+
412
+ # If postgres is set or database_url starts with postgres, use PostgreSQLTaskStore
413
+ if store_type == "postgres" or (settings.database_url and settings.database_url.startswith("postgres")):
414
+ try:
415
+ return PostgreSQLTaskStore()
416
+ except ImportError:
417
+ logger.warning("psycopg2 not installed. Falling back to SQLiteTaskStore.")
418
+ return SQLiteTaskStore(db_path)
419
+ except Exception as e:
420
+ logger.error(f"Failed to connect to PostgreSQL: {e}. Falling back to SQLiteTaskStore.")
421
+ return SQLiteTaskStore(db_path)
422
+
423
+ return SQLiteTaskStore(db_path)
agyqueue/task_queue.py ADDED
@@ -0,0 +1,111 @@
1
+ import os
2
+ import time
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from typing import Optional
6
+ import redis
7
+ from agyqueue.storage import TaskStore
8
+ from agyqueue.models import TaskStatus
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class BaseTaskQueue(ABC):
13
+ """Abstract base class defining the interface for task queues."""
14
+
15
+ @abstractmethod
16
+ def enqueue(self, task_id: str) -> None:
17
+ """Pushes a task ID onto the queue."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def dequeue(self, timeout: int = 1) -> Optional[str]:
22
+ """Pops a task ID from the queue. Blocks up to `timeout` seconds if empty."""
23
+ pass
24
+
25
+
26
+ class RedisTaskQueue(BaseTaskQueue):
27
+ """Redis-backed implementation of the TaskQueue interface."""
28
+
29
+ def __init__(self, redis_url: str):
30
+ self.redis_url = redis_url
31
+ self.redis_client = redis.Redis.from_url(
32
+ self.redis_url,
33
+ socket_connect_timeout=2.0,
34
+ decode_responses=True
35
+ )
36
+ # Test connection
37
+ self.redis_client.ping()
38
+ logger.info(f"Connected to Redis queue at {self.redis_url}")
39
+
40
+ def enqueue(self, task_id: str) -> None:
41
+ self.redis_client.rpush("agyqueue:task_ids", task_id)
42
+ logger.info(f"Enqueued task {task_id} to Redis.")
43
+
44
+ def dequeue(self, timeout: int = 1) -> Optional[str]:
45
+ # blpop returns (queue_name, value)
46
+ res = self.redis_client.blpop("agyqueue:task_ids", timeout=timeout)
47
+ if res:
48
+ return res[1]
49
+ return None
50
+
51
+
52
+ class SQLiteTaskQueue(BaseTaskQueue):
53
+ """SQLite-backed polling implementation of the TaskQueue interface."""
54
+
55
+ def __init__(self, db_path: Optional[str] = None):
56
+ self.db_path = db_path or os.environ.get("AGYQUEUE_DB_PATH", "agyqueue.db")
57
+ self.store = TaskStore(self.db_path)
58
+
59
+ def enqueue(self, task_id: str) -> None:
60
+ # SQLite queue relies on the task status being set to QUEUED in the SQLite store
61
+ logger.info(f"Enqueued task {task_id} to SQLite store (waiting for polling worker).")
62
+
63
+ def dequeue(self, timeout: int = 1) -> Optional[str]:
64
+ try:
65
+ # We access the internal connection from the sqlite store to perform the dequeue transaction
66
+ # Since the store uses WAL mode, we can claim the task concurrently.
67
+ # Wait, since self.store is a SQLiteTaskStore instance, we can call _get_conn()
68
+ # If the store is abstract, we verify if it has _get_conn
69
+ if hasattr(self.store, "_get_conn"):
70
+ conn = self.store._get_conn()
71
+ with conn:
72
+ row = conn.execute(
73
+ "SELECT task_id FROM tasks WHERE status = ? ORDER BY created_at ASC LIMIT 1",
74
+ (TaskStatus.QUEUED.value,)
75
+ ).fetchone()
76
+
77
+ if row:
78
+ task_id = row["task_id"]
79
+ # Claim the task
80
+ conn.execute(
81
+ "UPDATE tasks SET status = ?, step = ?, updated_at = ? WHERE task_id = ?",
82
+ (
83
+ TaskStatus.RUNNING.value,
84
+ "Claimed by worker",
85
+ time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
86
+ task_id
87
+ )
88
+ )
89
+ return task_id
90
+ except Exception as e:
91
+ logger.error(f"SQLite dequeue failed: {e}")
92
+
93
+ time.sleep(timeout)
94
+ return None
95
+
96
+
97
+ def TaskQueue(redis_url: Optional[str] = None, db_path: Optional[str] = None) -> BaseTaskQueue:
98
+ """Factory function returning the configured TaskQueue implementation.
99
+
100
+ Tries to connect to Redis if a URL is provided or set in environment,
101
+ falling back to SQLite if Redis is unavailable.
102
+ """
103
+ url = redis_url or os.environ.get("REDIS_URL")
104
+ if url:
105
+ try:
106
+ return RedisTaskQueue(url)
107
+ except Exception as e:
108
+ logger.warning(f"Could not connect to Redis at {url}: {e}. Falling back to SQLite queue.")
109
+
110
+ logger.info("Using SQLite database as task queue.")
111
+ return SQLiteTaskQueue(db_path)