avtomatika 1.0b7__py3-none-any.whl → 1.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,266 @@
1
+ from hashlib import sha256
2
+ from logging import getLogger
3
+ from secrets import token_urlsafe
4
+ from time import monotonic
5
+ from typing import Any, Optional
6
+
7
+ from rxon.models import TokenResponse
8
+ from rxon.validators import validate_identifier
9
+
10
+ from ..app_keys import S3_SERVICE_KEY
11
+ from ..config import Config
12
+ from ..constants import (
13
+ ERROR_CODE_INTEGRITY_MISMATCH,
14
+ ERROR_CODE_INVALID_INPUT,
15
+ ERROR_CODE_PERMANENT,
16
+ ERROR_CODE_TRANSIENT,
17
+ JOB_STATUS_CANCELLED,
18
+ JOB_STATUS_FAILED,
19
+ JOB_STATUS_QUARANTINED,
20
+ JOB_STATUS_RUNNING,
21
+ JOB_STATUS_WAITING_FOR_PARALLEL,
22
+ TASK_STATUS_CANCELLED,
23
+ TASK_STATUS_FAILURE,
24
+ TASK_STATUS_SUCCESS,
25
+ )
26
+ from ..history.base import HistoryStorageBase
27
+ from ..storage.base import StorageBackend
28
+
29
+ logger = getLogger(__name__)
30
+
31
+
32
+ class WorkerService:
33
+ def __init__(
34
+ self,
35
+ storage: StorageBackend,
36
+ history_storage: HistoryStorageBase,
37
+ config: Config,
38
+ engine: Any,
39
+ ):
40
+ self.storage = storage
41
+ self.history_storage = history_storage
42
+ self.config = config
43
+ self.engine = engine
44
+
45
+ async def register_worker(self, worker_data: dict[str, Any]) -> None:
46
+ """
47
+ Registers a new worker.
48
+ :param worker_data: Raw dictionary from request (to be validated/converted to Model later)
49
+ """
50
+ worker_id = worker_data.get("worker_id")
51
+ if not worker_id:
52
+ raise ValueError("Missing required field: worker_id")
53
+
54
+ validate_identifier(worker_id, "worker_id")
55
+
56
+ # S3 Consistency Check
57
+ s3_service = self.engine.app.get(S3_SERVICE_KEY)
58
+ if s3_service:
59
+ orchestrator_s3_hash = s3_service.get_config_hash()
60
+ worker_capabilities = worker_data.get("capabilities", {})
61
+ worker_s3_hash = worker_capabilities.get("s3_config_hash")
62
+
63
+ if orchestrator_s3_hash and worker_s3_hash and orchestrator_s3_hash != worker_s3_hash:
64
+ logger.warning(
65
+ f"Worker '{worker_id}' has a different S3 configuration hash! "
66
+ f"Orchestrator: {orchestrator_s3_hash}, Worker: {worker_s3_hash}. "
67
+ "This may lead to 'split-brain' storage issues."
68
+ )
69
+
70
+ ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
71
+ await self.storage.register_worker(worker_id, worker_data, ttl)
72
+
73
+ logger.info(f"Worker '{worker_id}' registered with info: {worker_data}")
74
+
75
+ await self.history_storage.log_worker_event(
76
+ {
77
+ "worker_id": worker_id,
78
+ "event_type": "registered",
79
+ "worker_info_snapshot": worker_data,
80
+ }
81
+ )
82
+
83
+ async def get_next_task(self, worker_id: str) -> Optional[dict[str, Any]]:
84
+ """
85
+ Retrieves the next task for a worker using long-polling configuration.
86
+ """
87
+ logger.debug(f"Worker {worker_id} is requesting a new task.")
88
+ return await self.storage.dequeue_task_for_worker(worker_id, self.config.WORKER_POLL_TIMEOUT_SECONDS)
89
+
90
+ async def process_task_result(self, result_payload: dict[str, Any], authenticated_worker_id: str) -> str:
91
+ """
92
+ Processes a task result submitted by a worker.
93
+ Returns a status string constant.
94
+ """
95
+ payload_worker_id = result_payload.get("worker_id")
96
+
97
+ if payload_worker_id and payload_worker_id != authenticated_worker_id:
98
+ raise PermissionError(
99
+ f"Forbidden: Authenticated worker '{authenticated_worker_id}' "
100
+ f"cannot submit results for another worker '{payload_worker_id}'."
101
+ )
102
+
103
+ job_id = result_payload.get("job_id")
104
+ task_id = result_payload.get("task_id")
105
+ result_data = result_payload.get("result", {})
106
+
107
+ if not job_id or not task_id:
108
+ raise ValueError("job_id and task_id are required")
109
+
110
+ job_state = await self.storage.get_job_state(job_id)
111
+ if not job_state:
112
+ raise LookupError("Job not found")
113
+
114
+ if job_state.get("status") == JOB_STATUS_WAITING_FOR_PARALLEL:
115
+ await self.storage.remove_job_from_watch(f"{job_id}:{task_id}")
116
+ job_state.setdefault("aggregation_results", {})[task_id] = result_data
117
+
118
+ branches = job_state.setdefault("active_branches", [])
119
+ if task_id in branches:
120
+ branches.remove(task_id)
121
+
122
+ if not branches:
123
+ logger.info(f"All parallel branches for job {job_id} have completed.")
124
+ job_state["status"] = JOB_STATUS_RUNNING
125
+ job_state["current_state"] = job_state["aggregation_target"]
126
+ await self.storage.save_job_state(job_id, job_state)
127
+ await self.storage.enqueue_job(job_id)
128
+ else:
129
+ logger.info(
130
+ f"Branch {task_id} for job {job_id} completed. Waiting for {len(branches)} more.",
131
+ )
132
+ await self.storage.save_job_state(job_id, job_state)
133
+
134
+ return "parallel_branch_result_accepted"
135
+
136
+ await self.storage.remove_job_from_watch(job_id)
137
+
138
+ now = monotonic()
139
+ dispatched_at = job_state.get("task_dispatched_at", now)
140
+ duration_ms = int((now - dispatched_at) * 1000)
141
+
142
+ await self.history_storage.log_job_event(
143
+ {
144
+ "job_id": job_id,
145
+ "state": job_state.get("current_state"),
146
+ "event_type": "task_finished",
147
+ "duration_ms": duration_ms,
148
+ "worker_id": authenticated_worker_id,
149
+ "context_snapshot": {**job_state, "result": result_data},
150
+ },
151
+ )
152
+
153
+ result_status = result_data.get("status", TASK_STATUS_SUCCESS) # Default to success? Constant?
154
+
155
+ if result_status == TASK_STATUS_FAILURE:
156
+ return await self._handle_task_failure(job_state, task_id, result_data)
157
+
158
+ if result_status == TASK_STATUS_CANCELLED:
159
+ logger.info(f"Task {task_id} for job {job_id} was cancelled by worker.")
160
+ job_state["status"] = JOB_STATUS_CANCELLED
161
+ await self.storage.save_job_state(job_id, job_state)
162
+
163
+ transitions = job_state.get("current_task_transitions", {})
164
+ if next_state := transitions.get("cancelled"):
165
+ job_state["current_state"] = next_state
166
+ job_state["status"] = JOB_STATUS_RUNNING
167
+ await self.storage.save_job_state(job_id, job_state)
168
+ await self.storage.enqueue_job(job_id)
169
+ return "result_accepted_cancelled"
170
+
171
+ transitions = job_state.get("current_task_transitions", {})
172
+ result_status = result_data.get("status", TASK_STATUS_SUCCESS)
173
+ next_state = transitions.get(result_status)
174
+
175
+ if next_state:
176
+ logger.info(f"Job {job_id} transitioning based on worker status '{result_status}' to state '{next_state}'")
177
+
178
+ worker_data_content = result_data.get("data")
179
+ if worker_data_content and isinstance(worker_data_content, dict):
180
+ if "state_history" not in job_state:
181
+ job_state["state_history"] = {}
182
+ job_state["state_history"].update(worker_data_content)
183
+
184
+ data_metadata = result_payload.get("data_metadata")
185
+ if data_metadata:
186
+ if "data_metadata" not in job_state:
187
+ job_state["data_metadata"] = {}
188
+ job_state["data_metadata"].update(data_metadata)
189
+ logger.debug(f"Stored data metadata for job {job_id}: {list(data_metadata.keys())}")
190
+
191
+ job_state["current_state"] = next_state
192
+ job_state["status"] = JOB_STATUS_RUNNING
193
+ await self.storage.save_job_state(job_id, job_state)
194
+ await self.storage.enqueue_job(job_id)
195
+ return "result_accepted_success"
196
+ else:
197
+ logger.error(f"Job {job_id} failed. Worker returned unhandled status '{result_status}'.")
198
+ job_state["status"] = JOB_STATUS_FAILED
199
+ job_state["error_message"] = f"Worker returned unhandled status: {result_status}"
200
+ await self.storage.save_job_state(job_id, job_state)
201
+ return "result_accepted_failure"
202
+
203
+ async def _handle_task_failure(self, job_state: dict, task_id: str, result_data: dict) -> str:
204
+ error_details = result_data.get("error", {})
205
+ error_type = ERROR_CODE_TRANSIENT
206
+ error_message = "No error details provided."
207
+
208
+ if isinstance(error_details, dict):
209
+ error_type = error_details.get("code", ERROR_CODE_TRANSIENT)
210
+ error_message = error_details.get("message", "No error message provided.")
211
+ elif isinstance(error_details, str):
212
+ error_message = error_details
213
+
214
+ job_id = job_state["id"]
215
+ logger.warning(f"Task {task_id} for job {job_id} failed with error type '{error_type}'.")
216
+
217
+ if error_type == ERROR_CODE_PERMANENT:
218
+ job_state["status"] = JOB_STATUS_QUARANTINED
219
+ job_state["error_message"] = f"Task failed with permanent error: {error_message}"
220
+ await self.storage.save_job_state(job_id, job_state)
221
+ await self.storage.quarantine_job(job_id)
222
+ elif error_type == ERROR_CODE_INVALID_INPUT:
223
+ job_state["status"] = JOB_STATUS_FAILED
224
+ job_state["error_message"] = f"Task failed due to invalid input: {error_message}"
225
+ await self.storage.save_job_state(job_id, job_state)
226
+ elif error_type == ERROR_CODE_INTEGRITY_MISMATCH:
227
+ job_state["status"] = JOB_STATUS_FAILED
228
+ job_state["error_message"] = f"Task failed due to data integrity mismatch: {error_message}"
229
+ await self.storage.save_job_state(job_id, job_state)
230
+ logger.critical(f"Data integrity mismatch detected for job {job_id}: {error_message}")
231
+ else:
232
+ await self.engine.handle_task_failure(job_state, task_id, error_message)
233
+
234
+ return "result_accepted_failure"
235
+
236
+ async def issue_access_token(self, worker_id: str) -> TokenResponse:
237
+ """Generates and stores a temporary access token."""
238
+ raw_token = token_urlsafe(32)
239
+ token_hash = sha256(raw_token.encode()).hexdigest()
240
+ ttl = 3600
241
+
242
+ await self.storage.save_worker_access_token(worker_id, token_hash, ttl)
243
+ logger.info(f"Issued temporary access token for worker {worker_id}")
244
+
245
+ return TokenResponse(access_token=raw_token, expires_in=ttl, worker_id=worker_id)
246
+
247
+ async def update_worker_heartbeat(
248
+ self, worker_id: str, update_data: Optional[dict[str, Any]]
249
+ ) -> Optional[dict[str, Any]]:
250
+ """Updates worker TTL and status."""
251
+ ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
252
+
253
+ if update_data:
254
+ updated_worker = await self.storage.update_worker_status(worker_id, update_data, ttl)
255
+ if updated_worker:
256
+ await self.history_storage.log_worker_event(
257
+ {
258
+ "worker_id": worker_id,
259
+ "event_type": "status_update",
260
+ "worker_info_snapshot": updated_worker,
261
+ },
262
+ )
263
+ return updated_worker
264
+ else:
265
+ refreshed = await self.storage.refresh_worker_ttl(worker_id, ttl)
266
+ return {"status": "ttl_refreshed"} if refreshed else None
@@ -142,6 +142,37 @@ class StorageBackend(ABC):
142
142
  """
143
143
  raise NotImplementedError
144
144
 
145
+ @abstractmethod
146
+ async def get_active_worker_ids(self) -> list[str]:
147
+ """Returns a list of IDs for all currently active workers.
148
+
149
+ :return: A list of worker ID strings.
150
+ """
151
+ raise NotImplementedError
152
+
153
+ @abstractmethod
154
+ async def cleanup_expired_workers(self) -> None:
155
+ """Maintenance task to clean up internal indexes from expired worker entries."""
156
+ raise NotImplementedError
157
+
158
+ @abstractmethod
159
+ async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
160
+ """Bulk retrieves worker info for a list of IDs.
161
+
162
+ :param worker_ids: List of worker identifiers.
163
+ :return: List of worker info dictionaries.
164
+ """
165
+ raise NotImplementedError
166
+
167
+ @abstractmethod
168
+ async def find_workers_for_task(self, task_type: str) -> list[str]:
169
+ """Finds idle workers that support the given task.
170
+
171
+ :param task_type: The type of task to find workers for.
172
+ :return: A list of worker IDs that are idle and support the task.
173
+ """
174
+ raise NotImplementedError
175
+
145
176
  @abstractmethod
146
177
  async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
147
178
  """Add a job to the list for timeout tracking.
@@ -152,9 +183,10 @@ class StorageBackend(ABC):
152
183
  raise NotImplementedError
153
184
 
154
185
  @abstractmethod
155
- async def get_timed_out_jobs(self) -> list[str]:
186
+ async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
156
187
  """Get a list of job IDs that are overdue and remove them from the tracking list.
157
188
 
189
+ :param limit: Maximum number of jobs to retrieve.
158
190
  :return: A list of overdue job IDs.
159
191
  """
160
192
  raise NotImplementedError
@@ -165,9 +197,10 @@ class StorageBackend(ABC):
165
197
  raise NotImplementedError
166
198
 
167
199
  @abstractmethod
168
- async def dequeue_job(self) -> tuple[str, str] | None:
200
+ async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
169
201
  """Retrieve a job ID and its message ID from the execution queue.
170
202
 
203
+ :param block: Milliseconds to block if no message is available. None for non-blocking.
171
204
  :return: A tuple of (job_id, message_id) or None if the timeout has expired.
172
205
  """
173
206
  raise NotImplementedError
@@ -250,7 +283,7 @@ class StorageBackend(ABC):
250
283
  raise NotImplementedError
251
284
 
252
285
  @abstractmethod
253
- async def set_worker_token(self, worker_id: str, token: str):
286
+ async def set_worker_token(self, worker_id: str, token: str) -> None:
254
287
  """Saves an individual token for a specific worker."""
255
288
  raise NotImplementedError
256
289
 
@@ -259,13 +292,23 @@ class StorageBackend(ABC):
259
292
  """Retrieves an individual token for a specific worker."""
260
293
  raise NotImplementedError
261
294
 
295
+ @abstractmethod
296
+ async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
297
+ """Saves a temporary access token for a worker (STS)."""
298
+ raise NotImplementedError
299
+
300
+ @abstractmethod
301
+ async def verify_worker_access_token(self, token: str) -> str | None:
302
+ """Verifies a temporary access token and returns the associated worker_id if valid."""
303
+ raise NotImplementedError
304
+
262
305
  @abstractmethod
263
306
  async def get_worker_info(self, worker_id: str) -> dict[str, Any] | None:
264
307
  """Get complete information about a worker by its ID."""
265
308
  raise NotImplementedError
266
309
 
267
310
  @abstractmethod
268
- async def flush_all(self):
311
+ async def flush_all(self) -> None:
269
312
  """Completely clears the storage. Used mainly for tests."""
270
313
  raise NotImplementedError
271
314
 
@@ -312,3 +355,11 @@ class StorageBackend(ABC):
312
355
  :return: True if the lock was successfully released, False otherwise.
313
356
  """
314
357
  raise NotImplementedError
358
+
359
+ @abstractmethod
360
+ async def ping(self) -> bool:
361
+ """Checks connection to the storage backend.
362
+
363
+ :return: True if storage is accessible, False otherwise.
364
+ """
365
+ raise NotImplementedError
@@ -154,6 +154,33 @@ class MemoryStorage(StorageBackend):
154
154
  )
155
155
  return active_workers
156
156
 
157
+ async def get_active_worker_ids(self) -> list[str]:
158
+ async with self._lock:
159
+ now = monotonic()
160
+ return [worker_id for worker_id, ttl in self._worker_ttls.items() if ttl > now]
161
+
162
+ async def cleanup_expired_workers(self) -> None:
163
+ async with self._lock:
164
+ await self._clean_expired()
165
+
166
+ async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
167
+ async with self._lock:
168
+ return [self._workers[wid] for wid in worker_ids if wid in self._workers]
169
+
170
+ async def find_workers_for_task(self, task_type: str) -> list[str]:
171
+ """Finds idle workers supporting the task (O(N) for memory storage)."""
172
+ async with self._lock:
173
+ now = monotonic()
174
+ candidates = []
175
+ for worker_id, info in self._workers.items():
176
+ if self._worker_ttls.get(worker_id, 0) <= now:
177
+ continue
178
+ if info.get("status", "idle") != "idle":
179
+ continue
180
+ if task_type in info.get("supported_tasks", []):
181
+ candidates.append(worker_id)
182
+ return candidates
183
+
157
184
  async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
158
185
  async with self._lock:
159
186
  self._watched_jobs[job_id] = timeout_at
@@ -162,10 +189,11 @@ class MemoryStorage(StorageBackend):
162
189
  async with self._lock:
163
190
  self._watched_jobs.pop(job_id, None)
164
191
 
165
- async def get_timed_out_jobs(self) -> list[str]:
192
+ async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
166
193
  async with self._lock:
167
194
  now = monotonic()
168
195
  timed_out_ids = [job_id for job_id, timeout_at in self._watched_jobs.items() if timeout_at <= now]
196
+ timed_out_ids = timed_out_ids[:limit]
169
197
  for job_id in timed_out_ids:
170
198
  self._watched_jobs.pop(job_id, None)
171
199
  return timed_out_ids
@@ -173,13 +201,21 @@ class MemoryStorage(StorageBackend):
173
201
  async def enqueue_job(self, job_id: str) -> None:
174
202
  await self._job_queue.put(job_id)
175
203
 
176
- async def dequeue_job(self) -> tuple[str, str] | None:
177
- """Waits indefinitely for a job ID from the queue and returns it.
178
- Returns a tuple of (job_id, message_id). In MemoryStorage, message_id is dummy.
204
+ async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
205
+ """Waits for a job ID from the queue.
206
+ If block is None, waits indefinitely.
207
+ If block is int, waits for that many milliseconds.
179
208
  """
180
- job_id = await self._job_queue.get()
181
- self._job_queue.task_done()
182
- return job_id, "memory-msg-id"
209
+ try:
210
+ if block is None:
211
+ job_id = await self._job_queue.get()
212
+ else:
213
+ job_id = await wait_for(self._job_queue.get(), timeout=block / 1000.0)
214
+
215
+ self._job_queue.task_done()
216
+ return job_id, "memory-msg-id"
217
+ except AsyncTimeoutError:
218
+ return None
183
219
 
184
220
  async def ack_job(self, message_id: str) -> None:
185
221
  """No-op for MemoryStorage as it doesn't support persistent streams."""
@@ -296,6 +332,16 @@ class MemoryStorage(StorageBackend):
296
332
  async with self._lock:
297
333
  return self._worker_tokens.get(worker_id)
298
334
 
335
+ async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
336
+ async with self._lock:
337
+ self._generic_keys[f"sts:{token}"] = worker_id
338
+ self._generic_key_ttls[f"sts:{token}"] = monotonic() + ttl
339
+
340
+ async def verify_worker_access_token(self, token: str) -> str | None:
341
+ async with self._lock:
342
+ await self._clean_expired()
343
+ return self._generic_keys.get(f"sts:{token}")
344
+
299
345
  async def set_task_cancellation_flag(self, task_id: str) -> None:
300
346
  key = f"task_cancel:{task_id}"
301
347
  await self.increment_key_with_ttl(key, 3600)
@@ -334,3 +380,6 @@ class MemoryStorage(StorageBackend):
334
380
  del self._locks[key]
335
381
  return True
336
382
  return False
383
+
384
+ async def ping(self) -> bool:
385
+ return True