modelq 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modelq-0.1.0/LICENSE +21 -0
- modelq-0.1.0/PKG-INFO +17 -0
- modelq-0.1.0/README.md +1 -0
- modelq-0.1.0/modelq/__init__.py +6 -0
- modelq-0.1.0/modelq/app/__init__.py +5 -0
- modelq-0.1.0/modelq/app/base.py +263 -0
- modelq-0.1.0/modelq/app/cache/__init__.py +5 -0
- modelq-0.1.0/modelq/app/cache/base.py +76 -0
- modelq-0.1.0/modelq/app/middleware/__init__.py +5 -0
- modelq-0.1.0/modelq/app/middleware/base.py +18 -0
- modelq-0.1.0/modelq/app/tasks/__init__.py +5 -0
- modelq-0.1.0/modelq/app/tasks/base.py +127 -0
- modelq-0.1.0/modelq/app/utils/__init__.py +5 -0
- modelq-0.1.0/modelq/app/utils/base64.py +14 -0
- modelq-0.1.0/modelq/exceptions.py +13 -0
- modelq-0.1.0/pyproject.toml +18 -0
modelq-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 ModelsLab
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
modelq-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: modelq
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Celery-like task queue for ML inference.
|
|
5
|
+
Author: Tanmaypatil123
|
|
6
|
+
Author-email: tanmay@modelslab.com
|
|
7
|
+
Requires-Python: >=3.9,<4.0
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Requires-Dist: click (>=8.0.0,<9.0.0)
|
|
14
|
+
Requires-Dist: redis (>=4.0.0,<5.0.0)
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# celery-plus-plus
|
modelq-0.1.0/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# celery-plus-plus
|
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
from typing import Optional, Dict, Any, Generator
|
|
2
|
+
import redis
|
|
3
|
+
import json
|
|
4
|
+
import functools
|
|
5
|
+
import threading
|
|
6
|
+
import time
|
|
7
|
+
import sqlite3
|
|
8
|
+
import uuid
|
|
9
|
+
import logging
|
|
10
|
+
from modelq.app.tasks import Task
|
|
11
|
+
from modelq.exceptions import TaskProcessingError, TaskTimeoutError
|
|
12
|
+
from modelq.app.cache import Cache
|
|
13
|
+
from modelq.app.middleware import Middleware
|
|
14
|
+
|
|
15
|
+
# Set up logging configuration
|
|
16
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
class ModelQ:
|
|
20
|
+
"""ModelQ class for managing machine learning tasks with Redis queueing and streaming."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
host: str = "localhost",
|
|
25
|
+
server_id: Optional[str] = None,
|
|
26
|
+
username: str = None,
|
|
27
|
+
port: str = 6379,
|
|
28
|
+
db: int = 0,
|
|
29
|
+
password: str = None,
|
|
30
|
+
ssl: bool = False,
|
|
31
|
+
ssl_cert_reqs: any = None,
|
|
32
|
+
cache_db_path: str = "cache.db",
|
|
33
|
+
**kwargs
|
|
34
|
+
):
|
|
35
|
+
self.redis_client = self._connect_to_redis(
|
|
36
|
+
host=host,
|
|
37
|
+
port=port,
|
|
38
|
+
db=db,
|
|
39
|
+
password=password,
|
|
40
|
+
username=username,
|
|
41
|
+
ssl=ssl,
|
|
42
|
+
ssl_cert_reqs=ssl_cert_reqs,
|
|
43
|
+
**kwargs
|
|
44
|
+
)
|
|
45
|
+
self.server_id = server_id or str(uuid.uuid4())
|
|
46
|
+
self.allowed_tasks = set()
|
|
47
|
+
self.cache_db_path = cache_db_path
|
|
48
|
+
self.cache = Cache(db_path=cache_db_path)
|
|
49
|
+
self.task_configurations: Dict[str, Dict[str, Any]] = {}
|
|
50
|
+
self.middleware: Middleware = None
|
|
51
|
+
self.register_server()
|
|
52
|
+
self.worker_threads = []
|
|
53
|
+
|
|
54
|
+
def _connect_to_redis(
|
|
55
|
+
self, host: str, port: str, db: int, password: str, ssl: bool, ssl_cert_reqs: any, username: str
|
|
56
|
+
) -> redis.Redis:
|
|
57
|
+
if host == "localhost":
|
|
58
|
+
connection = redis.Redis(host="localhost", db=3)
|
|
59
|
+
else:
|
|
60
|
+
connection = redis.Redis(
|
|
61
|
+
host=host,
|
|
62
|
+
port=port,
|
|
63
|
+
password=password,
|
|
64
|
+
username=username,
|
|
65
|
+
ssl=ssl,
|
|
66
|
+
ssl_cert_reqs=ssl
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return connection
|
|
70
|
+
|
|
71
|
+
def register_server(self):
|
|
72
|
+
"""Registers the server in Redis with its capabilities."""
|
|
73
|
+
self.redis_client.hset("servers", self.server_id, json.dumps({"allowed_tasks": list(self.allowed_tasks), "status": "idle"}))
|
|
74
|
+
|
|
75
|
+
def update_server_status(self, status: str):
|
|
76
|
+
"""Updates the server status in Redis."""
|
|
77
|
+
server_data = json.loads(self.redis_client.hget("servers", self.server_id))
|
|
78
|
+
server_data["status"] = status
|
|
79
|
+
self.redis_client.hset("servers", self.server_id, json.dumps(server_data))
|
|
80
|
+
|
|
81
|
+
def enqueue_task(self, task_name: str, payload: dict):
|
|
82
|
+
task = {
|
|
83
|
+
**task_name,
|
|
84
|
+
"status": "queued"
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
self.redis_client.rpush("ml_tasks", json.dumps(task))
|
|
88
|
+
|
|
89
|
+
def task(self, task_class=Task, timeout: Optional[int] = None, stream: bool = False, retries: int = 0):
|
|
90
|
+
"""Decorator to create a task. Allows specifying a custom task class, timeout, streaming support, and retries."""
|
|
91
|
+
def decorator(func):
|
|
92
|
+
@functools.wraps(func)
|
|
93
|
+
def wrapper(*args, **kwargs):
|
|
94
|
+
task_name = func.__name__
|
|
95
|
+
payload = {
|
|
96
|
+
"args": args,
|
|
97
|
+
"kwargs": kwargs,
|
|
98
|
+
"timeout": timeout,
|
|
99
|
+
"stream": stream,
|
|
100
|
+
"retries": retries
|
|
101
|
+
}
|
|
102
|
+
task = task_class(task_name=task_name, payload=payload)
|
|
103
|
+
if stream:
|
|
104
|
+
task.stream = True
|
|
105
|
+
self.enqueue_task(task.to_dict(), payload=payload)
|
|
106
|
+
# Store task in cache
|
|
107
|
+
task.store_in_cache(self.cache_db_path)
|
|
108
|
+
return task
|
|
109
|
+
# Attach the function to the instance so it can be called by process_task
|
|
110
|
+
setattr(self, func.__name__, func)
|
|
111
|
+
self.allowed_tasks.add(func.__name__)
|
|
112
|
+
self.register_server()
|
|
113
|
+
return wrapper
|
|
114
|
+
return decorator
|
|
115
|
+
|
|
116
|
+
def start_workers(self, no_of_workers: int = 1):
|
|
117
|
+
if any(thread.is_alive() for thread in self.worker_threads):
|
|
118
|
+
return # Workers are already running
|
|
119
|
+
|
|
120
|
+
self.check_middleware("before_worker_boot")
|
|
121
|
+
|
|
122
|
+
def worker_loop(worker_id):
|
|
123
|
+
while True:
|
|
124
|
+
try:
|
|
125
|
+
# Update server status to idle while waiting for tasks
|
|
126
|
+
self.update_server_status(f"worker_{worker_id}: idle")
|
|
127
|
+
task_data = self.redis_client.blpop("ml_tasks")
|
|
128
|
+
if task_data:
|
|
129
|
+
# Update server status to busy when a task is picked up
|
|
130
|
+
self.update_server_status(f"worker_{worker_id}: busy")
|
|
131
|
+
_, task_json = task_data
|
|
132
|
+
task_dict = json.loads(task_json)
|
|
133
|
+
task = Task.from_dict(task_dict)
|
|
134
|
+
if task.task_name in self.allowed_tasks:
|
|
135
|
+
try:
|
|
136
|
+
logger.info(f"Worker {worker_id} started processing task: {task.task_name}")
|
|
137
|
+
start_time = time.time()
|
|
138
|
+
self.process_task(task)
|
|
139
|
+
end_time = time.time()
|
|
140
|
+
logger.info(f"Worker {worker_id} finished task: {task.task_name} in {end_time - start_time:.2f} seconds")
|
|
141
|
+
except TaskProcessingError as e:
|
|
142
|
+
logger.error(f"Worker {worker_id} encountered a TaskProcessingError while processing task '{task.task_name}': {e}")
|
|
143
|
+
if task.payload.get("retries", 0) > 0:
|
|
144
|
+
task.payload["retries"] -= 1
|
|
145
|
+
self.enqueue_task(task.to_dict(), payload=task.payload)
|
|
146
|
+
except Exception as e:
|
|
147
|
+
logger.error(f"Worker {worker_id} encountered an unexpected error while processing task '{task.task_name}': {e}")
|
|
148
|
+
if task.payload.get("retries", 0) > 0:
|
|
149
|
+
task.payload["retries"] -= 1
|
|
150
|
+
self.enqueue_task(task.to_dict(), payload=task.payload)
|
|
151
|
+
else:
|
|
152
|
+
# Requeue the task if this server cannot process it
|
|
153
|
+
self.redis_client.rpush("ml_tasks", task_json)
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.error(f"Worker {worker_id} crashed with error: {e}. Restarting worker...")
|
|
156
|
+
|
|
157
|
+
for i in range(no_of_workers):
|
|
158
|
+
worker_thread = threading.Thread(target=worker_loop, args=(i,))
|
|
159
|
+
worker_thread.daemon = True
|
|
160
|
+
worker_thread.start()
|
|
161
|
+
self.worker_threads.append(worker_thread)
|
|
162
|
+
|
|
163
|
+
# Log after all workers have started
|
|
164
|
+
task_names = ', '.join(self.allowed_tasks) if self.allowed_tasks else 'No tasks registered'
|
|
165
|
+
logger.info(f"ModelQ worker started successfully with {no_of_workers} worker(s). Connected to Redis at {self.redis_client.connection_pool.connection_kwargs['host']}:{self.redis_client.connection_pool.connection_kwargs['port']}. Registered tasks: {task_names}")
|
|
166
|
+
|
|
167
|
+
def check_middleware(self, middleware_event: str):
|
|
168
|
+
logger.info(f"Middleware event triggered: {middleware_event}")
|
|
169
|
+
if self.middleware:
|
|
170
|
+
self.middleware.execute(event=middleware_event)
|
|
171
|
+
|
|
172
|
+
def process_task(self, task: Task) -> None:
|
|
173
|
+
"""Processes a given task."""
|
|
174
|
+
if task.task_name in self.allowed_tasks:
|
|
175
|
+
task_function = getattr(self, task.task_name, None)
|
|
176
|
+
if task_function:
|
|
177
|
+
try:
|
|
178
|
+
logger.info(f"Processing task: {task.task_name} with args: {task.payload.get('args', [])} and kwargs: {task.payload.get('kwargs', {})}")
|
|
179
|
+
start_time = time.time()
|
|
180
|
+
timeout = task.payload.get("timeout", None)
|
|
181
|
+
if task.payload.get("stream", False):
|
|
182
|
+
for result in task_function(*task.payload.get("args", []), **task.payload.get("kwargs", {})):
|
|
183
|
+
task.status = "in_progress"
|
|
184
|
+
self.redis_client.xadd(f"task_stream:{task.task_id}", {"result": json.dumps(result)})
|
|
185
|
+
# Mark the task as completed when streaming ends
|
|
186
|
+
task.status = "completed"
|
|
187
|
+
self.redis_client.set(f"task_result:{task.task_id}", json.dumps(task.to_dict()), ex=3600)
|
|
188
|
+
else:
|
|
189
|
+
if timeout:
|
|
190
|
+
result = self._run_with_timeout(task_function, timeout, *task.payload.get("args", []), **task.payload.get("kwargs", {}))
|
|
191
|
+
else:
|
|
192
|
+
result = task_function(*task.payload.get("args", []), **task.payload.get("kwargs", {}))
|
|
193
|
+
result_str = task._convert_to_string(result)
|
|
194
|
+
task.result = result_str
|
|
195
|
+
task.status = "completed"
|
|
196
|
+
self.redis_client.set(f"task_result:{task.task_id}", json.dumps(task.to_dict()), ex=3600)
|
|
197
|
+
end_time = time.time()
|
|
198
|
+
logger.info(f"Task {task.task_name} completed successfully in {end_time - start_time:.2f} seconds")
|
|
199
|
+
# Store updated task status in cache
|
|
200
|
+
task.store_in_cache(self.cache_db_path)
|
|
201
|
+
except Exception as e:
|
|
202
|
+
task.status = "failed"
|
|
203
|
+
task.result = str(e)
|
|
204
|
+
self.redis_client.set(f"task_result:{task.task_id}", json.dumps(task.to_dict()), ex=3600)
|
|
205
|
+
# Store failed task status in cache
|
|
206
|
+
task.store_in_cache(self.cache_db_path)
|
|
207
|
+
logger.error(f"Task {task.task_name} failed with error: {e}")
|
|
208
|
+
raise TaskProcessingError(task.task_name, str(e))
|
|
209
|
+
else:
|
|
210
|
+
task.status = "failed"
|
|
211
|
+
task.result = "Task function not found"
|
|
212
|
+
self.redis_client.set(f"task_result:{task.task_id}", json.dumps(task.to_dict()), ex=3600)
|
|
213
|
+
# Store failed task status in cache
|
|
214
|
+
task.store_in_cache(self.cache_db_path)
|
|
215
|
+
logger.error(f"Task {task.task_name} failed because the task function was not found")
|
|
216
|
+
raise TaskProcessingError(task.task_name, "Task function not found")
|
|
217
|
+
else:
|
|
218
|
+
task.status = "failed"
|
|
219
|
+
task.result = "Task not allowed"
|
|
220
|
+
self.redis_client.set(f"task_result:{task.task_id}", json.dumps(task.to_dict()), ex=3600)
|
|
221
|
+
# Store failed task status in cache
|
|
222
|
+
task.store_in_cache(self.cache_db_path)
|
|
223
|
+
logger.error(f"Task {task.task_name} is not allowed")
|
|
224
|
+
raise TaskProcessingError(task.task_name, "Task not allowed")
|
|
225
|
+
|
|
226
|
+
def _run_with_timeout(self, func, timeout, *args, **kwargs):
|
|
227
|
+
"""Runs a function with a timeout."""
|
|
228
|
+
result = [None]
|
|
229
|
+
exception = [None]
|
|
230
|
+
|
|
231
|
+
def target():
|
|
232
|
+
try:
|
|
233
|
+
result[0] = func(*args, **kwargs)
|
|
234
|
+
except Exception as e:
|
|
235
|
+
exception[0] = e
|
|
236
|
+
|
|
237
|
+
thread = threading.Thread(target=target)
|
|
238
|
+
thread.start()
|
|
239
|
+
thread.join(timeout)
|
|
240
|
+
if thread.is_alive():
|
|
241
|
+
logger.error(f"Task exceeded timeout of {timeout} seconds")
|
|
242
|
+
raise TaskTimeoutError(f"Task exceeded timeout of {timeout} seconds")
|
|
243
|
+
if exception[0]:
|
|
244
|
+
raise exception[0]
|
|
245
|
+
return result[0]
|
|
246
|
+
|
|
247
|
+
def get_all_queued_tasks(self) -> list:
|
|
248
|
+
"""Retrieves all tasks with status 'queued' from the SQLite cache database."""
|
|
249
|
+
with sqlite3.connect(self.cache_db_path) as conn:
|
|
250
|
+
cursor = conn.cursor()
|
|
251
|
+
cursor.execute('SELECT task_id, task_name, status FROM tasks WHERE status = ?', ("queued",))
|
|
252
|
+
rows = cursor.fetchall()
|
|
253
|
+
return [{"task_id": row[0], "task_name": row[1], "status": row[2]} for row in rows]
|
|
254
|
+
|
|
255
|
+
def get_task_status(self, task_id: str) -> Optional[str]:
|
|
256
|
+
"""Retrieves the status of a particular task by task ID from the SQLite cache database."""
|
|
257
|
+
with sqlite3.connect(self.cache_db_path) as conn:
|
|
258
|
+
cursor = conn.cursor()
|
|
259
|
+
cursor.execute('SELECT status FROM tasks WHERE task_id = ?', (task_id,))
|
|
260
|
+
row = cursor.fetchone()
|
|
261
|
+
if row:
|
|
262
|
+
return row[0]
|
|
263
|
+
return None
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import sqlite3
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
from typing import Optional
|
|
5
|
+
from modelq.app.tasks import Task
|
|
6
|
+
|
|
7
|
+
class Cache:
|
|
8
|
+
|
|
9
|
+
def __init__(self, db_path: str = "cache.db") -> None:
|
|
10
|
+
self.db_path = db_path
|
|
11
|
+
self._initialize_db()
|
|
12
|
+
|
|
13
|
+
def _initialize_db(self):
|
|
14
|
+
"""Initializes the SQLite database if it doesn't exist."""
|
|
15
|
+
if not os.path.exists(self.db_path):
|
|
16
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
17
|
+
cursor = conn.cursor()
|
|
18
|
+
cursor.execute(
|
|
19
|
+
'''
|
|
20
|
+
CREATE TABLE tasks (
|
|
21
|
+
task_id TEXT PRIMARY KEY,
|
|
22
|
+
task_name TEXT,
|
|
23
|
+
payload TEXT,
|
|
24
|
+
status TEXT,
|
|
25
|
+
result TEXT,
|
|
26
|
+
timestamp REAL
|
|
27
|
+
)
|
|
28
|
+
'''
|
|
29
|
+
)
|
|
30
|
+
conn.commit()
|
|
31
|
+
|
|
32
|
+
def _convert_to_string(self, data) -> str:
|
|
33
|
+
"""Converts any data type to a string representation."""
|
|
34
|
+
try:
|
|
35
|
+
if isinstance(data, (dict, list, int, float, bool)):
|
|
36
|
+
return json.dumps(data)
|
|
37
|
+
return str(data)
|
|
38
|
+
except TypeError:
|
|
39
|
+
return str(data)
|
|
40
|
+
|
|
41
|
+
def store_task(self, task: Task) -> None:
|
|
42
|
+
"""Stores a new task or updates an existing one in the SQLite database."""
|
|
43
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
44
|
+
cursor = conn.cursor()
|
|
45
|
+
task_id = self._convert_to_string(task.task_id)
|
|
46
|
+
task_name = self._convert_to_string(task.task_name)
|
|
47
|
+
payload = self._convert_to_string(task.payload)
|
|
48
|
+
status = self._convert_to_string(task.status)
|
|
49
|
+
result = self._convert_to_string(task.result)
|
|
50
|
+
timestamp = task.timestamp if isinstance(task.timestamp, (int, float)) else None
|
|
51
|
+
|
|
52
|
+
cursor.execute(
|
|
53
|
+
'''
|
|
54
|
+
INSERT OR REPLACE INTO tasks (task_id, task_name, payload, status, result, timestamp)
|
|
55
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
56
|
+
''',
|
|
57
|
+
(task_id, task_name, payload, status, result, timestamp)
|
|
58
|
+
)
|
|
59
|
+
conn.commit()
|
|
60
|
+
|
|
61
|
+
def get_task(self, task_id: str) -> Optional[Task]:
|
|
62
|
+
"""Retrieves a task from the SQLite database by its task ID."""
|
|
63
|
+
with sqlite3.connect(self.db_path) as conn:
|
|
64
|
+
cursor = conn.cursor()
|
|
65
|
+
cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
|
|
66
|
+
row = cursor.fetchone()
|
|
67
|
+
if row:
|
|
68
|
+
return Task.from_dict({
|
|
69
|
+
"task_id": row[0],
|
|
70
|
+
"task_name": row[1],
|
|
71
|
+
"payload": json.loads(row[2]),
|
|
72
|
+
"status": row[3],
|
|
73
|
+
"result": json.loads(row[4]) if row[4] and row[4].startswith(('{', '[')) else row[4],
|
|
74
|
+
"timestamp": row[5]
|
|
75
|
+
})
|
|
76
|
+
return None
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from abc import ABC , abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Middleware(ABC) :
|
|
5
|
+
def __init__(self) -> None:
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
def execute(self,event):
|
|
9
|
+
if event == "before_worker_boot":
|
|
10
|
+
self.before_worker_boot()
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def before_worker_boot(self):
|
|
14
|
+
"Called before the worker process starts up."
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import time
|
|
3
|
+
import json
|
|
4
|
+
import redis
|
|
5
|
+
import sqlite3
|
|
6
|
+
import base64
|
|
7
|
+
from typing import Any, Optional, Generator
|
|
8
|
+
from modelq.exceptions import TaskTimeoutError
|
|
9
|
+
from PIL import Image, PngImagePlugin
|
|
10
|
+
import io
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Task:
|
|
14
|
+
|
|
15
|
+
def __init__(self, task_name: str, payload: dict, timeout: int = 15):
|
|
16
|
+
self.task_id = str(uuid.uuid4())
|
|
17
|
+
self.task_name = task_name
|
|
18
|
+
self.payload = payload
|
|
19
|
+
self.status = "queued"
|
|
20
|
+
self.result = None
|
|
21
|
+
self.timestamp = time.time()
|
|
22
|
+
self.timeout = timeout
|
|
23
|
+
self.stream = False
|
|
24
|
+
self.combined_result = ""
|
|
25
|
+
|
|
26
|
+
def to_dict(self):
|
|
27
|
+
return {
|
|
28
|
+
"task_id": self.task_id,
|
|
29
|
+
"task_name": self.task_name,
|
|
30
|
+
"payload": self.payload,
|
|
31
|
+
"status": self.status,
|
|
32
|
+
"result": self.result,
|
|
33
|
+
"timestamp": self.timestamp,
|
|
34
|
+
"stream": self.stream
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def from_dict(data: dict) -> 'Task':
|
|
39
|
+
task = Task(task_name=data["task_name"], payload=data["payload"])
|
|
40
|
+
task.task_id = data["task_id"]
|
|
41
|
+
task.status = data["status"]
|
|
42
|
+
task.result = data.get("result")
|
|
43
|
+
task.timestamp = data["timestamp"]
|
|
44
|
+
task.stream = data.get("stream", False)
|
|
45
|
+
return task
|
|
46
|
+
|
|
47
|
+
def _convert_to_string(self, data: Any) -> str:
|
|
48
|
+
"""Converts any data type to a string representation, including PIL images."""
|
|
49
|
+
# print(type(data))
|
|
50
|
+
try:
|
|
51
|
+
if isinstance(data, (dict, list, int, float, bool)):
|
|
52
|
+
return json.dumps(data)
|
|
53
|
+
elif isinstance(data, (Image.Image, PngImagePlugin.PngImageFile)):
|
|
54
|
+
print("here")
|
|
55
|
+
buffered = io.BytesIO()
|
|
56
|
+
data.save(buffered, format="PNG")
|
|
57
|
+
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
58
|
+
return str(data)
|
|
59
|
+
except TypeError:
|
|
60
|
+
return str(data)
|
|
61
|
+
|
|
62
|
+
def store_in_cache(self, db_path: str = "cache.db") -> None:
|
|
63
|
+
"""Stores the task in the SQLite cache database."""
|
|
64
|
+
with sqlite3.connect(db_path) as conn:
|
|
65
|
+
cursor = conn.cursor()
|
|
66
|
+
task_id = self._convert_to_string(self.task_id)
|
|
67
|
+
task_name = self._convert_to_string(self.task_name)
|
|
68
|
+
payload = self._convert_to_string(self.payload)
|
|
69
|
+
status = self._convert_to_string(self.status)
|
|
70
|
+
result = self._convert_to_string(self.result)
|
|
71
|
+
timestamp = self.timestamp if isinstance(self.timestamp, (int, float)) else None
|
|
72
|
+
|
|
73
|
+
cursor.execute(
|
|
74
|
+
'''
|
|
75
|
+
INSERT OR REPLACE INTO tasks (task_id, task_name, payload, status, result, timestamp)
|
|
76
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
77
|
+
''',
|
|
78
|
+
(task_id, task_name, payload, status, result, timestamp)
|
|
79
|
+
)
|
|
80
|
+
conn.commit()
|
|
81
|
+
|
|
82
|
+
def get_result(self, redis_client: redis.Redis, timeout: int = None) -> Any:
|
|
83
|
+
"""Waits for the result of the task until the timeout."""
|
|
84
|
+
|
|
85
|
+
if not timeout:
|
|
86
|
+
timeout = self.timeout
|
|
87
|
+
|
|
88
|
+
start_time = time.time()
|
|
89
|
+
while time.time() - start_time < timeout:
|
|
90
|
+
task_json = redis_client.get(f"task_result:{self.task_id}")
|
|
91
|
+
if task_json:
|
|
92
|
+
task_data = json.loads(task_json)
|
|
93
|
+
self.result = task_data.get("result")
|
|
94
|
+
self.status = task_data.get("status")
|
|
95
|
+
# Store the updated task status in cache
|
|
96
|
+
self.store_in_cache()
|
|
97
|
+
return self.result
|
|
98
|
+
time.sleep(1)
|
|
99
|
+
raise TaskTimeoutError(self.task_id)
|
|
100
|
+
|
|
101
|
+
def get_stream(self, redis_client: redis.Redis) -> Generator[Any, None, None]:
|
|
102
|
+
"""Generator to yield results from a streaming task."""
|
|
103
|
+
stream_key = f"task_stream:{self.task_id}"
|
|
104
|
+
last_id = "0"
|
|
105
|
+
completed = False
|
|
106
|
+
|
|
107
|
+
while not completed:
|
|
108
|
+
results = redis_client.xread({stream_key: last_id}, block=1000, count=10)
|
|
109
|
+
if results:
|
|
110
|
+
for _, messages in results:
|
|
111
|
+
for message_id, message_data in messages:
|
|
112
|
+
result = json.loads(message_data[b"result"].decode("utf-8"))
|
|
113
|
+
yield result
|
|
114
|
+
last_id = message_id
|
|
115
|
+
# Concatenate result for storing combined response
|
|
116
|
+
self.combined_result += result
|
|
117
|
+
# Check if the task is marked as completed after yielding current messages
|
|
118
|
+
task_json = redis_client.get(f"task_result:{self.task_id}")
|
|
119
|
+
if task_json:
|
|
120
|
+
task_data = json.loads(task_json)
|
|
121
|
+
if task_data.get("status") == "completed":
|
|
122
|
+
completed = True
|
|
123
|
+
# Store the completed task status and combined result in cache
|
|
124
|
+
self.status = "completed"
|
|
125
|
+
self.result = self.combined_result
|
|
126
|
+
self.store_in_cache()
|
|
127
|
+
return
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from PIL import Image
|
|
2
|
+
import io
|
|
3
|
+
import base64
|
|
4
|
+
|
|
5
|
+
def base64_to_image(base64_string):
|
|
6
|
+
split = base64_string.split(",")
|
|
7
|
+
if len(split) > 1 :
|
|
8
|
+
base64_data = split[1]
|
|
9
|
+
else :
|
|
10
|
+
|
|
11
|
+
base64_data = base64_string
|
|
12
|
+
img = Image.open(io.BytesIO(base64.b64decode(base64_data)))
|
|
13
|
+
|
|
14
|
+
return img
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
class TaskTimeoutError(Exception):
|
|
2
|
+
"""Custom exception to indicate task timeout."""
|
|
3
|
+
def __init__(self, task_id : str) -> None:
|
|
4
|
+
super().__init__(f"Task {task_id} timed out waiting for result.")
|
|
5
|
+
self.task_id = task_id
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TaskProcessingError(Exception):
|
|
9
|
+
"""Custom exception to indicate an error occurred during task processing."""
|
|
10
|
+
def __init__(self, task_name: str, message: str):
|
|
11
|
+
super().__init__(f"Error processing task {task_name}: {message}")
|
|
12
|
+
self.task_name = task_name
|
|
13
|
+
self.message = message
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
[tool.poetry]
|
|
2
|
+
name = "modelq"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Celery-like task queue for ML inference."
|
|
5
|
+
authors = ["Tanmaypatil123 <tanmay@modelslab.com>"]
|
|
6
|
+
readme = "README.md"
|
|
7
|
+
|
|
8
|
+
[tool.poetry.scripts]
|
|
9
|
+
modelq = "modelq.app:run_worker"
|
|
10
|
+
|
|
11
|
+
[tool.poetry.dependencies]
|
|
12
|
+
python = "^3.9"
|
|
13
|
+
click = "^8.0.0"
|
|
14
|
+
redis = "^4.0.0"
|
|
15
|
+
|
|
16
|
+
[build-system]
|
|
17
|
+
requires = ["poetry-core"]
|
|
18
|
+
build-backend = "poetry.core.masonry.api"
|