avtomatika 1.0b7__py3-none-any.whl → 1.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,6 +35,12 @@ class RedisStorage(StorageBackend):
35
35
  def _get_key(self, job_id: str) -> str:
36
36
  return f"{self._prefix}:{job_id}"
37
37
 
38
+ async def get_worker_info(self, worker_id: str) -> dict[str, Any] | None:
39
+ """Gets the full info for a worker by its ID."""
40
+ key = f"orchestrator:worker:info:{worker_id}"
41
+ data = await self._redis.get(key)
42
+ return self._unpack(data) if data else None
43
+
38
44
  @staticmethod
39
45
  def _pack(data: Any) -> bytes:
40
46
  return packb(data, use_bin_type=True)
@@ -55,10 +61,8 @@ class RedisStorage(StorageBackend):
55
61
  key = f"orchestrator:task_queue:{worker_type}"
56
62
 
57
63
  pipe = self._redis.pipeline()
58
- pipe.zcard(key) # Get the number of elements
59
- # Get the top 3 highest priority bids (scores)
64
+ pipe.zcard(key)
60
65
  pipe.zrange(key, -3, -1, withscores=True, score_cast_func=float)
61
- # Get the top 3 lowest priority bids (scores)
62
66
  pipe.zrange(key, 0, 2, withscores=True, score_cast_func=float)
63
67
  results = await pipe.execute()
64
68
 
@@ -66,7 +70,6 @@ class RedisStorage(StorageBackend):
66
70
  top_bids = [score for _, score in reversed(top_bids_raw)]
67
71
  bottom_bids = [score for _, score in bottom_bids_raw]
68
72
 
69
- # Simple average calculation, can be improved for large queues
70
73
  all_scores = [s for _, s in await self._redis.zrange(key, 0, -1, withscores=True, score_cast_func=float)]
71
74
  avg_bid = sum(all_scores) / len(all_scores) if all_scores else 0
72
75
 
@@ -102,8 +105,6 @@ class RedisStorage(StorageBackend):
102
105
  await pipe.watch(key)
103
106
  current_state_raw = await pipe.get(key)
104
107
  current_state = self._unpack(current_state_raw) if current_state_raw else {}
105
-
106
- # Simple dictionary merge. For nested structures, a deep merge may be required.
107
108
  current_state.update(update_data)
108
109
 
109
110
  pipe.multi()
@@ -119,55 +120,46 @@ class RedisStorage(StorageBackend):
119
120
  worker_info: dict[str, Any],
120
121
  ttl: int,
121
122
  ) -> None:
122
- """Registers a worker in Redis."""
123
+ """Registers a worker in Redis and updates indexes."""
123
124
  worker_info.setdefault("reputation", 1.0)
124
125
  key = f"orchestrator:worker:info:{worker_id}"
125
- await self._redis.set(key, self._pack(worker_info), ex=ttl)
126
+ tasks_key = f"orchestrator:worker:tasks:{worker_id}"
126
127
 
127
- async def enqueue_task_for_worker(
128
- self,
129
- worker_id: str,
130
- task_payload: dict[str, Any],
131
- priority: float,
132
- ) -> None:
133
- """Adds a task to the priority queue (Sorted Set) for a worker."""
134
- key = f"orchestrator:task_queue:{worker_id}"
135
- await self._redis.zadd(key, {self._pack(task_payload): priority})
128
+ async with self._redis.pipeline(transaction=True) as pipe:
129
+ pipe.set(key, self._pack(worker_info), ex=ttl)
130
+ pipe.sadd("orchestrator:index:workers:all", worker_id)
136
131
 
137
- async def dequeue_task_for_worker(
138
- self,
139
- worker_id: str,
140
- timeout: int,
141
- ) -> dict[str, Any] | None:
142
- """Retrieves the highest priority task from the queue (Sorted Set),
143
- using the blocking BZPOPMAX operation.
144
- """
145
- key = f"orchestrator:task_queue:{worker_id}"
146
- try:
147
- # BZPOPMAX returns a tuple (key, member, score)
148
- result = await self._redis.bzpopmax([key], timeout=timeout)
149
- return self._unpack(result[1]) if result else None
150
- except CancelledError:
151
- return None
152
- except ResponseError as e:
153
- # Error handling if `fakeredis` does not support BZPOPMAX
154
- if "unknown command" in str(e).lower() or "wrong number of arguments" in str(e).lower():
155
- logger.warning(
156
- "BZPOPMAX is not supported (likely running with fakeredis). "
157
- "Falling back to non-blocking ZPOPMAX for testing.",
158
- )
159
- # Non-blocking fallback for tests
160
- res = await self._redis.zpopmax(key)
161
- if res:
162
- return self._unpack(res[0][0])
163
- raise e
132
+ if worker_info.get("status", "idle") == "idle":
133
+ pipe.sadd("orchestrator:index:workers:idle", worker_id)
134
+ else:
135
+ pipe.srem("orchestrator:index:workers:idle", worker_id)
164
136
 
165
- async def refresh_worker_ttl(self, worker_id: str, ttl: int) -> bool:
166
- """Updates the TTL for a worker key using the EXPIRE command."""
137
+ supported_tasks = worker_info.get("supported_tasks", [])
138
+ if supported_tasks:
139
+ pipe.sadd(tasks_key, *supported_tasks)
140
+ for task in supported_tasks:
141
+ pipe.sadd(f"orchestrator:index:workers:task:{task}", worker_id)
142
+
143
+ await pipe.execute()
144
+
145
+ async def deregister_worker(self, worker_id: str) -> None:
146
+ """Deletes the worker key and removes it from all indexes."""
167
147
  key = f"orchestrator:worker:info:{worker_id}"
168
- # EXPIRE returns 1 if the TTL was set, and 0 if the key does not exist.
169
- was_set = await self._redis.expire(key, ttl) # type: ignore[misc]
170
- return bool(was_set)
148
+ tasks_key = f"orchestrator:worker:tasks:{worker_id}"
149
+
150
+ tasks = await self._redis.smembers(tasks_key) # type: ignore
151
+
152
+ async with self._redis.pipeline(transaction=True) as pipe:
153
+ pipe.delete(key)
154
+ pipe.delete(tasks_key)
155
+ pipe.srem("orchestrator:index:workers:all", worker_id)
156
+ pipe.srem("orchestrator:index:workers:idle", worker_id)
157
+
158
+ for task in tasks:
159
+ task_str = task.decode("utf-8") if isinstance(task, bytes) else task
160
+ pipe.srem(f"orchestrator:index:workers:task:{task_str}", worker_id)
161
+
162
+ await pipe.execute()
171
163
 
172
164
  async def update_worker_status(
173
165
  self,
@@ -184,102 +176,154 @@ class RedisStorage(StorageBackend):
184
176
  return None
185
177
 
186
178
  current_state = self._unpack(current_state_raw)
187
-
188
- # Create a potential new state to compare against the current one
189
179
  new_state = current_state.copy()
190
180
  new_state.update(status_update)
191
181
 
192
182
  pipe.multi()
193
183
 
194
- # Only write to Redis if the state has actually changed.
195
184
  if new_state != current_state:
196
185
  pipe.set(key, self._pack(new_state), ex=ttl)
197
- current_state = new_state # Update the state to be returned
186
+ old_status = current_state.get("status", "idle")
187
+ new_status = new_state.get("status", "idle")
188
+
189
+ if old_status != new_status:
190
+ if new_status == "idle":
191
+ pipe.sadd("orchestrator:index:workers:idle", worker_id)
192
+ else:
193
+ pipe.srem("orchestrator:index:workers:idle", worker_id)
194
+ current_state = new_state
198
195
  else:
199
- # If nothing changed, just refresh the TTL to keep the worker alive.
200
196
  pipe.expire(key, ttl)
201
197
 
202
198
  await pipe.execute()
203
199
  return current_state
204
200
  except WatchError:
205
- # In case of a conflict, the operation can be repeated,
206
- # but for a heartbeat it is not critical, you can just skip it.
207
201
  return None
208
202
 
209
- async def update_worker_data(
210
- self,
211
- worker_id: str,
212
- update_data: dict[str, Any],
213
- ) -> dict[str, Any] | None:
203
+ async def find_workers_for_task(self, task_type: str) -> list[str]:
204
+ """Finds idle workers that support the given task using set intersection."""
205
+ task_index = f"orchestrator:index:workers:task:{task_type}"
206
+ idle_index = "orchestrator:index:workers:idle"
207
+ worker_ids = await self._redis.sinter(task_index, idle_index) # type: ignore
208
+ return [wid.decode("utf-8") if isinstance(wid, bytes) else wid for wid in worker_ids]
209
+
210
+ async def enqueue_task_for_worker(self, worker_id: str, task_payload: dict[str, Any], priority: float) -> None:
211
+ key = f"orchestrator:task_queue:{worker_id}"
212
+ await self._redis.zadd(key, {self._pack(task_payload): priority})
213
+
214
+ async def dequeue_task_for_worker(self, worker_id: str, timeout: int) -> dict[str, Any] | None:
215
+ key = f"orchestrator:task_queue:{worker_id}"
216
+ try:
217
+ result = await self._redis.bzpopmax([key], timeout=timeout)
218
+ return self._unpack(result[1]) if result else None
219
+ except CancelledError:
220
+ return None
221
+ except ResponseError as e:
222
+ if "unknown command" in str(e).lower() or "wrong number of arguments" in str(e).lower():
223
+ res = await self._redis.zpopmax(key)
224
+ if res:
225
+ return self._unpack(res[0][0])
226
+ raise e
227
+
228
+ async def refresh_worker_ttl(self, worker_id: str, ttl: int) -> bool:
229
+ was_set = await self._redis.expire(f"orchestrator:worker:info:{worker_id}", ttl)
230
+ return bool(was_set)
231
+
232
+ async def update_worker_data(self, worker_id: str, update_data: dict[str, Any]) -> dict[str, Any] | None:
214
233
  key = f"orchestrator:worker:info:{worker_id}"
215
234
  async with self._redis.pipeline(transaction=True) as pipe:
216
235
  try:
217
236
  await pipe.watch(key)
218
- current_state_raw = await pipe.get(key)
219
- if not current_state_raw:
237
+ raw = await pipe.get(key)
238
+ if not raw:
220
239
  return None
221
-
222
- current_state = self._unpack(current_state_raw)
223
- current_state.update(update_data)
224
-
240
+ data = self._unpack(raw)
241
+ data.update(update_data)
225
242
  pipe.multi()
226
- # Do not set TTL, as this is a data update, not a heartbeat
227
- pipe.set(key, self._pack(current_state))
243
+ pipe.set(key, self._pack(data))
228
244
  await pipe.execute()
229
- return current_state
245
+ return data
230
246
  except WatchError:
231
- # In case of a conflict, the operation can be repeated
232
- logger.warning(
233
- f"WatchError during worker data update for {worker_id}, retrying.",
234
- )
235
- # In this case, it is better to repeat, as updating the reputation is important
236
247
  return await self.update_worker_data(worker_id, update_data)
237
248
 
238
249
  async def get_available_workers(self) -> list[dict[str, Any]]:
239
- """Gets a list of active workers by scanning keys in Redis."""
240
- worker_keys = [key async for key in self._redis.scan_iter("orchestrator:worker:info:*")] # type: ignore[attr-defined]
241
-
250
+ worker_keys = [key async for key in self._redis.scan_iter("orchestrator:worker:info:*")] # type: ignore
242
251
  if not worker_keys:
243
252
  return []
253
+ data_list = await self._redis.mget(worker_keys)
254
+ return [self._unpack(data) for data in data_list if data]
244
255
 
245
- worker_data_list = await self._redis.mget(worker_keys)
246
- return [self._unpack(data) for data in worker_data_list if data]
256
+ async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
257
+ if not worker_ids:
258
+ return []
259
+ keys = [f"orchestrator:worker:info:{wid}" for wid in worker_ids]
260
+ data_list = await self._redis.mget(keys)
261
+ return [self._unpack(data) for data in data_list if data]
262
+
263
+ async def get_active_worker_ids(self) -> list[str]:
264
+ worker_ids = await self._redis.smembers("orchestrator:index:workers:all") # type: ignore
265
+ return [wid.decode("utf-8") if isinstance(wid, bytes) else wid for wid in worker_ids]
266
+
267
+ async def cleanup_expired_workers(self) -> None:
268
+ worker_ids = await self.get_active_worker_ids()
269
+ if not worker_ids:
270
+ return
271
+ pipe = self._redis.pipeline()
272
+ for wid in worker_ids:
273
+ pipe.exists(f"orchestrator:worker:info:{wid}")
274
+ existence = await pipe.execute()
275
+ dead_ids = [worker_ids[i] for i, exists in enumerate(existence) if not exists]
276
+ for wid in dead_ids:
277
+ tasks = await self._redis.smembers(f"orchestrator:worker:tasks:{wid}") # type: ignore
278
+ async with self._redis.pipeline(transaction=True) as p:
279
+ p.delete(f"orchestrator:worker:tasks:{wid}")
280
+ p.srem("orchestrator:index:workers:all", wid)
281
+ p.srem("orchestrator:index:workers:idle", wid)
282
+ for t in tasks:
283
+ p.srem(f"orchestrator:index:workers:task:{t.decode() if isinstance(t, bytes) else t}", wid)
284
+ await p.execute()
247
285
 
248
286
  async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
249
- """Adds a job to a Redis sorted set.
250
- The score is the timeout time.
251
- """
252
287
  await self._redis.zadd("orchestrator:watched_jobs", {job_id: timeout_at})
253
288
 
254
289
  async def remove_job_from_watch(self, job_id: str) -> None:
255
- """Removes a job from the sorted set for tracking."""
256
290
  await self._redis.zrem("orchestrator:watched_jobs", job_id)
257
291
 
258
- async def get_timed_out_jobs(self) -> list[str]:
259
- """Finds and removes overdue jobs from the sorted set."""
292
+ async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
260
293
  now = get_running_loop().time()
261
- # Find all jobs with a timeout up to the current moment
262
- timed_out_ids = await self._redis.zrangebyscore(
263
- "orchestrator:watched_jobs",
264
- 0,
265
- now,
266
- )
267
-
268
- if timed_out_ids:
269
- # Atomically remove the found IDs
270
- await self._redis.zrem("orchestrator:watched_jobs", *timed_out_ids) # type: ignore[arg-type]
271
- return [job_id.decode("utf-8") for job_id in timed_out_ids]
272
-
294
+ # Lua script to atomically fetch and remove timed out jobs
295
+ LUA_POP_TIMEOUTS = """
296
+ local now = ARGV[1]
297
+ local limit = ARGV[2]
298
+ local ids = redis.call('ZRANGEBYSCORE', KEYS[1], 0, now, 'LIMIT', 0, limit)
299
+ if #ids > 0 then
300
+ redis.call('ZREM', KEYS[1], unpack(ids))
301
+ end
302
+ return ids
303
+ """
304
+ try:
305
+ sha = await self._redis.script_load(LUA_POP_TIMEOUTS)
306
+ ids = await self._redis.evalsha(sha, 1, "orchestrator:watched_jobs", now, limit)
307
+ except NoScriptError:
308
+ ids = await self._redis.eval(LUA_POP_TIMEOUTS, 1, "orchestrator:watched_jobs", now, limit)
309
+ except ResponseError as e:
310
+ # Fallback for Redis versions that don't support script_load/evalsha or other errors
311
+ if "unknown command" in str(e).lower():
312
+ logger.warning("Redis does not support LUA scripts. Falling back to non-atomic get_timed_out_jobs.")
313
+ ids = await self._redis.zrangebyscore("orchestrator:watched_jobs", 0, now, start=0, num=limit)
314
+ if ids:
315
+ await self._redis.zrem("orchestrator:watched_jobs", *ids) # type: ignore
316
+ else:
317
+ raise e
318
+
319
+ if ids:
320
+ return [i.decode("utf-8") if isinstance(i, bytes) else i for i in ids]
273
321
  return []
274
322
 
275
323
  async def enqueue_job(self, job_id: str) -> None:
276
- """Adds a job to the Redis stream."""
277
324
  await self._redis.xadd(self._stream_key, {"job_id": job_id})
278
325
 
279
- async def dequeue_job(self) -> tuple[str, str] | None:
280
- """Retrieves a job from the Redis stream using consumer groups.
281
- Implements a recovery strategy: checks for pending messages first.
282
- """
326
+ async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
283
327
  if not self._group_created:
284
328
  try:
285
329
  await self._redis.xgroup_create(self._stream_key, self._group_name, id="0", mkstream=True)
@@ -287,79 +331,39 @@ class RedisStorage(StorageBackend):
287
331
  if "BUSYGROUP" not in str(e):
288
332
  raise e
289
333
  self._group_created = True
290
-
291
334
  try:
292
- try:
293
- autoclaim_result = await self._redis.xautoclaim(
294
- self._stream_key,
295
- self._group_name,
296
- self._consumer_name,
297
- min_idle_time=self._min_idle_time_ms,
298
- start_id="0-0",
299
- count=1,
300
- )
301
- if autoclaim_result and autoclaim_result[1]:
302
- messages = autoclaim_result[1]
303
- message_id, data = messages[0]
304
- if data:
305
- job_id = data[b"job_id"].decode("utf-8")
306
- logger.info(f"Reclaimed pending message {message_id} for consumer {self._consumer_name}")
307
- return job_id, message_id.decode("utf-8")
308
- except Exception as e:
309
- if "unknown command" in str(e).lower() or isinstance(e, ResponseError):
310
- pending_result = await self._redis.xreadgroup(
311
- self._group_name,
312
- self._consumer_name,
313
- {self._stream_key: "0"},
314
- count=1,
315
- )
316
- if pending_result:
317
- stream_name, messages = pending_result[0]
318
- if messages:
319
- message_id, data = messages[0]
320
- job_id = data[b"job_id"].decode("utf-8")
321
- return job_id, message_id.decode("utf-8")
322
- else:
323
- raise e
324
-
325
- result = await self._redis.xreadgroup(
335
+ claim = await self._redis.xautoclaim(
336
+ self._stream_key,
326
337
  self._group_name,
327
338
  self._consumer_name,
328
- {self._stream_key: ">"},
339
+ min_idle_time=self._min_idle_time_ms,
340
+ start_id="0-0",
329
341
  count=1,
330
342
  )
331
- if result:
332
- stream_name, messages = result[0]
333
- message_id, data = messages[0]
334
- job_id = data[b"job_id"].decode("utf-8")
335
- return job_id, message_id.decode("utf-8")
343
+ if claim and claim[1]:
344
+ msg_id, data = claim[1][0]
345
+ return data[b"job_id"].decode("utf-8"), msg_id.decode("utf-8")
346
+ read = await self._redis.xreadgroup(
347
+ self._group_name, self._consumer_name, {self._stream_key: ">"}, count=1, block=block
348
+ )
349
+ if read:
350
+ msg_id, data = read[0][1][0]
351
+ return data[b"job_id"].decode("utf-8"), msg_id.decode("utf-8")
336
352
  return None
337
353
  except CancelledError:
338
354
  return None
339
355
 
340
356
  async def ack_job(self, message_id: str) -> None:
341
- """Acknowledges a message in the Redis stream."""
342
357
  await self._redis.xack(self._stream_key, self._group_name, message_id)
343
358
 
344
359
  async def quarantine_job(self, job_id: str) -> None:
345
- """Moves the job ID to the 'quarantine' list in Redis."""
346
- await self._redis.lpush("orchestrator:quarantine_queue", job_id) # type: ignore[arg-type]
360
+ await self._redis.lpush("orchestrator:quarantine_queue", job_id) # type: ignore
347
361
 
348
362
  async def get_quarantined_jobs(self) -> list[str]:
349
- """Gets all job IDs from the quarantine queue."""
350
- jobs_bytes = await self._redis.lrange("orchestrator:quarantine_queue", 0, -1)
351
- return [job.decode("utf-8") for job in jobs_bytes]
352
-
353
- async def deregister_worker(self, worker_id: str) -> None:
354
- """Deletes the worker key from Redis."""
355
- key = f"orchestrator:worker:info:{worker_id}"
356
- await self._redis.delete(key)
363
+ jobs = await self._redis.lrange("orchestrator:quarantine_queue", 0, -1)
364
+ return [j.decode("utf-8") for j in jobs]
357
365
 
358
366
  async def increment_key_with_ttl(self, key: str, ttl: int) -> int:
359
- """Atomically increments a counter and sets a TTL on the first call,
360
- using a Lua script for atomicity.
361
- Returns the new value of the counter.
362
- """
363
367
  async with self._redis.pipeline(transaction=True) as pipe:
364
368
  pipe.incr(key)
365
369
  pipe.expire(key, ttl)
@@ -367,144 +371,103 @@ class RedisStorage(StorageBackend):
367
371
  return results[0]
368
372
 
369
373
  async def save_client_config(self, token: str, config: dict[str, Any]) -> None:
370
- """Saves the static client configuration as a hash."""
371
- key = f"orchestrator:client_config:{token}"
372
- # Convert all values to binary strings for storage in a Redis hash
373
- str_config = {k: self._pack(v) for k, v in config.items()}
374
- await self._redis.hset(key, mapping=str_config)
374
+ await self._redis.hset(
375
+ f"orchestrator:client_config:{token}", mapping={k: self._pack(v) for k, v in config.items()}
376
+ )
375
377
 
376
378
  async def get_client_config(self, token: str) -> dict[str, Any] | None:
377
- """Gets the static client configuration."""
378
- key = f"orchestrator:client_config:{token}"
379
- config_raw = await self._redis.hgetall(key) # type: ignore[misc]
380
- if not config_raw:
379
+ raw = await self._redis.hgetall(f"orchestrator:client_config:{token}") # type: ignore
380
+ if not raw:
381
381
  return None
382
- # Decode keys and values, parse binary
383
- return {k.decode("utf-8"): self._unpack(v) for k, v in config_raw.items()}
382
+ return {k.decode("utf-8"): self._unpack(v) for k, v in raw.items()}
384
383
 
385
384
  async def initialize_client_quota(self, token: str, quota: int) -> None:
386
- """Sets or resets the quota counter."""
387
- key = f"orchestrator:quota:{token}"
388
- await self._redis.set(key, quota)
385
+ await self._redis.set(f"orchestrator:quota:{token}", quota)
389
386
 
390
387
  async def check_and_decrement_quota(self, token: str) -> bool:
391
- """Atomically checks and decrements the quota. Returns True if successful."""
392
388
  key = f"orchestrator:quota:{token}"
393
-
394
- LUA_SCRIPT = """
395
- local current = redis.call('GET', KEYS[1])
396
- if current and tonumber(current) > 0 then
397
- redis.call('DECR', KEYS[1])
398
- return 1
399
- else
400
- return 0
401
- end
402
- """
403
-
389
+ LUA = (
390
+ "local c = redis.call('GET', KEYS[1]) "
391
+ "if c and tonumber(c) > 0 then redis.call('DECR', KEYS[1]) return 1 else return 0 end"
392
+ )
404
393
  try:
405
- # This is the most efficient path for a real Redis server.
406
- # It loads the script once and then executes it by its SHA hash.
407
- sha = await self._redis.script_load(LUA_SCRIPT)
408
- result = await self._redis.evalsha(sha, 1, key)
394
+ sha = await self._redis.script_load(LUA)
395
+ res = await self._redis.evalsha(sha, 1, key)
409
396
  except NoScriptError:
410
- # If the script is not in the cache, Redis raises NoScriptError.
411
- # We can then fall back to executing the full script.
412
- result = await self._redis.eval(LUA_SCRIPT, 1, key)
397
+ res = await self._redis.eval(LUA, 1, key)
413
398
  except ResponseError as e:
414
- # This is the fallback path for `fakeredis` used in tests, which
415
- # does not support `SCRIPT LOAD` or `EVALSHA`. It raises a
416
- # ResponseError: "unknown command `script`".
417
399
  if "unknown command" in str(e):
418
- # We resort to a non-atomic GET/DECR for testing purposes.
419
- # This is not safe for production but allows tests to pass.
420
- current_val = await self._redis.get(key)
421
- if current_val and int(current_val) > 0:
400
+ cur = await self._redis.get(key)
401
+ if cur and int(cur) > 0:
422
402
  await self._redis.decr(key)
423
403
  return True
424
404
  return False
425
- # If it's a different ResponseError, re-raise it.
426
405
  raise
427
-
428
- return bool(result)
406
+ return bool(res)
429
407
 
430
408
  async def flush_all(self):
431
- """Completely clears the current Redis database.
432
- WARNING: This operation will delete ALL keys in the current DB.
433
- Use for testing purposes only.
434
- """
435
- logger.warning("Flushing all data from Redis database.")
436
409
  await self._redis.flushdb()
437
410
 
438
411
  async def get_job_queue_length(self) -> int:
439
- """Returns the length of the job stream."""
440
412
  return await self._redis.xlen(self._stream_key)
441
413
 
442
414
  async def get_active_worker_count(self) -> int:
443
- """Returns the number of active worker keys."""
444
- count = 0
415
+ c = 0
445
416
  async for _ in self._redis.scan_iter("orchestrator:worker:info:*"):
446
- count += 1
447
- return count
417
+ c += 1
418
+ return c
448
419
 
449
420
  async def set_nx_ttl(self, key: str, value: str, ttl: int) -> bool:
450
- """
451
- Uses Redis SET command with NX (Not Exists) and EX (Expire) options.
452
- """
453
- # redis.set returns True if set, None if not set (when nx=True)
454
- result = await self._redis.set(key, value, nx=True, ex=ttl)
455
- return bool(result)
421
+ return bool(await self._redis.set(key, value, nx=True, ex=ttl))
456
422
 
457
423
  async def get_str(self, key: str) -> str | None:
458
424
  val = await self._redis.get(key)
459
- if val is None:
460
- return None
461
- return val.decode("utf-8") if isinstance(val, bytes) else str(val)
425
+ return val.decode("utf-8") if isinstance(val, bytes) else str(val) if val is not None else None
462
426
 
463
427
  async def set_str(self, key: str, value: str, ttl: int | None = None) -> None:
464
428
  await self._redis.set(key, value, ex=ttl)
465
429
 
466
430
  async def set_worker_token(self, worker_id: str, token: str):
467
- """Stores the individual token for a specific worker."""
468
- key = f"orchestrator:worker:token:{worker_id}"
469
- await self._redis.set(key, token)
431
+ await self._redis.set(f"orchestrator:worker:token:{worker_id}", token)
470
432
 
471
433
  async def get_worker_token(self, worker_id: str) -> str | None:
472
- """Retrieves the individual token for a specific worker."""
473
- key = f"orchestrator:worker:token:{worker_id}"
474
- token = await self._redis.get(key)
434
+ token = await self._redis.get(f"orchestrator:worker:token:{worker_id}")
475
435
  return token.decode("utf-8") if token else None
476
436
 
477
- async def get_worker_info(self, worker_id: str) -> dict[str, Any] | None:
478
- """Gets the full info for a worker by its ID."""
479
- key = f"orchestrator:worker:info:{worker_id}"
480
- data = await self._redis.get(key)
481
- return self._unpack(data) if data else None
437
+ async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
438
+ await self._redis.set(f"orchestrator:sts:token:{token}", worker_id, ex=ttl)
439
+
440
+ async def verify_worker_access_token(self, token: str) -> str | None:
441
+ worker_id = await self._redis.get(f"orchestrator:sts:token:{token}")
442
+ return worker_id.decode("utf-8") if worker_id else None
482
443
 
483
444
  async def acquire_lock(self, key: str, holder_id: str, ttl: int) -> bool:
484
- """Attempts to acquire a lock using Redis SET NX."""
485
- redis_key = f"orchestrator:lock:{key}"
486
- result = await self._redis.set(redis_key, holder_id, nx=True, ex=ttl)
487
- return bool(result)
445
+ return bool(await self._redis.set(f"orchestrator:lock:{key}", holder_id, nx=True, ex=ttl))
488
446
 
489
447
  async def release_lock(self, key: str, holder_id: str) -> bool:
490
- """Releases the lock using a Lua script to ensure ownership."""
491
- redis_key = f"orchestrator:lock:{key}"
492
-
493
- LUA_RELEASE_SCRIPT = """
494
- if redis.call("get", KEYS[1]) == ARGV[1] then
495
- return redis.call("del", KEYS[1])
496
- else
497
- return 0
498
- end
499
- """
448
+ LUA = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end"
500
449
  try:
501
- result = await self._redis.eval(LUA_RELEASE_SCRIPT, 1, redis_key, holder_id)
502
- return bool(result)
450
+ return bool(await self._redis.eval(LUA, 1, f"orchestrator:lock:{key}", holder_id))
503
451
  except ResponseError as e:
504
452
  if "unknown command" in str(e):
505
- current_val = await self._redis.get(redis_key)
506
- if current_val and current_val.decode("utf-8") == holder_id:
507
- await self._redis.delete(redis_key)
453
+ cur = await self._redis.get(f"orchestrator:lock:{key}")
454
+ if cur and cur.decode("utf-8") == holder_id:
455
+ await self._redis.delete(f"orchestrator:lock:{key}")
508
456
  return True
509
457
  return False
510
458
  raise e
459
+
460
+ async def ping(self) -> bool:
461
+ try:
462
+ return await self._redis.ping()
463
+ except Exception:
464
+ return False
465
+
466
+ async def reindex_workers(self) -> None:
467
+ """Scan existing worker keys and rebuild indexes."""
468
+ async for key in self._redis.scan_iter("orchestrator:worker:info:*"): # type: ignore
469
+ worker_id = key.decode("utf-8").split(":")[-1]
470
+ raw = await self._redis.get(key)
471
+ if raw:
472
+ info = self._unpack(raw)
473
+ await self.register_worker(worker_id, info, int(await self._redis.ttl(key)))