avtomatika 1.0b8__py3-none-any.whl → 1.0b10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
avtomatika/security.py CHANGED
@@ -10,6 +10,62 @@ from .storage.base import StorageBackend
10
10
  Handler = Callable[[web.Request], Awaitable[web.Response]]
11
11
 
12
12
 
13
+ async def verify_worker_auth(
14
+ storage: StorageBackend,
15
+ config: Config,
16
+ token: str | None,
17
+ cert_identity: str | None,
18
+ worker_id_hint: str | None,
19
+ ) -> str:
20
+ """
21
+ Verifies worker authentication using token or mTLS.
22
+ Returns authenticated worker_id.
23
+ Raises ValueError (400), PermissionError (401/403) on failure.
24
+ """
25
+ # mTLS Check
26
+ if cert_identity:
27
+ if worker_id_hint and cert_identity != worker_id_hint:
28
+ raise PermissionError(
29
+ f"Unauthorized: Certificate CN '{cert_identity}' does not match worker_id '{worker_id_hint}'"
30
+ )
31
+ return cert_identity
32
+
33
+ # Token Check
34
+ if not token:
35
+ raise PermissionError(f"Missing {AUTH_HEADER_WORKER} header or client certificate")
36
+
37
+ hashed_provided_token = sha256(token.encode()).hexdigest()
38
+
39
+ # STS Access Token
40
+ token_worker_id = await storage.verify_worker_access_token(hashed_provided_token)
41
+ if token_worker_id:
42
+ if worker_id_hint and token_worker_id != worker_id_hint:
43
+ raise PermissionError(
44
+ f"Unauthorized: Access Token belongs to '{token_worker_id}', but request is for '{worker_id_hint}'"
45
+ )
46
+ return token_worker_id
47
+
48
+ # Individual/Global Token
49
+ if not worker_id_hint:
50
+ if config.GLOBAL_WORKER_TOKEN and token == config.GLOBAL_WORKER_TOKEN:
51
+ return "unknown_authenticated_by_global_token"
52
+
53
+ raise PermissionError("Unauthorized: Invalid token or missing worker_id hint")
54
+
55
+ # Individual Token for specific worker
56
+ expected_token_hash = await storage.get_worker_token(worker_id_hint)
57
+ if expected_token_hash:
58
+ if hashed_provided_token == expected_token_hash:
59
+ return worker_id_hint
60
+ raise PermissionError("Unauthorized: Invalid individual worker token")
61
+
62
+ # Global Token Fallback
63
+ if config.GLOBAL_WORKER_TOKEN and token == config.GLOBAL_WORKER_TOKEN:
64
+ return worker_id_hint
65
+
66
+ raise PermissionError("Unauthorized: No valid token found")
67
+
68
+
13
69
  def client_auth_middleware_factory(
14
70
  storage: StorageBackend,
15
71
  ) -> Any:
@@ -38,77 +94,3 @@ def client_auth_middleware_factory(
38
94
  return await handler(request)
39
95
 
40
96
  return middleware
41
-
42
-
43
- def worker_auth_middleware_factory(
44
- storage: StorageBackend,
45
- config: Config,
46
- ) -> Any:
47
- """
48
- Middleware factory for worker authentication.
49
- It supports both individual tokens and a global fallback token for backward compatibility.
50
- It also attaches the authenticated worker_id to the request.
51
- """
52
-
53
- @web.middleware
54
- async def middleware(request: web.Request, handler: Handler) -> web.Response:
55
- provided_token = request.headers.get(AUTH_HEADER_WORKER)
56
- if not provided_token:
57
- return web.json_response(
58
- {"error": f"Missing {AUTH_HEADER_WORKER} header"},
59
- status=401,
60
- )
61
-
62
- worker_id = request.match_info.get("worker_id")
63
- data = None
64
-
65
- # For specific endpoints, worker_id is in the body.
66
- # We need to read the body here, which can be tricky as it's a stream.
67
- # We clone the request to allow the handler to read the body again.
68
- if not worker_id and (request.path.endswith("/register") or request.path.endswith("/tasks/result")):
69
- try:
70
- cloned_request = request.clone()
71
- data = await cloned_request.json()
72
- worker_id = data.get("worker_id")
73
- # Attach the parsed data to the request so the handler doesn't need to re-parse
74
- if request.path.endswith("/register"):
75
- request["worker_registration_data"] = data
76
- except Exception:
77
- return web.json_response({"error": "Invalid JSON body"}, status=400)
78
-
79
- # If no worker_id could be determined from path or body, we can only validate against the global token.
80
- if not worker_id:
81
- if provided_token == config.GLOBAL_WORKER_TOKEN:
82
- # We don't know the worker_id, so we can't attach it.
83
- return await handler(request)
84
- else:
85
- return web.json_response(
86
- {"error": "Unauthorized: Invalid token or missing worker_id"},
87
- status=401,
88
- )
89
-
90
- # --- Individual Token Check ---
91
- expected_token_hash = await storage.get_worker_token(worker_id)
92
- if expected_token_hash:
93
- hashed_provided_token = sha256(provided_token.encode()).hexdigest()
94
- if hashed_provided_token == expected_token_hash:
95
- request["worker_id"] = worker_id # Attach authenticated worker_id
96
- return await handler(request)
97
- else:
98
- # If an individual token exists, we do not fall back to the global token.
99
- return web.json_response(
100
- {"error": "Unauthorized: Invalid individual worker token"},
101
- status=401,
102
- )
103
-
104
- # --- Global Token Fallback ---
105
- if config.GLOBAL_WORKER_TOKEN and provided_token == config.GLOBAL_WORKER_TOKEN:
106
- request["worker_id"] = worker_id # Attach authenticated worker_id
107
- return await handler(request)
108
-
109
- return web.json_response(
110
- {"error": "Unauthorized: No valid token found"},
111
- status=401,
112
- )
113
-
114
- return middleware
File without changes
@@ -0,0 +1,267 @@
1
+ from hashlib import sha256
2
+ from logging import getLogger
3
+ from secrets import token_urlsafe
4
+ from time import monotonic
5
+ from typing import Any, Optional
6
+
7
+ from rxon.models import TokenResponse
8
+ from rxon.validators import validate_identifier
9
+
10
+ from ..app_keys import S3_SERVICE_KEY
11
+ from ..config import Config
12
+ from ..constants import (
13
+ ERROR_CODE_DEPENDENCY,
14
+ ERROR_CODE_INTEGRITY_MISMATCH,
15
+ ERROR_CODE_INVALID_INPUT,
16
+ ERROR_CODE_PERMANENT,
17
+ ERROR_CODE_SECURITY,
18
+ ERROR_CODE_TRANSIENT,
19
+ JOB_STATUS_CANCELLED,
20
+ JOB_STATUS_FAILED,
21
+ JOB_STATUS_QUARANTINED,
22
+ JOB_STATUS_RUNNING,
23
+ JOB_STATUS_WAITING_FOR_PARALLEL,
24
+ TASK_STATUS_CANCELLED,
25
+ TASK_STATUS_FAILURE,
26
+ TASK_STATUS_SUCCESS,
27
+ )
28
+ from ..history.base import HistoryStorageBase
29
+ from ..storage.base import StorageBackend
30
+
31
+ logger = getLogger(__name__)
32
+
33
+
34
+ class WorkerService:
35
+ def __init__(
36
+ self,
37
+ storage: StorageBackend,
38
+ history_storage: HistoryStorageBase,
39
+ config: Config,
40
+ engine: Any,
41
+ ):
42
+ self.storage = storage
43
+ self.history_storage = history_storage
44
+ self.config = config
45
+ self.engine = engine
46
+
47
+ async def register_worker(self, worker_data: dict[str, Any]) -> None:
48
+ """
49
+ Registers a new worker.
50
+ :param worker_data: Raw dictionary from request (to be validated/converted to Model later)
51
+ """
52
+ worker_id = worker_data.get("worker_id")
53
+ if not worker_id:
54
+ raise ValueError("Missing required field: worker_id")
55
+
56
+ validate_identifier(worker_id, "worker_id")
57
+
58
+ # S3 Consistency Check
59
+ s3_service = self.engine.app.get(S3_SERVICE_KEY)
60
+ if s3_service:
61
+ orchestrator_s3_hash = s3_service.get_config_hash()
62
+ worker_capabilities = worker_data.get("capabilities", {})
63
+ worker_s3_hash = worker_capabilities.get("s3_config_hash")
64
+
65
+ if orchestrator_s3_hash and worker_s3_hash and orchestrator_s3_hash != worker_s3_hash:
66
+ logger.warning(
67
+ f"Worker '{worker_id}' has a different S3 configuration hash! "
68
+ f"Orchestrator: {orchestrator_s3_hash}, Worker: {worker_s3_hash}. "
69
+ "This may lead to 'split-brain' storage issues."
70
+ )
71
+
72
+ ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
73
+ await self.storage.register_worker(worker_id, worker_data, ttl)
74
+
75
+ logger.info(f"Worker '{worker_id}' registered with info: {worker_data}")
76
+
77
+ await self.history_storage.log_worker_event(
78
+ {
79
+ "worker_id": worker_id,
80
+ "event_type": "registered",
81
+ "worker_info_snapshot": worker_data,
82
+ }
83
+ )
84
+
85
+ async def get_next_task(self, worker_id: str) -> Optional[dict[str, Any]]:
86
+ """
87
+ Retrieves the next task for a worker using long-polling configuration.
88
+ """
89
+ logger.debug(f"Worker {worker_id} is requesting a new task.")
90
+ return await self.storage.dequeue_task_for_worker(worker_id, self.config.WORKER_POLL_TIMEOUT_SECONDS)
91
+
92
+ async def process_task_result(self, result_payload: dict[str, Any], authenticated_worker_id: str) -> str:
93
+ """
94
+ Processes a task result submitted by a worker.
95
+ Returns a status string constant.
96
+ """
97
+ payload_worker_id = result_payload.get("worker_id")
98
+
99
+ if payload_worker_id and payload_worker_id != authenticated_worker_id:
100
+ raise PermissionError(
101
+ f"Forbidden: Authenticated worker '{authenticated_worker_id}' "
102
+ f"cannot submit results for another worker '{payload_worker_id}'."
103
+ )
104
+
105
+ job_id = result_payload.get("job_id")
106
+ task_id = result_payload.get("task_id")
107
+ result_data = result_payload.get("result", {})
108
+
109
+ if not job_id or not task_id:
110
+ raise ValueError("job_id and task_id are required")
111
+
112
+ job_state = await self.storage.get_job_state(job_id)
113
+ if not job_state:
114
+ raise LookupError("Job not found")
115
+
116
+ if job_state.get("status") == JOB_STATUS_WAITING_FOR_PARALLEL:
117
+ await self.storage.remove_job_from_watch(f"{job_id}:{task_id}")
118
+ job_state.setdefault("aggregation_results", {})[task_id] = result_data
119
+
120
+ branches = job_state.setdefault("active_branches", [])
121
+ if task_id in branches:
122
+ branches.remove(task_id)
123
+
124
+ if not branches:
125
+ logger.info(f"All parallel branches for job {job_id} have completed.")
126
+ job_state["status"] = JOB_STATUS_RUNNING
127
+ job_state["current_state"] = job_state["aggregation_target"]
128
+ await self.storage.save_job_state(job_id, job_state)
129
+ await self.storage.enqueue_job(job_id)
130
+ else:
131
+ logger.info(
132
+ f"Branch {task_id} for job {job_id} completed. Waiting for {len(branches)} more.",
133
+ )
134
+ await self.storage.save_job_state(job_id, job_state)
135
+
136
+ return "parallel_branch_result_accepted"
137
+
138
+ await self.storage.remove_job_from_watch(job_id)
139
+
140
+ now = monotonic()
141
+ dispatched_at = job_state.get("task_dispatched_at", now)
142
+ duration_ms = int((now - dispatched_at) * 1000)
143
+
144
+ await self.history_storage.log_job_event(
145
+ {
146
+ "job_id": job_id,
147
+ "state": job_state.get("current_state"),
148
+ "event_type": "task_finished",
149
+ "duration_ms": duration_ms,
150
+ "worker_id": authenticated_worker_id,
151
+ "context_snapshot": {**job_state, "result": result_data},
152
+ },
153
+ )
154
+
155
+ result_status = result_data.get("status", TASK_STATUS_SUCCESS) # Default to success? Constant?
156
+
157
+ if result_status == TASK_STATUS_FAILURE:
158
+ return await self._handle_task_failure(job_state, task_id, result_data)
159
+
160
+ if result_status == TASK_STATUS_CANCELLED:
161
+ logger.info(f"Task {task_id} for job {job_id} was cancelled by worker.")
162
+ job_state["status"] = JOB_STATUS_CANCELLED
163
+ await self.storage.save_job_state(job_id, job_state)
164
+
165
+ transitions = job_state.get("current_task_transitions", {})
166
+ if next_state := transitions.get("cancelled"):
167
+ job_state["current_state"] = next_state
168
+ job_state["status"] = JOB_STATUS_RUNNING
169
+ await self.storage.save_job_state(job_id, job_state)
170
+ await self.storage.enqueue_job(job_id)
171
+ return "result_accepted_cancelled"
172
+
173
+ transitions = job_state.get("current_task_transitions", {})
174
+ result_status = result_data.get("status", TASK_STATUS_SUCCESS)
175
+ next_state = transitions.get(result_status)
176
+
177
+ if next_state:
178
+ logger.info(f"Job {job_id} transitioning based on worker status '{result_status}' to state '{next_state}'")
179
+
180
+ worker_data_content = result_data.get("data")
181
+ if worker_data_content and isinstance(worker_data_content, dict):
182
+ if "state_history" not in job_state:
183
+ job_state["state_history"] = {}
184
+ job_state["state_history"].update(worker_data_content)
185
+
186
+ data_metadata = result_payload.get("data_metadata")
187
+ if data_metadata:
188
+ if "data_metadata" not in job_state:
189
+ job_state["data_metadata"] = {}
190
+ job_state["data_metadata"].update(data_metadata)
191
+ logger.debug(f"Stored data metadata for job {job_id}: {list(data_metadata.keys())}")
192
+
193
+ job_state["current_state"] = next_state
194
+ job_state["status"] = JOB_STATUS_RUNNING
195
+ await self.storage.save_job_state(job_id, job_state)
196
+ await self.storage.enqueue_job(job_id)
197
+ return "result_accepted_success"
198
+ else:
199
+ logger.error(f"Job {job_id} failed. Worker returned unhandled status '{result_status}'.")
200
+ job_state["status"] = JOB_STATUS_FAILED
201
+ job_state["error_message"] = f"Worker returned unhandled status: {result_status}"
202
+ await self.storage.save_job_state(job_id, job_state)
203
+ return "result_accepted_failure"
204
+
205
+ async def _handle_task_failure(self, job_state: dict, task_id: str, result_data: dict) -> str:
206
+ error_details = result_data.get("error", {})
207
+ error_type = ERROR_CODE_TRANSIENT
208
+ error_message = "No error details provided."
209
+
210
+ if isinstance(error_details, dict):
211
+ error_type = error_details.get("code", ERROR_CODE_TRANSIENT)
212
+ error_message = error_details.get("message", "No error message provided.")
213
+ elif isinstance(error_details, str):
214
+ error_message = error_details
215
+
216
+ job_id = job_state["id"]
217
+ logger.warning(f"Task {task_id} for job {job_id} failed with error type '{error_type}'.")
218
+
219
+ if error_type in (ERROR_CODE_PERMANENT, ERROR_CODE_SECURITY, ERROR_CODE_DEPENDENCY):
220
+ job_state["status"] = JOB_STATUS_QUARANTINED
221
+ job_state["error_message"] = f"Task failed with permanent error ({error_type}): {error_message}"
222
+ await self.storage.save_job_state(job_id, job_state)
223
+ await self.storage.quarantine_job(job_id)
224
+ elif error_type == ERROR_CODE_INVALID_INPUT:
225
+ job_state["status"] = JOB_STATUS_FAILED
226
+ job_state["error_message"] = f"Task failed due to invalid input: {error_message}"
227
+ await self.storage.save_job_state(job_id, job_state)
228
+ elif error_type == ERROR_CODE_INTEGRITY_MISMATCH:
229
+ job_state["status"] = JOB_STATUS_FAILED
230
+ job_state["error_message"] = f"Task failed due to data integrity mismatch: {error_message}"
231
+ await self.storage.save_job_state(job_id, job_state)
232
+ logger.critical(f"Data integrity mismatch detected for job {job_id}: {error_message}")
233
+ else:
234
+ await self.engine.handle_task_failure(job_state, task_id, error_message)
235
+ return "result_accepted_failure"
236
+
237
+ async def issue_access_token(self, worker_id: str) -> TokenResponse:
238
+ """Generates and stores a temporary access token."""
239
+ raw_token = token_urlsafe(32)
240
+ token_hash = sha256(raw_token.encode()).hexdigest()
241
+ ttl = 3600
242
+
243
+ await self.storage.save_worker_access_token(worker_id, token_hash, ttl)
244
+ logger.info(f"Issued temporary access token for worker {worker_id}")
245
+
246
+ return TokenResponse(access_token=raw_token, expires_in=ttl, worker_id=worker_id)
247
+
248
+ async def update_worker_heartbeat(
249
+ self, worker_id: str, update_data: Optional[dict[str, Any]]
250
+ ) -> Optional[dict[str, Any]]:
251
+ """Updates worker TTL and status."""
252
+ ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
253
+
254
+ if update_data:
255
+ updated_worker = await self.storage.update_worker_status(worker_id, update_data, ttl)
256
+ if updated_worker:
257
+ await self.history_storage.log_worker_event(
258
+ {
259
+ "worker_id": worker_id,
260
+ "event_type": "status_update",
261
+ "worker_info_snapshot": updated_worker,
262
+ },
263
+ )
264
+ return updated_worker
265
+ else:
266
+ refreshed = await self.storage.refresh_worker_ttl(worker_id, ttl)
267
+ return {"status": "ttl_refreshed"} if refreshed else None
@@ -292,6 +292,16 @@ class StorageBackend(ABC):
292
292
  """Retrieves an individual token for a specific worker."""
293
293
  raise NotImplementedError
294
294
 
295
+ @abstractmethod
296
+ async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
297
+ """Saves a temporary access token for a worker (STS)."""
298
+ raise NotImplementedError
299
+
300
+ @abstractmethod
301
+ async def verify_worker_access_token(self, token: str) -> str | None:
302
+ """Verifies a temporary access token and returns the associated worker_id if valid."""
303
+ raise NotImplementedError
304
+
295
305
  @abstractmethod
296
306
  async def get_worker_info(self, worker_id: str) -> dict[str, Any] | None:
297
307
  """Get complete information about a worker by its ID."""
@@ -12,12 +12,12 @@ class MemoryStorage(StorageBackend):
12
12
  Not persistent.
13
13
  """
14
14
 
15
- def __init__(self):
15
+ def __init__(self) -> None:
16
16
  self._jobs: dict[str, dict[str, Any]] = {}
17
17
  self._workers: dict[str, dict[str, Any]] = {}
18
18
  self._worker_ttls: dict[str, float] = {}
19
- self._worker_task_queues: dict[str, PriorityQueue] = {}
20
- self._job_queue = Queue()
19
+ self._worker_task_queues: dict[str, PriorityQueue[Any]] = {}
20
+ self._job_queue: Queue[str] = Queue()
21
21
  self._quarantine_queue: list[str] = []
22
22
  self._watched_jobs: dict[str, float] = {}
23
23
  self._client_configs: dict[str, dict[str, Any]] = {}
@@ -189,10 +189,11 @@ class MemoryStorage(StorageBackend):
189
189
  async with self._lock:
190
190
  self._watched_jobs.pop(job_id, None)
191
191
 
192
- async def get_timed_out_jobs(self) -> list[str]:
192
+ async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
193
193
  async with self._lock:
194
194
  now = monotonic()
195
195
  timed_out_ids = [job_id for job_id, timeout_at in self._watched_jobs.items() if timeout_at <= now]
196
+ timed_out_ids = timed_out_ids[:limit]
196
197
  for job_id in timed_out_ids:
197
198
  self._watched_jobs.pop(job_id, None)
198
199
  return timed_out_ids
@@ -331,6 +332,16 @@ class MemoryStorage(StorageBackend):
331
332
  async with self._lock:
332
333
  return self._worker_tokens.get(worker_id)
333
334
 
335
+ async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
336
+ async with self._lock:
337
+ self._generic_keys[f"sts:{token}"] = worker_id
338
+ self._generic_key_ttls[f"sts:{token}"] = monotonic() + ttl
339
+
340
+ async def verify_worker_access_token(self, token: str) -> str | None:
341
+ async with self._lock:
342
+ await self._clean_expired()
343
+ return self._generic_keys.get(f"sts:{token}")
344
+
334
345
  async def set_task_cancellation_flag(self, task_id: str) -> None:
335
346
  key = f"task_cancel:{task_id}"
336
347
  await self.increment_key_with_ttl(key, 3600)
@@ -95,7 +95,7 @@ class RedisStorage(StorageBackend):
95
95
  self,
96
96
  job_id: str,
97
97
  update_data: dict[str, Any],
98
- ) -> dict[Any, Any] | None | Any:
98
+ ) -> dict[str, Any]:
99
99
  """Atomically update the job state in Redis using a transaction."""
100
100
  key = self._get_key(job_id)
101
101
 
@@ -104,7 +104,7 @@ class RedisStorage(StorageBackend):
104
104
  try:
105
105
  await pipe.watch(key)
106
106
  current_state_raw = await pipe.get(key)
107
- current_state = self._unpack(current_state_raw) if current_state_raw else {}
107
+ current_state: dict[str, Any] = self._unpack(current_state_raw) if current_state_raw else {}
108
108
  current_state.update(update_data)
109
109
 
110
110
  pipe.multi()
@@ -147,7 +147,7 @@ class RedisStorage(StorageBackend):
147
147
  key = f"orchestrator:worker:info:{worker_id}"
148
148
  tasks_key = f"orchestrator:worker:tasks:{worker_id}"
149
149
 
150
- tasks = await self._redis.smembers(tasks_key) # type: ignore
150
+ tasks = await self._redis.smembers(tasks_key) # type: ignore[var-annotated]
151
151
 
152
152
  async with self._redis.pipeline(transaction=True) as pipe:
153
153
  pipe.delete(key)
@@ -156,7 +156,7 @@ class RedisStorage(StorageBackend):
156
156
  pipe.srem("orchestrator:index:workers:idle", worker_id)
157
157
 
158
158
  for task in tasks:
159
- task_str = task.decode("utf-8") if isinstance(task, bytes) else task
159
+ task_str = task.decode("utf-8") if isinstance(task, bytes) else str(task)
160
160
  pipe.srem(f"orchestrator:index:workers:task:{task_str}", worker_id)
161
161
 
162
162
  await pipe.execute()
@@ -204,8 +204,8 @@ class RedisStorage(StorageBackend):
204
204
  """Finds idle workers that support the given task using set intersection."""
205
205
  task_index = f"orchestrator:index:workers:task:{task_type}"
206
206
  idle_index = "orchestrator:index:workers:idle"
207
- worker_ids = await self._redis.sinter(task_index, idle_index) # type: ignore
208
- return [wid.decode("utf-8") if isinstance(wid, bytes) else wid for wid in worker_ids]
207
+ worker_ids = await self._redis.sinter(task_index, idle_index) # type: ignore[var-annotated]
208
+ return [wid.decode("utf-8") if isinstance(wid, bytes) else str(wid) for wid in worker_ids]
209
209
 
210
210
  async def enqueue_task_for_worker(self, worker_id: str, task_payload: dict[str, Any], priority: float) -> None:
211
211
  key = f"orchestrator:task_queue:{worker_id}"
@@ -274,13 +274,14 @@ class RedisStorage(StorageBackend):
274
274
  existence = await pipe.execute()
275
275
  dead_ids = [worker_ids[i] for i, exists in enumerate(existence) if not exists]
276
276
  for wid in dead_ids:
277
- tasks = await self._redis.smembers(f"orchestrator:worker:tasks:{wid}") # type: ignore
277
+ tasks = await self._redis.smembers(f"orchestrator:worker:tasks:{wid}") # type: ignore[var-annotated]
278
278
  async with self._redis.pipeline(transaction=True) as p:
279
279
  p.delete(f"orchestrator:worker:tasks:{wid}")
280
280
  p.srem("orchestrator:index:workers:all", wid)
281
281
  p.srem("orchestrator:index:workers:idle", wid)
282
282
  for t in tasks:
283
- p.srem(f"orchestrator:index:workers:task:{t.decode() if isinstance(t, bytes) else t}", wid)
283
+ t_str = t.decode() if isinstance(t, bytes) else str(t)
284
+ p.srem(f"orchestrator:index:workers:task:{t_str}", wid)
284
285
  await p.execute()
285
286
 
286
287
  async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
@@ -291,10 +292,33 @@ class RedisStorage(StorageBackend):
291
292
 
292
293
  async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
293
294
  now = get_running_loop().time()
294
- ids = await self._redis.zrangebyscore("orchestrator:watched_jobs", 0, now, start=0, num=limit)
295
+ # Lua script to atomically fetch and remove timed out jobs
296
+ LUA_POP_TIMEOUTS = """
297
+ local now = ARGV[1]
298
+ local limit = ARGV[2]
299
+ local ids = redis.call('ZRANGEBYSCORE', KEYS[1], 0, now, 'LIMIT', 0, limit)
300
+ if #ids > 0 then
301
+ redis.call('ZREM', KEYS[1], unpack(ids))
302
+ end
303
+ return ids
304
+ """
305
+ try:
306
+ sha = await self._redis.script_load(LUA_POP_TIMEOUTS)
307
+ ids = await self._redis.evalsha(sha, 1, "orchestrator:watched_jobs", now, limit)
308
+ except NoScriptError:
309
+ ids = await self._redis.eval(LUA_POP_TIMEOUTS, 1, "orchestrator:watched_jobs", now, limit)
310
+ except ResponseError as e:
311
+ # Fallback for Redis versions that don't support script_load/evalsha or other errors
312
+ if "unknown command" in str(e).lower():
313
+ logger.warning("Redis does not support LUA scripts. Falling back to non-atomic get_timed_out_jobs.")
314
+ ids = await self._redis.zrangebyscore("orchestrator:watched_jobs", 0, now, start=0, num=limit)
315
+ if ids:
316
+ await self._redis.zrem("orchestrator:watched_jobs", *ids) # type: ignore
317
+ else:
318
+ raise e
319
+
295
320
  if ids:
296
- await self._redis.zrem("orchestrator:watched_jobs", *ids) # type: ignore
297
- return [i.decode("utf-8") for i in ids]
321
+ return [i.decode("utf-8") if isinstance(i, bytes) else i for i in ids]
298
322
  return []
299
323
 
300
324
  async def enqueue_job(self, job_id: str) -> None:
@@ -411,6 +435,13 @@ class RedisStorage(StorageBackend):
411
435
  token = await self._redis.get(f"orchestrator:worker:token:{worker_id}")
412
436
  return token.decode("utf-8") if token else None
413
437
 
438
+ async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
439
+ await self._redis.set(f"orchestrator:sts:token:{token}", worker_id, ex=ttl)
440
+
441
+ async def verify_worker_access_token(self, token: str) -> str | None:
442
+ worker_id = await self._redis.get(f"orchestrator:sts:token:{token}")
443
+ return worker_id.decode("utf-8") if worker_id else None
444
+
414
445
  async def acquire_lock(self, key: str, holder_id: str, ttl: int) -> bool:
415
446
  return bool(await self._redis.set(f"orchestrator:lock:{key}", holder_id, nx=True, ex=ttl))
416
447
 
avtomatika/telemetry.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from logging import getLogger
2
2
  from os import getenv
3
+ from typing import Any
3
4
 
4
5
  logger = getLogger(__name__)
5
6
 
@@ -17,28 +18,28 @@ except ImportError:
17
18
  TELEMETRY_ENABLED = False
18
19
 
19
20
  class DummySpan:
20
- def __enter__(self):
21
+ def __enter__(self) -> "DummySpan":
21
22
  return self
22
23
 
23
- def __exit__(self, *args):
24
+ def __exit__(self, *args: Any) -> None:
24
25
  pass
25
26
 
26
- def set_attribute(self, key, value):
27
+ def set_attribute(self, key: str, value: Any) -> None:
27
28
  pass
28
29
 
29
30
  class DummyTracer:
30
31
  @staticmethod
31
- def start_as_current_span(name, context=None):
32
+ def start_as_current_span(name: str, context: Any = None) -> DummySpan:
32
33
  return DummySpan()
33
34
 
34
35
  class NoOpTrace:
35
- def get_tracer(self, name):
36
+ def get_tracer(self, name: str) -> DummyTracer:
36
37
  return DummyTracer()
37
38
 
38
- trace = NoOpTrace()
39
+ trace: Any = NoOpTrace() # type: ignore[no-redef]
39
40
 
40
41
 
41
- def setup_telemetry(service_name: str = "avtomatika"):
42
+ def setup_telemetry(service_name: str = "avtomatika") -> Any:
42
43
  """Configures OpenTelemetry for the application if installed."""
43
44
  if not TELEMETRY_ENABLED:
44
45
  logger.info("opentelemetry-sdk not found. Telemetry is disabled.")
@@ -1,8 +1,8 @@
1
- from asyncio import CancelledError, Queue, QueueFull, create_task, sleep
1
+ from asyncio import CancelledError, Queue, QueueFull, Task, create_task, sleep
2
2
  from contextlib import suppress
3
3
  from dataclasses import asdict, dataclass
4
4
  from logging import getLogger
5
- from typing import Any
5
+ from typing import Any, Optional
6
6
 
7
7
  from aiohttp import ClientSession, ClientTimeout
8
8
 
@@ -24,7 +24,7 @@ class WebhookSender:
24
24
  self.timeout = ClientTimeout(total=10)
25
25
  self.max_retries = 3
26
26
  self._queue: Queue[tuple[str, WebhookPayload]] = Queue(maxsize=1000)
27
- self._worker_task = None
27
+ self._worker_task: Optional[Task[None]] = None
28
28
 
29
29
  def start(self) -> None:
30
30
  if not self._worker_task:
avtomatika/watcher.py CHANGED
@@ -3,6 +3,8 @@ from logging import getLogger
3
3
  from typing import TYPE_CHECKING
4
4
  from uuid import uuid4
5
5
 
6
+ from .constants import JOB_STATUS_FAILED, JOB_STATUS_WAITING_FOR_WORKER
7
+
6
8
  if TYPE_CHECKING:
7
9
  from .engine import OrchestratorEngine
8
10
 
@@ -38,8 +40,8 @@ class Watcher:
38
40
  try:
39
41
  # Get the latest version to avoid overwriting
40
42
  job_state = await self.storage.get_job_state(job_id)
41
- if job_state and job_state["status"] == "waiting_for_worker":
42
- job_state["status"] = "failed"
43
+ if job_state and job_state["status"] == JOB_STATUS_WAITING_FOR_WORKER:
44
+ job_state["status"] = JOB_STATUS_FAILED
43
45
  job_state["error_message"] = "Worker task timed out."
44
46
  await self.storage.save_job_state(job_id, job_state)
45
47