avtomatika 1.0b7__py3-none-any.whl → 1.0b9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika/api/handlers.py +3 -255
- avtomatika/api/routes.py +42 -63
- avtomatika/app_keys.py +2 -0
- avtomatika/config.py +18 -0
- avtomatika/constants.py +2 -26
- avtomatika/data_types.py +4 -23
- avtomatika/dispatcher.py +9 -26
- avtomatika/engine.py +127 -6
- avtomatika/executor.py +53 -25
- avtomatika/health_checker.py +23 -5
- avtomatika/history/base.py +60 -6
- avtomatika/history/noop.py +18 -7
- avtomatika/history/postgres.py +8 -6
- avtomatika/history/sqlite.py +7 -5
- avtomatika/metrics.py +1 -1
- avtomatika/reputation.py +46 -40
- avtomatika/s3.py +379 -0
- avtomatika/security.py +56 -74
- avtomatika/services/__init__.py +0 -0
- avtomatika/services/worker_service.py +266 -0
- avtomatika/storage/base.py +55 -4
- avtomatika/storage/memory.py +56 -7
- avtomatika/storage/redis.py +214 -251
- avtomatika/utils/webhook_sender.py +44 -2
- avtomatika/watcher.py +35 -35
- avtomatika/ws_manager.py +10 -9
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/METADATA +81 -7
- avtomatika-1.0b9.dist-info/RECORD +48 -0
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/WHEEL +1 -1
- avtomatika-1.0b7.dist-info/RECORD +0 -45
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/licenses/LICENSE +0 -0
- {avtomatika-1.0b7.dist-info → avtomatika-1.0b9.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from hashlib import sha256
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from secrets import token_urlsafe
|
|
4
|
+
from time import monotonic
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
from rxon.models import TokenResponse
|
|
8
|
+
from rxon.validators import validate_identifier
|
|
9
|
+
|
|
10
|
+
from ..app_keys import S3_SERVICE_KEY
|
|
11
|
+
from ..config import Config
|
|
12
|
+
from ..constants import (
|
|
13
|
+
ERROR_CODE_INTEGRITY_MISMATCH,
|
|
14
|
+
ERROR_CODE_INVALID_INPUT,
|
|
15
|
+
ERROR_CODE_PERMANENT,
|
|
16
|
+
ERROR_CODE_TRANSIENT,
|
|
17
|
+
JOB_STATUS_CANCELLED,
|
|
18
|
+
JOB_STATUS_FAILED,
|
|
19
|
+
JOB_STATUS_QUARANTINED,
|
|
20
|
+
JOB_STATUS_RUNNING,
|
|
21
|
+
JOB_STATUS_WAITING_FOR_PARALLEL,
|
|
22
|
+
TASK_STATUS_CANCELLED,
|
|
23
|
+
TASK_STATUS_FAILURE,
|
|
24
|
+
TASK_STATUS_SUCCESS,
|
|
25
|
+
)
|
|
26
|
+
from ..history.base import HistoryStorageBase
|
|
27
|
+
from ..storage.base import StorageBackend
|
|
28
|
+
|
|
29
|
+
logger = getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class WorkerService:
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
storage: StorageBackend,
|
|
36
|
+
history_storage: HistoryStorageBase,
|
|
37
|
+
config: Config,
|
|
38
|
+
engine: Any,
|
|
39
|
+
):
|
|
40
|
+
self.storage = storage
|
|
41
|
+
self.history_storage = history_storage
|
|
42
|
+
self.config = config
|
|
43
|
+
self.engine = engine
|
|
44
|
+
|
|
45
|
+
async def register_worker(self, worker_data: dict[str, Any]) -> None:
|
|
46
|
+
"""
|
|
47
|
+
Registers a new worker.
|
|
48
|
+
:param worker_data: Raw dictionary from request (to be validated/converted to Model later)
|
|
49
|
+
"""
|
|
50
|
+
worker_id = worker_data.get("worker_id")
|
|
51
|
+
if not worker_id:
|
|
52
|
+
raise ValueError("Missing required field: worker_id")
|
|
53
|
+
|
|
54
|
+
validate_identifier(worker_id, "worker_id")
|
|
55
|
+
|
|
56
|
+
# S3 Consistency Check
|
|
57
|
+
s3_service = self.engine.app.get(S3_SERVICE_KEY)
|
|
58
|
+
if s3_service:
|
|
59
|
+
orchestrator_s3_hash = s3_service.get_config_hash()
|
|
60
|
+
worker_capabilities = worker_data.get("capabilities", {})
|
|
61
|
+
worker_s3_hash = worker_capabilities.get("s3_config_hash")
|
|
62
|
+
|
|
63
|
+
if orchestrator_s3_hash and worker_s3_hash and orchestrator_s3_hash != worker_s3_hash:
|
|
64
|
+
logger.warning(
|
|
65
|
+
f"Worker '{worker_id}' has a different S3 configuration hash! "
|
|
66
|
+
f"Orchestrator: {orchestrator_s3_hash}, Worker: {worker_s3_hash}. "
|
|
67
|
+
"This may lead to 'split-brain' storage issues."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
|
|
71
|
+
await self.storage.register_worker(worker_id, worker_data, ttl)
|
|
72
|
+
|
|
73
|
+
logger.info(f"Worker '{worker_id}' registered with info: {worker_data}")
|
|
74
|
+
|
|
75
|
+
await self.history_storage.log_worker_event(
|
|
76
|
+
{
|
|
77
|
+
"worker_id": worker_id,
|
|
78
|
+
"event_type": "registered",
|
|
79
|
+
"worker_info_snapshot": worker_data,
|
|
80
|
+
}
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
async def get_next_task(self, worker_id: str) -> Optional[dict[str, Any]]:
|
|
84
|
+
"""
|
|
85
|
+
Retrieves the next task for a worker using long-polling configuration.
|
|
86
|
+
"""
|
|
87
|
+
logger.debug(f"Worker {worker_id} is requesting a new task.")
|
|
88
|
+
return await self.storage.dequeue_task_for_worker(worker_id, self.config.WORKER_POLL_TIMEOUT_SECONDS)
|
|
89
|
+
|
|
90
|
+
async def process_task_result(self, result_payload: dict[str, Any], authenticated_worker_id: str) -> str:
|
|
91
|
+
"""
|
|
92
|
+
Processes a task result submitted by a worker.
|
|
93
|
+
Returns a status string constant.
|
|
94
|
+
"""
|
|
95
|
+
payload_worker_id = result_payload.get("worker_id")
|
|
96
|
+
|
|
97
|
+
if payload_worker_id and payload_worker_id != authenticated_worker_id:
|
|
98
|
+
raise PermissionError(
|
|
99
|
+
f"Forbidden: Authenticated worker '{authenticated_worker_id}' "
|
|
100
|
+
f"cannot submit results for another worker '{payload_worker_id}'."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
job_id = result_payload.get("job_id")
|
|
104
|
+
task_id = result_payload.get("task_id")
|
|
105
|
+
result_data = result_payload.get("result", {})
|
|
106
|
+
|
|
107
|
+
if not job_id or not task_id:
|
|
108
|
+
raise ValueError("job_id and task_id are required")
|
|
109
|
+
|
|
110
|
+
job_state = await self.storage.get_job_state(job_id)
|
|
111
|
+
if not job_state:
|
|
112
|
+
raise LookupError("Job not found")
|
|
113
|
+
|
|
114
|
+
if job_state.get("status") == JOB_STATUS_WAITING_FOR_PARALLEL:
|
|
115
|
+
await self.storage.remove_job_from_watch(f"{job_id}:{task_id}")
|
|
116
|
+
job_state.setdefault("aggregation_results", {})[task_id] = result_data
|
|
117
|
+
|
|
118
|
+
branches = job_state.setdefault("active_branches", [])
|
|
119
|
+
if task_id in branches:
|
|
120
|
+
branches.remove(task_id)
|
|
121
|
+
|
|
122
|
+
if not branches:
|
|
123
|
+
logger.info(f"All parallel branches for job {job_id} have completed.")
|
|
124
|
+
job_state["status"] = JOB_STATUS_RUNNING
|
|
125
|
+
job_state["current_state"] = job_state["aggregation_target"]
|
|
126
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
127
|
+
await self.storage.enqueue_job(job_id)
|
|
128
|
+
else:
|
|
129
|
+
logger.info(
|
|
130
|
+
f"Branch {task_id} for job {job_id} completed. Waiting for {len(branches)} more.",
|
|
131
|
+
)
|
|
132
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
133
|
+
|
|
134
|
+
return "parallel_branch_result_accepted"
|
|
135
|
+
|
|
136
|
+
await self.storage.remove_job_from_watch(job_id)
|
|
137
|
+
|
|
138
|
+
now = monotonic()
|
|
139
|
+
dispatched_at = job_state.get("task_dispatched_at", now)
|
|
140
|
+
duration_ms = int((now - dispatched_at) * 1000)
|
|
141
|
+
|
|
142
|
+
await self.history_storage.log_job_event(
|
|
143
|
+
{
|
|
144
|
+
"job_id": job_id,
|
|
145
|
+
"state": job_state.get("current_state"),
|
|
146
|
+
"event_type": "task_finished",
|
|
147
|
+
"duration_ms": duration_ms,
|
|
148
|
+
"worker_id": authenticated_worker_id,
|
|
149
|
+
"context_snapshot": {**job_state, "result": result_data},
|
|
150
|
+
},
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
result_status = result_data.get("status", TASK_STATUS_SUCCESS) # Default to success? Constant?
|
|
154
|
+
|
|
155
|
+
if result_status == TASK_STATUS_FAILURE:
|
|
156
|
+
return await self._handle_task_failure(job_state, task_id, result_data)
|
|
157
|
+
|
|
158
|
+
if result_status == TASK_STATUS_CANCELLED:
|
|
159
|
+
logger.info(f"Task {task_id} for job {job_id} was cancelled by worker.")
|
|
160
|
+
job_state["status"] = JOB_STATUS_CANCELLED
|
|
161
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
162
|
+
|
|
163
|
+
transitions = job_state.get("current_task_transitions", {})
|
|
164
|
+
if next_state := transitions.get("cancelled"):
|
|
165
|
+
job_state["current_state"] = next_state
|
|
166
|
+
job_state["status"] = JOB_STATUS_RUNNING
|
|
167
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
168
|
+
await self.storage.enqueue_job(job_id)
|
|
169
|
+
return "result_accepted_cancelled"
|
|
170
|
+
|
|
171
|
+
transitions = job_state.get("current_task_transitions", {})
|
|
172
|
+
result_status = result_data.get("status", TASK_STATUS_SUCCESS)
|
|
173
|
+
next_state = transitions.get(result_status)
|
|
174
|
+
|
|
175
|
+
if next_state:
|
|
176
|
+
logger.info(f"Job {job_id} transitioning based on worker status '{result_status}' to state '{next_state}'")
|
|
177
|
+
|
|
178
|
+
worker_data_content = result_data.get("data")
|
|
179
|
+
if worker_data_content and isinstance(worker_data_content, dict):
|
|
180
|
+
if "state_history" not in job_state:
|
|
181
|
+
job_state["state_history"] = {}
|
|
182
|
+
job_state["state_history"].update(worker_data_content)
|
|
183
|
+
|
|
184
|
+
data_metadata = result_payload.get("data_metadata")
|
|
185
|
+
if data_metadata:
|
|
186
|
+
if "data_metadata" not in job_state:
|
|
187
|
+
job_state["data_metadata"] = {}
|
|
188
|
+
job_state["data_metadata"].update(data_metadata)
|
|
189
|
+
logger.debug(f"Stored data metadata for job {job_id}: {list(data_metadata.keys())}")
|
|
190
|
+
|
|
191
|
+
job_state["current_state"] = next_state
|
|
192
|
+
job_state["status"] = JOB_STATUS_RUNNING
|
|
193
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
194
|
+
await self.storage.enqueue_job(job_id)
|
|
195
|
+
return "result_accepted_success"
|
|
196
|
+
else:
|
|
197
|
+
logger.error(f"Job {job_id} failed. Worker returned unhandled status '{result_status}'.")
|
|
198
|
+
job_state["status"] = JOB_STATUS_FAILED
|
|
199
|
+
job_state["error_message"] = f"Worker returned unhandled status: {result_status}"
|
|
200
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
201
|
+
return "result_accepted_failure"
|
|
202
|
+
|
|
203
|
+
async def _handle_task_failure(self, job_state: dict, task_id: str, result_data: dict) -> str:
|
|
204
|
+
error_details = result_data.get("error", {})
|
|
205
|
+
error_type = ERROR_CODE_TRANSIENT
|
|
206
|
+
error_message = "No error details provided."
|
|
207
|
+
|
|
208
|
+
if isinstance(error_details, dict):
|
|
209
|
+
error_type = error_details.get("code", ERROR_CODE_TRANSIENT)
|
|
210
|
+
error_message = error_details.get("message", "No error message provided.")
|
|
211
|
+
elif isinstance(error_details, str):
|
|
212
|
+
error_message = error_details
|
|
213
|
+
|
|
214
|
+
job_id = job_state["id"]
|
|
215
|
+
logger.warning(f"Task {task_id} for job {job_id} failed with error type '{error_type}'.")
|
|
216
|
+
|
|
217
|
+
if error_type == ERROR_CODE_PERMANENT:
|
|
218
|
+
job_state["status"] = JOB_STATUS_QUARANTINED
|
|
219
|
+
job_state["error_message"] = f"Task failed with permanent error: {error_message}"
|
|
220
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
221
|
+
await self.storage.quarantine_job(job_id)
|
|
222
|
+
elif error_type == ERROR_CODE_INVALID_INPUT:
|
|
223
|
+
job_state["status"] = JOB_STATUS_FAILED
|
|
224
|
+
job_state["error_message"] = f"Task failed due to invalid input: {error_message}"
|
|
225
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
226
|
+
elif error_type == ERROR_CODE_INTEGRITY_MISMATCH:
|
|
227
|
+
job_state["status"] = JOB_STATUS_FAILED
|
|
228
|
+
job_state["error_message"] = f"Task failed due to data integrity mismatch: {error_message}"
|
|
229
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
230
|
+
logger.critical(f"Data integrity mismatch detected for job {job_id}: {error_message}")
|
|
231
|
+
else:
|
|
232
|
+
await self.engine.handle_task_failure(job_state, task_id, error_message)
|
|
233
|
+
|
|
234
|
+
return "result_accepted_failure"
|
|
235
|
+
|
|
236
|
+
async def issue_access_token(self, worker_id: str) -> TokenResponse:
|
|
237
|
+
"""Generates and stores a temporary access token."""
|
|
238
|
+
raw_token = token_urlsafe(32)
|
|
239
|
+
token_hash = sha256(raw_token.encode()).hexdigest()
|
|
240
|
+
ttl = 3600
|
|
241
|
+
|
|
242
|
+
await self.storage.save_worker_access_token(worker_id, token_hash, ttl)
|
|
243
|
+
logger.info(f"Issued temporary access token for worker {worker_id}")
|
|
244
|
+
|
|
245
|
+
return TokenResponse(access_token=raw_token, expires_in=ttl, worker_id=worker_id)
|
|
246
|
+
|
|
247
|
+
async def update_worker_heartbeat(
|
|
248
|
+
self, worker_id: str, update_data: Optional[dict[str, Any]]
|
|
249
|
+
) -> Optional[dict[str, Any]]:
|
|
250
|
+
"""Updates worker TTL and status."""
|
|
251
|
+
ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
|
|
252
|
+
|
|
253
|
+
if update_data:
|
|
254
|
+
updated_worker = await self.storage.update_worker_status(worker_id, update_data, ttl)
|
|
255
|
+
if updated_worker:
|
|
256
|
+
await self.history_storage.log_worker_event(
|
|
257
|
+
{
|
|
258
|
+
"worker_id": worker_id,
|
|
259
|
+
"event_type": "status_update",
|
|
260
|
+
"worker_info_snapshot": updated_worker,
|
|
261
|
+
},
|
|
262
|
+
)
|
|
263
|
+
return updated_worker
|
|
264
|
+
else:
|
|
265
|
+
refreshed = await self.storage.refresh_worker_ttl(worker_id, ttl)
|
|
266
|
+
return {"status": "ttl_refreshed"} if refreshed else None
|
avtomatika/storage/base.py
CHANGED
|
@@ -142,6 +142,37 @@ class StorageBackend(ABC):
|
|
|
142
142
|
"""
|
|
143
143
|
raise NotImplementedError
|
|
144
144
|
|
|
145
|
+
@abstractmethod
|
|
146
|
+
async def get_active_worker_ids(self) -> list[str]:
|
|
147
|
+
"""Returns a list of IDs for all currently active workers.
|
|
148
|
+
|
|
149
|
+
:return: A list of worker ID strings.
|
|
150
|
+
"""
|
|
151
|
+
raise NotImplementedError
|
|
152
|
+
|
|
153
|
+
@abstractmethod
|
|
154
|
+
async def cleanup_expired_workers(self) -> None:
|
|
155
|
+
"""Maintenance task to clean up internal indexes from expired worker entries."""
|
|
156
|
+
raise NotImplementedError
|
|
157
|
+
|
|
158
|
+
@abstractmethod
|
|
159
|
+
async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
|
|
160
|
+
"""Bulk retrieves worker info for a list of IDs.
|
|
161
|
+
|
|
162
|
+
:param worker_ids: List of worker identifiers.
|
|
163
|
+
:return: List of worker info dictionaries.
|
|
164
|
+
"""
|
|
165
|
+
raise NotImplementedError
|
|
166
|
+
|
|
167
|
+
@abstractmethod
|
|
168
|
+
async def find_workers_for_task(self, task_type: str) -> list[str]:
|
|
169
|
+
"""Finds idle workers that support the given task.
|
|
170
|
+
|
|
171
|
+
:param task_type: The type of task to find workers for.
|
|
172
|
+
:return: A list of worker IDs that are idle and support the task.
|
|
173
|
+
"""
|
|
174
|
+
raise NotImplementedError
|
|
175
|
+
|
|
145
176
|
@abstractmethod
|
|
146
177
|
async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
|
|
147
178
|
"""Add a job to the list for timeout tracking.
|
|
@@ -152,9 +183,10 @@ class StorageBackend(ABC):
|
|
|
152
183
|
raise NotImplementedError
|
|
153
184
|
|
|
154
185
|
@abstractmethod
|
|
155
|
-
async def get_timed_out_jobs(self) -> list[str]:
|
|
186
|
+
async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
|
|
156
187
|
"""Get a list of job IDs that are overdue and remove them from the tracking list.
|
|
157
188
|
|
|
189
|
+
:param limit: Maximum number of jobs to retrieve.
|
|
158
190
|
:return: A list of overdue job IDs.
|
|
159
191
|
"""
|
|
160
192
|
raise NotImplementedError
|
|
@@ -165,9 +197,10 @@ class StorageBackend(ABC):
|
|
|
165
197
|
raise NotImplementedError
|
|
166
198
|
|
|
167
199
|
@abstractmethod
|
|
168
|
-
async def dequeue_job(self) -> tuple[str, str] | None:
|
|
200
|
+
async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
|
|
169
201
|
"""Retrieve a job ID and its message ID from the execution queue.
|
|
170
202
|
|
|
203
|
+
:param block: Milliseconds to block if no message is available. None for non-blocking.
|
|
171
204
|
:return: A tuple of (job_id, message_id) or None if the timeout has expired.
|
|
172
205
|
"""
|
|
173
206
|
raise NotImplementedError
|
|
@@ -250,7 +283,7 @@ class StorageBackend(ABC):
|
|
|
250
283
|
raise NotImplementedError
|
|
251
284
|
|
|
252
285
|
@abstractmethod
|
|
253
|
-
async def set_worker_token(self, worker_id: str, token: str):
|
|
286
|
+
async def set_worker_token(self, worker_id: str, token: str) -> None:
|
|
254
287
|
"""Saves an individual token for a specific worker."""
|
|
255
288
|
raise NotImplementedError
|
|
256
289
|
|
|
@@ -259,13 +292,23 @@ class StorageBackend(ABC):
|
|
|
259
292
|
"""Retrieves an individual token for a specific worker."""
|
|
260
293
|
raise NotImplementedError
|
|
261
294
|
|
|
295
|
+
@abstractmethod
|
|
296
|
+
async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
|
|
297
|
+
"""Saves a temporary access token for a worker (STS)."""
|
|
298
|
+
raise NotImplementedError
|
|
299
|
+
|
|
300
|
+
@abstractmethod
|
|
301
|
+
async def verify_worker_access_token(self, token: str) -> str | None:
|
|
302
|
+
"""Verifies a temporary access token and returns the associated worker_id if valid."""
|
|
303
|
+
raise NotImplementedError
|
|
304
|
+
|
|
262
305
|
@abstractmethod
|
|
263
306
|
async def get_worker_info(self, worker_id: str) -> dict[str, Any] | None:
|
|
264
307
|
"""Get complete information about a worker by its ID."""
|
|
265
308
|
raise NotImplementedError
|
|
266
309
|
|
|
267
310
|
@abstractmethod
|
|
268
|
-
async def flush_all(self):
|
|
311
|
+
async def flush_all(self) -> None:
|
|
269
312
|
"""Completely clears the storage. Used mainly for tests."""
|
|
270
313
|
raise NotImplementedError
|
|
271
314
|
|
|
@@ -312,3 +355,11 @@ class StorageBackend(ABC):
|
|
|
312
355
|
:return: True if the lock was successfully released, False otherwise.
|
|
313
356
|
"""
|
|
314
357
|
raise NotImplementedError
|
|
358
|
+
|
|
359
|
+
@abstractmethod
|
|
360
|
+
async def ping(self) -> bool:
|
|
361
|
+
"""Checks connection to the storage backend.
|
|
362
|
+
|
|
363
|
+
:return: True if storage is accessible, False otherwise.
|
|
364
|
+
"""
|
|
365
|
+
raise NotImplementedError
|
avtomatika/storage/memory.py
CHANGED
|
@@ -154,6 +154,33 @@ class MemoryStorage(StorageBackend):
|
|
|
154
154
|
)
|
|
155
155
|
return active_workers
|
|
156
156
|
|
|
157
|
+
async def get_active_worker_ids(self) -> list[str]:
|
|
158
|
+
async with self._lock:
|
|
159
|
+
now = monotonic()
|
|
160
|
+
return [worker_id for worker_id, ttl in self._worker_ttls.items() if ttl > now]
|
|
161
|
+
|
|
162
|
+
async def cleanup_expired_workers(self) -> None:
|
|
163
|
+
async with self._lock:
|
|
164
|
+
await self._clean_expired()
|
|
165
|
+
|
|
166
|
+
async def get_workers(self, worker_ids: list[str]) -> list[dict[str, Any]]:
|
|
167
|
+
async with self._lock:
|
|
168
|
+
return [self._workers[wid] for wid in worker_ids if wid in self._workers]
|
|
169
|
+
|
|
170
|
+
async def find_workers_for_task(self, task_type: str) -> list[str]:
|
|
171
|
+
"""Finds idle workers supporting the task (O(N) for memory storage)."""
|
|
172
|
+
async with self._lock:
|
|
173
|
+
now = monotonic()
|
|
174
|
+
candidates = []
|
|
175
|
+
for worker_id, info in self._workers.items():
|
|
176
|
+
if self._worker_ttls.get(worker_id, 0) <= now:
|
|
177
|
+
continue
|
|
178
|
+
if info.get("status", "idle") != "idle":
|
|
179
|
+
continue
|
|
180
|
+
if task_type in info.get("supported_tasks", []):
|
|
181
|
+
candidates.append(worker_id)
|
|
182
|
+
return candidates
|
|
183
|
+
|
|
157
184
|
async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
|
|
158
185
|
async with self._lock:
|
|
159
186
|
self._watched_jobs[job_id] = timeout_at
|
|
@@ -162,10 +189,11 @@ class MemoryStorage(StorageBackend):
|
|
|
162
189
|
async with self._lock:
|
|
163
190
|
self._watched_jobs.pop(job_id, None)
|
|
164
191
|
|
|
165
|
-
async def get_timed_out_jobs(self) -> list[str]:
|
|
192
|
+
async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
|
|
166
193
|
async with self._lock:
|
|
167
194
|
now = monotonic()
|
|
168
195
|
timed_out_ids = [job_id for job_id, timeout_at in self._watched_jobs.items() if timeout_at <= now]
|
|
196
|
+
timed_out_ids = timed_out_ids[:limit]
|
|
169
197
|
for job_id in timed_out_ids:
|
|
170
198
|
self._watched_jobs.pop(job_id, None)
|
|
171
199
|
return timed_out_ids
|
|
@@ -173,13 +201,21 @@ class MemoryStorage(StorageBackend):
|
|
|
173
201
|
async def enqueue_job(self, job_id: str) -> None:
|
|
174
202
|
await self._job_queue.put(job_id)
|
|
175
203
|
|
|
176
|
-
async def dequeue_job(self) -> tuple[str, str] | None:
|
|
177
|
-
"""Waits
|
|
178
|
-
|
|
204
|
+
async def dequeue_job(self, block: int | None = None) -> tuple[str, str] | None:
|
|
205
|
+
"""Waits for a job ID from the queue.
|
|
206
|
+
If block is None, waits indefinitely.
|
|
207
|
+
If block is int, waits for that many milliseconds.
|
|
179
208
|
"""
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
209
|
+
try:
|
|
210
|
+
if block is None:
|
|
211
|
+
job_id = await self._job_queue.get()
|
|
212
|
+
else:
|
|
213
|
+
job_id = await wait_for(self._job_queue.get(), timeout=block / 1000.0)
|
|
214
|
+
|
|
215
|
+
self._job_queue.task_done()
|
|
216
|
+
return job_id, "memory-msg-id"
|
|
217
|
+
except AsyncTimeoutError:
|
|
218
|
+
return None
|
|
183
219
|
|
|
184
220
|
async def ack_job(self, message_id: str) -> None:
|
|
185
221
|
"""No-op for MemoryStorage as it doesn't support persistent streams."""
|
|
@@ -296,6 +332,16 @@ class MemoryStorage(StorageBackend):
|
|
|
296
332
|
async with self._lock:
|
|
297
333
|
return self._worker_tokens.get(worker_id)
|
|
298
334
|
|
|
335
|
+
async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
|
|
336
|
+
async with self._lock:
|
|
337
|
+
self._generic_keys[f"sts:{token}"] = worker_id
|
|
338
|
+
self._generic_key_ttls[f"sts:{token}"] = monotonic() + ttl
|
|
339
|
+
|
|
340
|
+
async def verify_worker_access_token(self, token: str) -> str | None:
|
|
341
|
+
async with self._lock:
|
|
342
|
+
await self._clean_expired()
|
|
343
|
+
return self._generic_keys.get(f"sts:{token}")
|
|
344
|
+
|
|
299
345
|
async def set_task_cancellation_flag(self, task_id: str) -> None:
|
|
300
346
|
key = f"task_cancel:{task_id}"
|
|
301
347
|
await self.increment_key_with_ttl(key, 3600)
|
|
@@ -334,3 +380,6 @@ class MemoryStorage(StorageBackend):
|
|
|
334
380
|
del self._locks[key]
|
|
335
381
|
return True
|
|
336
382
|
return False
|
|
383
|
+
|
|
384
|
+
async def ping(self) -> bool:
|
|
385
|
+
return True
|