avtomatika 1.0b7__py3-none-any.whl → 1.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika/api/handlers.py +3 -255
- avtomatika/api/routes.py +42 -63
- avtomatika/app_keys.py +2 -0
- avtomatika/config.py +18 -0
- avtomatika/constants.py +2 -26
- avtomatika/data_types.py +4 -23
- avtomatika/dispatcher.py +9 -26
- avtomatika/engine.py +127 -6
- avtomatika/executor.py +53 -25
- avtomatika/health_checker.py +23 -5
- avtomatika/history/base.py +60 -6
- avtomatika/history/noop.py +18 -7
- avtomatika/history/postgres.py +8 -6
- avtomatika/history/sqlite.py +7 -5
- avtomatika/metrics.py +1 -1
- avtomatika/reputation.py +46 -40
- avtomatika/s3.py +379 -0
- avtomatika/security.py +56 -74
- avtomatika/services/__init__.py +0 -0
- avtomatika/services/worker_service.py +266 -0
- avtomatika/storage/base.py +55 -4
- avtomatika/storage/memory.py +56 -7
- avtomatika/storage/redis.py +214 -251
- avtomatika/utils/webhook_sender.py +44 -2
- avtomatika/watcher.py +35 -35
- avtomatika/ws_manager.py +10 -9
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/METADATA +81 -7
- avtomatika-1.0b9.dist-info/RECORD +48 -0
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/WHEEL +1 -1
- avtomatika-1.0b7.dist-info/RECORD +0 -45
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/licenses/LICENSE +0 -0
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/top_level.txt +0 -0
avtomatika/storage/redis.py
CHANGED
|
@@ -35,6 +35,12 @@ class RedisStorage(StorageBackend):
|
|
|
35
35
|
def _get_key(self, job_id: str) -> str:
|
|
36
36
|
return f"{self._prefix}:{job_id}"
|
|
37
37
|
|
|
38
|
+
async def get_worker_info(self, worker_id: str) -> dict[str, Any] | None:
|
|
39
|
+
"""Gets the full info for a worker by its ID."""
|
|
40
|
+
key = f"orchestrator:worker:info:{worker_id}"
|
|
41
|
+
data = await self._redis.get(key)
|
|
42
|
+
return self._unpack(data) if data else None
|
|
43
|
+
|
|
38
44
|
@staticmethod
|
|
39
45
|
def _pack(data: Any) -> bytes:
|
|
40
46
|
return packb(data, use_bin_type=True)
|
|
@@ -55,10 +61,8 @@ class RedisStorage(StorageBackend):
|
|
|
55
61
|
key = f"orchestrator:task_queue:{worker_type}"
|
|
56
62
|
|
|
57
63
|
pipe = self._redis.pipeline()
|
|
58
|
-
pipe.zcard(key)
|
|
59
|
-
# Get the top 3 highest priority bids (scores)
|
|
64
|
+
pipe.zcard(key)
|
|
60
65
|
pipe.zrange(key, -3, -1, withscores=True, score_cast_func=float)
|
|
61
|
-
# Get the top 3 lowest priority bids (scores)
|
|
62
66
|
pipe.zrange(key, 0, 2, withscores=True, score_cast_func=float)
|
|
63
67
|
results = await pipe.execute()
|
|
64
68
|
|
|
@@ -66,7 +70,6 @@ class RedisStorage(StorageBackend):
|
|
|
66
70
|
top_bids = [score for _, score in reversed(top_bids_raw)]
|
|
67
71
|
bottom_bids = [score for _, score in bottom_bids_raw]
|
|
68
72
|
|
|
69
|
-
# Simple average calculation, can be improved for large queues
|
|
70
73
|
all_scores = [s for _, s in await self._redis.zrange(key, 0, -1, withscores=True, score_cast_func=float)]
|
|
71
74
|
avg_bid = sum(all_scores) / len(all_scores) if all_scores else 0
|
|
72
75
|
|
|
@@ -102,8 +105,6 @@ class RedisStorage(StorageBackend):
|
|
|
102
105
|
await pipe.watch(key)
|
|
103
106
|
current_state_raw = await pipe.get(key)
|
|
104
107
|
current_state = self._unpack(current_state_raw) if current_state_raw else {}
|
|
105
|
-
|
|
106
|
-
# Simple dictionary merge. For nested structures, a deep merge may be required.
|
|
107
108
|
current_state.update(update_data)
|
|
108
109
|
|
|
109
110
|
pipe.multi()
|
|
@@ -119,55 +120,46 @@ class RedisStorage(StorageBackend):
|
|
|
119
120
|
worker_info: dict[str, Any],
|
|
120
121
|
ttl: int,
|
|
121
122
|
) -> None:
|
|
122
|
-
"""Registers a worker in Redis."""
|
|
123
|
+
"""Registers a worker in Redis and updates indexes."""
|
|
123
124
|
worker_info.setdefault("reputation", 1.0)
|
|
124
125
|
key = f"orchestrator:worker:info:{worker_id}"
|
|
125
|
-
|
|
126
|
+
tasks_key = f"orchestrator:worker:tasks:{worker_id}"
|
|
126
127
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
task_payload: dict[str, Any],
|
|
131
|
-
priority: float,
|
|
132
|
-
) -> None:
|
|
133
|
-
"""Adds a task to the priority queue (Sorted Set) for a worker."""
|
|
134
|
-
key = f"orchestrator:task_queue:{worker_id}"
|
|
135
|
-
await self._redis.zadd(key, {self._pack(task_payload): priority})
|
|
128
|
+
async with self._redis.pipeline(transaction=True) as pipe:
|
|
129
|
+
pipe.set(key, self._pack(worker_info), ex=ttl)
|
|
130
|
+
pipe.sadd("orchestrator:index:workers:all", worker_id)
|
|
136
131
|
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
) -> dict[str, Any] | None:
|
|
142
|
-
"""Retrieves the highest priority task from the queue (Sorted Set),
|
|
143
|
-
using the blocking BZPOPMAX operation.
|
|
144
|
-
"""
|
|
145
|
-
key = f"orchestrator:task_queue:{worker_id}"
|
|
146
|
-
try:
|
|
147
|
-
# BZPOPMAX returns a tuple (key, member, score)
|
|
148
|
-
result = await self._redis.bzpopmax([key], timeout=timeout)
|
|
149
|
-
return self._unpack(result[1]) if result else None
|
|
150
|
-
except CancelledError:
|
|
151
|
-
return None
|
|
152
|
-
except ResponseError as e:
|
|
153
|
-
# Error handling if `fakeredis` does not support BZPOPMAX
|
|
154
|
-
if "unknown command" in str(e).lower() or "wrong number of arguments" in str(e).lower():
|
|
155
|
-
logger.warning(
|
|
156
|
-
"BZPOPMAX is not supported (likely running with fakeredis). "
|
|
157
|
-
"Falling back to non-blocking ZPOPMAX for testing.",
|
|
158
|
-
)
|
|
159
|
-
# Non-blocking fallback for tests
|
|
160
|
-
res = await self._redis.zpopmax(key)
|
|
161
|
-
if res:
|
|
162
|
-
return self._unpack(res[0][0])
|
|
163
|
-
raise e
|
|
132
|
+
if worker_info.get("status", "idle") == "idle":
|
|
133
|
+
pipe.sadd("orchestrator:index:workers:idle", worker_id)
|
|
134
|
+
else:
|
|
135
|
+
pipe.srem("orchestrator:index:workers:idle", worker_id)
|
|
164
136
|
|
|
165
|
-
|
|
166
|
-
|
|
137
|
+
supported_tasks = worker_info.get("supported_tasks", [])
|
|
138
|
+
if supported_tasks:
|
|
139
|
+
pipe.sadd(tasks_key, *supported_tasks)
|
|
140
|
+
for task in supported_tasks:
|
|
141
|
+
pipe.sadd(f"orchestrator:index:workers:task:{task}", worker_id)
|
|
142
|
+
|
|
143
|
+
await pipe.execute()
|
|
144
|
+
|
|
145
|
+
async def deregister_worker(self, worker_id: str) -> None:
|
|
146
|
+
"""Deletes the worker key and removes it from all indexes."""
|
|
167
147
|
key = f"orchestrator:worker:info:{worker_id}"
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
148
|
+
tasks_key = f"orchestrator:worker:tasks:{worker_id}"
|
|
149
|
+
|
|
150
|
+
tasks = await self._redis.smembers(tasks_key) # type: ignore
|
|
151
|
+
|
|
152
|
+
async with self._redis.pipeline(transaction=True) as pipe:
|
|
153
|
+
pipe.delete(key)
|
|
154
|
+
pipe.delete(tasks_key)
|
|
155
|
+
pipe.srem("orchestrator:index:workers:all", worker_id)
|
|
156
|
+
pipe.srem("orchestrator:index:workers:idle", worker_id)
|
|
157
|
+
|
|
158
|
+
for task in tasks:
|
|
159
|
+
task_str = task.decode("utf-8") if isinstance(task, bytes) else task
|
|
160
|
+
pipe.srem(f"orchestrator:index:workers:task:{task_str}", worker_id)
|
|
161
|
+
|
|
162
|
+
await pipe.execute()
|
|
171
163
|
|
|
172
164
|
async def update_worker_status(
|
|
173
165
|
self,
|
|
@@ -184,102 +176,154 @@ class RedisStorage(StorageBackend):
|
|
|
184
176
|
return None
|
|
185
177
|
|
|
186
178
|
current_state = self._unpack(current_state_raw)
|
|
187
|
-
|
|
188
|
-
# Create a potential new state to compare against the current one
|
|
189
179
|
new_state = current_state.copy()
|
|
190
180
|
new_state.update(status_update)
|
|
191
181
|
|
|
192
182
|
pipe.multi()
|
|
193
183
|
|
|
194
|
-
# Only write to Redis if the state has actually changed.
|
|
195
184
|
if new_state != current_state:
|
|
196
185
|
pipe.set(key, self._pack(new_state), ex=ttl)
|
|
197
|
-
|
|
186
|
+
old_status = current_state.get("status", "idle")
|
|
187
|
+
new_status = new_state.get("status", "idle")
|
|
188
|
+
|
|
189
|
+
if old_status != new_status:
|
|
190
|
+
if new_status == "idle":
|
|
191
|
+
pipe.sadd("orchestrator:index:workers:idle", worker_id)
|
|
192
|
+
else:
|
|
193
|
+
pipe.srem("orchestrator:index:workers:idle", worker_id)
|
|
194
|
+
current_state = new_state
|
|
198
195
|
else:
|
|
199
|
-
# If nothing changed, just refresh the TTL to keep the worker alive.
|
|
200
196
|
pipe.expire(key, ttl)
|
|
201
197
|
|
|
202
198
|
await pipe.execute()
|
|
203
199
|
return current_state
|
|
204
200
|
except WatchError:
|
|
205
|
-
# In case of a conflict, the operation can be repeated,
|
|
206
|
-
# but for a heartbeat it is not critical, you can just skip it.
|
|
207
201
|
return None
|
|
208
202
|
|
|
209
|
-
async def
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
203
|
+
async def find_workers_for_task(self, task_type: str) -> list[str]:
|
|
204
|
+
"""Finds idle workers that support the given task using set intersection."""
|
|
205
|
+
task_index = f"orchestrator:index:workers:task:{task_type}"
|
|
206
|
+
idle_index = "orchestrator:index:workers:idle"
|
|
207
|
+
worker_ids = await self._redis.sinter(task_index, idle_index) # type: ignore
|
|
208
|
+
return [wid.decode("utf-8") if isinstance(wid, bytes) else wid for wid in worker_ids]
|
|
209
|
+
|
|
210
|
+
async def enqueue_task_for_worker(self, worker_id: str, task_payload: dict[str, Any], priority: float) -> None:
|
|
211
|
+
key = f"orchestrator:task_queue:{worker_id}"
|
|
212
|
+
await self._redis.zadd(key, {self._pack(task_payload): priority})
|
|
213
|
+
|
|
214
|
+
async def dequeue_task_for_worker(self, worker_id: str, timeout: int) -> dict[str, Any] | None:
|
|
215
|
+
key = f"orchestrator:task_queue:{worker_id}"
|
|
216
|
+
try:
|
|
217
|
+
result = await self._redis.bzpopmax([key], timeout=timeout)
|
|
218
|
+
return self._unpack(result[1]) if result else None
|
|
219
|
+
except CancelledError:
|
|
220
|
+
return None
|
|
221
|
+
except ResponseError as e:
|
|
222
|
+
if "unknown command" in str(e).lower() or "wrong number of arguments" in str(e).lower():
|
|
223
|
+
res = await self._redis.zpopmax(key)
|
|
224
|
+
if res:
|
|
225
|
+
return self._unpack(res[0][0])
|
|
226
|
+
raise e
|
|
227
|
+
|
|
228
|
+
async def refresh_worker_ttl(self, worker_id: str, ttl: int) -> bool:
|
|
229
|
+
was_set = await self._redis.expire(f"orchestrator:worker:info:{worker_id}", ttl)
|
|
230
|
+
return bool(was_set)
|
|
231
|
+
|
|
232
|
+
async def update_worker_data(self, worker_id: str, update_data: dict[str, Any]) -> dict[str, Any] | None:
|
|
214
233
|
key = f"orchestrator:worker:info:{worker_id}"
|
|
215
234
|
async with self._redis.pipeline(transaction=True) as pipe:
|
|
216
235
|
try:
|
|
217
236
|
await pipe.watch(key)
|
|
218
|
-
|
|
219
|
-
if not
|
|
237
|
+
raw = await pipe.get(key)
|
|
238
|
+
if not raw:
|
|
220
239
|
return None
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
current_state.update(update_data)
|
|
224
|
-
|
|
240
|
+
data = self._unpack(raw)
|
|
241
|
+
data.update(update_data)
|
|
225
242
|
pipe.multi()
|
|
226
|
-
|
|
227
|
-
pipe.set(key, self._pack(current_state))
|
|
243
|
+
pipe.set(key, self._pack(data))
|
|
228
244
|
await pipe.execute()
|
|
229
|
-
return
|
|
245
|
+
return data
|
|
230
246
|
except WatchError:
|
|
231
|
-
# In case of a conflict, the operation can be repeated
|
|
232
|
-
logger.warning(
|
|
233
|
-
f"WatchError during worker data update for {worker_id}, retrying.",
|
|
234
|
-
)
|
|
235
|
-
# In this case, it is better to repeat, as updating the reputation is important
|
|
236
247
|
return await self.update_worker_data(worker_id, update_data)
|
|
237
248
|
|
|
238
249
|
async def get_available_workers(self) -> list[dict[str, Any]]:
|
|
239
|
-
|
|
240
|
-
worker_keys = [key async for key in self._redis.scan_iter("orchestrator:worker:info:*")] # type: ignore[attr-defined]
|
|
241
|
-
|
|
250
|
+
worker_keys = [key async for key in self._redis.scan_iter("orchestrator:worker:info:*")] # type: ignore
|
|
242
251
|
if not worker_keys:
|
|
243
252
|
return []
|
|
253
|
+
data_list = await self._redis.mget(worker_keys)
|
|
254
|
+
return [self._unpack(data) for data in data_list if data]
|
|
244
255
|
|
|
245
|
-
|
|
246
|
-
|
|
256
|
+
async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
|
|
257
|
+
if not worker_ids:
|
|
258
|
+
return []
|
|
259
|
+
keys = [f"orchestrator:worker:info:{wid}" for wid in worker_ids]
|
|
260
|
+
data_list = await self._redis.mget(keys)
|
|
261
|
+
return [self._unpack(data) for data in data_list if data]
|
|
262
|
+
|
|
263
|
+
async def get_active_worker_ids(self) -> list[str]:
|
|
264
|
+
worker_ids = await self._redis.smembers("orchestrator:index:workers:all") # type: ignore
|
|
265
|
+
return [wid.decode("utf-8") if isinstance(wid, bytes) else wid for wid in worker_ids]
|
|
266
|
+
|
|
267
|
+
async def cleanup_expired_workers(self) -> None:
|
|
268
|
+
worker_ids = await self.get_active_worker_ids()
|
|
269
|
+
if not worker_ids:
|
|
270
|
+
return
|
|
271
|
+
pipe = self._redis.pipeline()
|
|
272
|
+
for wid in worker_ids:
|
|
273
|
+
pipe.exists(f"orchestrator:worker:info:{wid}")
|
|
274
|
+
existence = await pipe.execute()
|
|
275
|
+
dead_ids = [worker_ids[i] for i, exists in enumerate(existence) if not exists]
|
|
276
|
+
for wid in dead_ids:
|
|
277
|
+
tasks = await self._redis.smembers(f"orchestrator:worker:tasks:{wid}") # type: ignore
|
|
278
|
+
async with self._redis.pipeline(transaction=True) as p:
|
|
279
|
+
p.delete(f"orchestrator:worker:tasks:{wid}")
|
|
280
|
+
p.srem("orchestrator:index:workers:all", wid)
|
|
281
|
+
p.srem("orchestrator:index:workers:idle", wid)
|
|
282
|
+
for t in tasks:
|
|
283
|
+
p.srem(f"orchestrator:index:workers:task:{t.decode() if isinstance(t, bytes) else t}", wid)
|
|
284
|
+
await p.execute()
|
|
247
285
|
|
|
248
286
|
async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
|
|
249
|
-
"""Adds a job to a Redis sorted set.
|
|
250
|
-
The score is the timeout time.
|
|
251
|
-
"""
|
|
252
287
|
await self._redis.zadd("orchestrator:watched_jobs", {job_id: timeout_at})
|
|
253
288
|
|
|
254
289
|
async def remove_job_from_watch(self, job_id: str) -> None:
|
|
255
|
-
"""Removes a job from the sorted set for tracking."""
|
|
256
290
|
await self._redis.zrem("orchestrator:watched_jobs", job_id)
|
|
257
291
|
|
|
258
|
-
async def get_timed_out_jobs(self) -> list[str]:
|
|
259
|
-
"""Finds and removes overdue jobs from the sorted set."""
|
|
292
|
+
async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
|
|
260
293
|
now = get_running_loop().time()
|
|
261
|
-
#
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
294
|
+
# Lua script to atomically fetch and remove timed out jobs
|
|
295
|
+
LUA_POP_TIMEOUTS = """
|
|
296
|
+
local now = ARGV[1]
|
|
297
|
+
local limit = ARGV[2]
|
|
298
|
+
local ids = redis.call('ZRANGEBYSCORE', KEYS[1], 0, now, 'LIMIT', 0, limit)
|
|
299
|
+
if #ids > 0 then
|
|
300
|
+
redis.call('ZREM', KEYS[1], unpack(ids))
|
|
301
|
+
end
|
|
302
|
+
return ids
|
|
303
|
+
"""
|
|
304
|
+
try:
|
|
305
|
+
sha = await self._redis.script_load(LUA_POP_TIMEOUTS)
|
|
306
|
+
ids = await self._redis.evalsha(sha, 1, "orchestrator:watched_jobs", now, limit)
|
|
307
|
+
except NoScriptError:
|
|
308
|
+
ids = await self._redis.eval(LUA_POP_TIMEOUTS, 1, "orchestrator:watched_jobs", now, limit)
|
|
309
|
+
except ResponseError as e:
|
|
310
|
+
# Fallback for Redis versions that don't support script_load/evalsha or other errors
|
|
311
|
+
if "unknown command" in str(e).lower():
|
|
312
|
+
logger.warning("Redis does not support LUA scripts. Falling back to non-atomic get_timed_out_jobs.")
|
|
313
|
+
ids = await self._redis.zrangebyscore("orchestrator:watched_jobs", 0, now, start=0, num=limit)
|
|
314
|
+
if ids:
|
|
315
|
+
await self._redis.zrem("orchestrator:watched_jobs", *ids) # type: ignore
|
|
316
|
+
else:
|
|
317
|
+
raise e
|
|
318
|
+
|
|
319
|
+
if ids:
|
|
320
|
+
return [i.decode("utf-8") if isinstance(i, bytes) else i for i in ids]
|
|
273
321
|
return []
|
|
274
322
|
|
|
275
323
|
async def enqueue_job(self, job_id: str) -> None:
|
|
276
|
-
"""Adds a job to the Redis stream."""
|
|
277
324
|
await self._redis.xadd(self._stream_key, {"job_id": job_id})
|
|
278
325
|
|
|
279
|
-
async def dequeue_job(self) -> tuple[str, str] | None:
|
|
280
|
-
"""Retrieves a job from the Redis stream using consumer groups.
|
|
281
|
-
Implements a recovery strategy: checks for pending messages first.
|
|
282
|
-
"""
|
|
326
|
+
async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
|
|
283
327
|
if not self._group_created:
|
|
284
328
|
try:
|
|
285
329
|
await self._redis.xgroup_create(self._stream_key, self._group_name, id="0", mkstream=True)
|
|
@@ -287,79 +331,39 @@ class RedisStorage(StorageBackend):
|
|
|
287
331
|
if "BUSYGROUP" not in str(e):
|
|
288
332
|
raise e
|
|
289
333
|
self._group_created = True
|
|
290
|
-
|
|
291
334
|
try:
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
self._stream_key,
|
|
295
|
-
self._group_name,
|
|
296
|
-
self._consumer_name,
|
|
297
|
-
min_idle_time=self._min_idle_time_ms,
|
|
298
|
-
start_id="0-0",
|
|
299
|
-
count=1,
|
|
300
|
-
)
|
|
301
|
-
if autoclaim_result and autoclaim_result[1]:
|
|
302
|
-
messages = autoclaim_result[1]
|
|
303
|
-
message_id, data = messages[0]
|
|
304
|
-
if data:
|
|
305
|
-
job_id = data[b"job_id"].decode("utf-8")
|
|
306
|
-
logger.info(f"Reclaimed pending message {message_id} for consumer {self._consumer_name}")
|
|
307
|
-
return job_id, message_id.decode("utf-8")
|
|
308
|
-
except Exception as e:
|
|
309
|
-
if "unknown command" in str(e).lower() or isinstance(e, ResponseError):
|
|
310
|
-
pending_result = await self._redis.xreadgroup(
|
|
311
|
-
self._group_name,
|
|
312
|
-
self._consumer_name,
|
|
313
|
-
{self._stream_key: "0"},
|
|
314
|
-
count=1,
|
|
315
|
-
)
|
|
316
|
-
if pending_result:
|
|
317
|
-
stream_name, messages = pending_result[0]
|
|
318
|
-
if messages:
|
|
319
|
-
message_id, data = messages[0]
|
|
320
|
-
job_id = data[b"job_id"].decode("utf-8")
|
|
321
|
-
return job_id, message_id.decode("utf-8")
|
|
322
|
-
else:
|
|
323
|
-
raise e
|
|
324
|
-
|
|
325
|
-
result = await self._redis.xreadgroup(
|
|
335
|
+
claim = await self._redis.xautoclaim(
|
|
336
|
+
self._stream_key,
|
|
326
337
|
self._group_name,
|
|
327
338
|
self._consumer_name,
|
|
328
|
-
|
|
339
|
+
min_idle_time=self._min_idle_time_ms,
|
|
340
|
+
start_id="0-0",
|
|
329
341
|
count=1,
|
|
330
342
|
)
|
|
331
|
-
if
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
343
|
+
if claim and claim[1]:
|
|
344
|
+
msg_id, data = claim[1][0]
|
|
345
|
+
return data[b"job_id"].decode("utf-8"), msg_id.decode("utf-8")
|
|
346
|
+
read = await self._redis.xreadgroup(
|
|
347
|
+
self._group_name, self._consumer_name, {self._stream_key: ">"}, count=1, block=block
|
|
348
|
+
)
|
|
349
|
+
if read:
|
|
350
|
+
msg_id, data = read[0][1][0]
|
|
351
|
+
return data[b"job_id"].decode("utf-8"), msg_id.decode("utf-8")
|
|
336
352
|
return None
|
|
337
353
|
except CancelledError:
|
|
338
354
|
return None
|
|
339
355
|
|
|
340
356
|
async def ack_job(self, message_id: str) -> None:
|
|
341
|
-
"""Acknowledges a message in the Redis stream."""
|
|
342
357
|
await self._redis.xack(self._stream_key, self._group_name, message_id)
|
|
343
358
|
|
|
344
359
|
async def quarantine_job(self, job_id: str) -> None:
|
|
345
|
-
""
|
|
346
|
-
await self._redis.lpush("orchestrator:quarantine_queue", job_id) # type: ignore[arg-type]
|
|
360
|
+
await self._redis.lpush("orchestrator:quarantine_queue", job_id) # type: ignore
|
|
347
361
|
|
|
348
362
|
async def get_quarantined_jobs(self) -> list[str]:
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
return [job.decode("utf-8") for job in jobs_bytes]
|
|
352
|
-
|
|
353
|
-
async def deregister_worker(self, worker_id: str) -> None:
|
|
354
|
-
"""Deletes the worker key from Redis."""
|
|
355
|
-
key = f"orchestrator:worker:info:{worker_id}"
|
|
356
|
-
await self._redis.delete(key)
|
|
363
|
+
jobs = await self._redis.lrange("orchestrator:quarantine_queue", 0, -1)
|
|
364
|
+
return [j.decode("utf-8") for j in jobs]
|
|
357
365
|
|
|
358
366
|
async def increment_key_with_ttl(self, key: str, ttl: int) -> int:
|
|
359
|
-
"""Atomically increments a counter and sets a TTL on the first call,
|
|
360
|
-
using a Lua script for atomicity.
|
|
361
|
-
Returns the new value of the counter.
|
|
362
|
-
"""
|
|
363
367
|
async with self._redis.pipeline(transaction=True) as pipe:
|
|
364
368
|
pipe.incr(key)
|
|
365
369
|
pipe.expire(key, ttl)
|
|
@@ -367,144 +371,103 @@ class RedisStorage(StorageBackend):
|
|
|
367
371
|
return results[0]
|
|
368
372
|
|
|
369
373
|
async def save_client_config(self, token: str, config: dict[str, Any]) -> None:
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
str_config = {k: self._pack(v) for k, v in config.items()}
|
|
374
|
-
await self._redis.hset(key, mapping=str_config)
|
|
374
|
+
await self._redis.hset(
|
|
375
|
+
f"orchestrator:client_config:{token}", mapping={k: self._pack(v) for k, v in config.items()}
|
|
376
|
+
)
|
|
375
377
|
|
|
376
378
|
async def get_client_config(self, token: str) -> dict[str, Any] | None:
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
config_raw = await self._redis.hgetall(key) # type: ignore[misc]
|
|
380
|
-
if not config_raw:
|
|
379
|
+
raw = await self._redis.hgetall(f"orchestrator:client_config:{token}") # type: ignore
|
|
380
|
+
if not raw:
|
|
381
381
|
return None
|
|
382
|
-
|
|
383
|
-
return {k.decode("utf-8"): self._unpack(v) for k, v in config_raw.items()}
|
|
382
|
+
return {k.decode("utf-8"): self._unpack(v) for k, v in raw.items()}
|
|
384
383
|
|
|
385
384
|
async def initialize_client_quota(self, token: str, quota: int) -> None:
|
|
386
|
-
""
|
|
387
|
-
key = f"orchestrator:quota:{token}"
|
|
388
|
-
await self._redis.set(key, quota)
|
|
385
|
+
await self._redis.set(f"orchestrator:quota:{token}", quota)
|
|
389
386
|
|
|
390
387
|
async def check_and_decrement_quota(self, token: str) -> bool:
|
|
391
|
-
"""Atomically checks and decrements the quota. Returns True if successful."""
|
|
392
388
|
key = f"orchestrator:quota:{token}"
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
redis.call('DECR', KEYS[1])
|
|
398
|
-
return 1
|
|
399
|
-
else
|
|
400
|
-
return 0
|
|
401
|
-
end
|
|
402
|
-
"""
|
|
403
|
-
|
|
389
|
+
LUA = (
|
|
390
|
+
"local c = redis.call('GET', KEYS[1]) "
|
|
391
|
+
"if c and tonumber(c) > 0 then redis.call('DECR', KEYS[1]) return 1 else return 0 end"
|
|
392
|
+
)
|
|
404
393
|
try:
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
sha = await self._redis.script_load(LUA_SCRIPT)
|
|
408
|
-
result = await self._redis.evalsha(sha, 1, key)
|
|
394
|
+
sha = await self._redis.script_load(LUA)
|
|
395
|
+
res = await self._redis.evalsha(sha, 1, key)
|
|
409
396
|
except NoScriptError:
|
|
410
|
-
|
|
411
|
-
# We can then fall back to executing the full script.
|
|
412
|
-
result = await self._redis.eval(LUA_SCRIPT, 1, key)
|
|
397
|
+
res = await self._redis.eval(LUA, 1, key)
|
|
413
398
|
except ResponseError as e:
|
|
414
|
-
# This is the fallback path for `fakeredis` used in tests, which
|
|
415
|
-
# does not support `SCRIPT LOAD` or `EVALSHA`. It raises a
|
|
416
|
-
# ResponseError: "unknown command `script`".
|
|
417
399
|
if "unknown command" in str(e):
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
current_val = await self._redis.get(key)
|
|
421
|
-
if current_val and int(current_val) > 0:
|
|
400
|
+
cur = await self._redis.get(key)
|
|
401
|
+
if cur and int(cur) > 0:
|
|
422
402
|
await self._redis.decr(key)
|
|
423
403
|
return True
|
|
424
404
|
return False
|
|
425
|
-
# If it's a different ResponseError, re-raise it.
|
|
426
405
|
raise
|
|
427
|
-
|
|
428
|
-
return bool(result)
|
|
406
|
+
return bool(res)
|
|
429
407
|
|
|
430
408
|
async def flush_all(self):
|
|
431
|
-
"""Completely clears the current Redis database.
|
|
432
|
-
WARNING: This operation will delete ALL keys in the current DB.
|
|
433
|
-
Use for testing purposes only.
|
|
434
|
-
"""
|
|
435
|
-
logger.warning("Flushing all data from Redis database.")
|
|
436
409
|
await self._redis.flushdb()
|
|
437
410
|
|
|
438
411
|
async def get_job_queue_length(self) -> int:
|
|
439
|
-
"""Returns the length of the job stream."""
|
|
440
412
|
return await self._redis.xlen(self._stream_key)
|
|
441
413
|
|
|
442
414
|
async def get_active_worker_count(self) -> int:
|
|
443
|
-
|
|
444
|
-
count = 0
|
|
415
|
+
c = 0
|
|
445
416
|
async for _ in self._redis.scan_iter("orchestrator:worker:info:*"):
|
|
446
|
-
|
|
447
|
-
return
|
|
417
|
+
c += 1
|
|
418
|
+
return c
|
|
448
419
|
|
|
449
420
|
async def set_nx_ttl(self, key: str, value: str, ttl: int) -> bool:
|
|
450
|
-
|
|
451
|
-
Uses Redis SET command with NX (Not Exists) and EX (Expire) options.
|
|
452
|
-
"""
|
|
453
|
-
# redis.set returns True if set, None if not set (when nx=True)
|
|
454
|
-
result = await self._redis.set(key, value, nx=True, ex=ttl)
|
|
455
|
-
return bool(result)
|
|
421
|
+
return bool(await self._redis.set(key, value, nx=True, ex=ttl))
|
|
456
422
|
|
|
457
423
|
async def get_str(self, key: str) -> str | None:
|
|
458
424
|
val = await self._redis.get(key)
|
|
459
|
-
if val is None
|
|
460
|
-
return None
|
|
461
|
-
return val.decode("utf-8") if isinstance(val, bytes) else str(val)
|
|
425
|
+
return val.decode("utf-8") if isinstance(val, bytes) else str(val) if val is not None else None
|
|
462
426
|
|
|
463
427
|
async def set_str(self, key: str, value: str, ttl: int | None = None) -> None:
|
|
464
428
|
await self._redis.set(key, value, ex=ttl)
|
|
465
429
|
|
|
466
430
|
async def set_worker_token(self, worker_id: str, token: str):
|
|
467
|
-
""
|
|
468
|
-
key = f"orchestrator:worker:token:{worker_id}"
|
|
469
|
-
await self._redis.set(key, token)
|
|
431
|
+
await self._redis.set(f"orchestrator:worker:token:{worker_id}", token)
|
|
470
432
|
|
|
471
433
|
async def get_worker_token(self, worker_id: str) -> str | None:
|
|
472
|
-
|
|
473
|
-
key = f"orchestrator:worker:token:{worker_id}"
|
|
474
|
-
token = await self._redis.get(key)
|
|
434
|
+
token = await self._redis.get(f"orchestrator:worker:token:{worker_id}")
|
|
475
435
|
return token.decode("utf-8") if token else None
|
|
476
436
|
|
|
477
|
-
async def
|
|
478
|
-
""
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
437
|
+
async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
|
|
438
|
+
await self._redis.set(f"orchestrator:sts:token:{token}", worker_id, ex=ttl)
|
|
439
|
+
|
|
440
|
+
async def verify_worker_access_token(self, token: str) -> str | None:
|
|
441
|
+
worker_id = await self._redis.get(f"orchestrator:sts:token:{token}")
|
|
442
|
+
return worker_id.decode("utf-8") if worker_id else None
|
|
482
443
|
|
|
483
444
|
async def acquire_lock(self, key: str, holder_id: str, ttl: int) -> bool:
|
|
484
|
-
|
|
485
|
-
redis_key = f"orchestrator:lock:{key}"
|
|
486
|
-
result = await self._redis.set(redis_key, holder_id, nx=True, ex=ttl)
|
|
487
|
-
return bool(result)
|
|
445
|
+
return bool(await self._redis.set(f"orchestrator:lock:{key}", holder_id, nx=True, ex=ttl))
|
|
488
446
|
|
|
489
447
|
async def release_lock(self, key: str, holder_id: str) -> bool:
|
|
490
|
-
"
|
|
491
|
-
redis_key = f"orchestrator:lock:{key}"
|
|
492
|
-
|
|
493
|
-
LUA_RELEASE_SCRIPT = """
|
|
494
|
-
if redis.call("get", KEYS[1]) == ARGV[1] then
|
|
495
|
-
return redis.call("del", KEYS[1])
|
|
496
|
-
else
|
|
497
|
-
return 0
|
|
498
|
-
end
|
|
499
|
-
"""
|
|
448
|
+
LUA = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end"
|
|
500
449
|
try:
|
|
501
|
-
|
|
502
|
-
return bool(result)
|
|
450
|
+
return bool(await self._redis.eval(LUA, 1, f"orchestrator:lock:{key}", holder_id))
|
|
503
451
|
except ResponseError as e:
|
|
504
452
|
if "unknown command" in str(e):
|
|
505
|
-
|
|
506
|
-
if
|
|
507
|
-
await self._redis.delete(
|
|
453
|
+
cur = await self._redis.get(f"orchestrator:lock:{key}")
|
|
454
|
+
if cur and cur.decode("utf-8") == holder_id:
|
|
455
|
+
await self._redis.delete(f"orchestrator:lock:{key}")
|
|
508
456
|
return True
|
|
509
457
|
return False
|
|
510
458
|
raise e
|
|
459
|
+
|
|
460
|
+
async def ping(self) -> bool:
|
|
461
|
+
try:
|
|
462
|
+
return await self._redis.ping()
|
|
463
|
+
except Exception:
|
|
464
|
+
return False
|
|
465
|
+
|
|
466
|
+
async def reindex_workers(self) -> None:
|
|
467
|
+
"""Scan existing worker keys and rebuild indexes."""
|
|
468
|
+
async for key in self._redis.scan_iter("orchestrator:worker:info:*"): # type: ignore
|
|
469
|
+
worker_id = key.decode("utf-8").split(":")[-1]
|
|
470
|
+
raw = await self._redis.get(key)
|
|
471
|
+
if raw:
|
|
472
|
+
info = self._unpack(raw)
|
|
473
|
+
await self.register_worker(worker_id, info, int(await self._redis.ttl(key)))
|