avtomatika 1.0b8__py3-none-any.whl → 1.0b10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- avtomatika/api/handlers.py +5 -257
- avtomatika/api/routes.py +42 -63
- avtomatika/api.html +1 -1
- avtomatika/app_keys.py +1 -0
- avtomatika/blueprint.py +3 -2
- avtomatika/config.py +8 -0
- avtomatika/constants.py +75 -25
- avtomatika/data_types.py +2 -22
- avtomatika/dispatcher.py +4 -0
- avtomatika/engine.py +119 -7
- avtomatika/executor.py +19 -19
- avtomatika/logging_config.py +16 -7
- avtomatika/s3.py +96 -40
- avtomatika/scheduler_config_loader.py +5 -2
- avtomatika/security.py +56 -74
- avtomatika/services/__init__.py +0 -0
- avtomatika/services/worker_service.py +267 -0
- avtomatika/storage/base.py +10 -0
- avtomatika/storage/memory.py +15 -4
- avtomatika/storage/redis.py +42 -11
- avtomatika/telemetry.py +8 -7
- avtomatika/utils/webhook_sender.py +3 -3
- avtomatika/watcher.py +4 -2
- avtomatika/ws_manager.py +16 -8
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/METADATA +47 -15
- avtomatika-1.0b10.dist-info/RECORD +48 -0
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/WHEEL +1 -1
- avtomatika-1.0b8.dist-info/RECORD +0 -46
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/licenses/LICENSE +0 -0
- {avtomatika-1.0b8.dist-info → avtomatika-1.0b10.dist-info}/top_level.txt +0 -0
avtomatika/security.py
CHANGED
|
@@ -10,6 +10,62 @@ from .storage.base import StorageBackend
|
|
|
10
10
|
Handler = Callable[[web.Request], Awaitable[web.Response]]
|
|
11
11
|
|
|
12
12
|
|
|
13
|
+
async def verify_worker_auth(
|
|
14
|
+
storage: StorageBackend,
|
|
15
|
+
config: Config,
|
|
16
|
+
token: str | None,
|
|
17
|
+
cert_identity: str | None,
|
|
18
|
+
worker_id_hint: str | None,
|
|
19
|
+
) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Verifies worker authentication using token or mTLS.
|
|
22
|
+
Returns authenticated worker_id.
|
|
23
|
+
Raises ValueError (400), PermissionError (401/403) on failure.
|
|
24
|
+
"""
|
|
25
|
+
# mTLS Check
|
|
26
|
+
if cert_identity:
|
|
27
|
+
if worker_id_hint and cert_identity != worker_id_hint:
|
|
28
|
+
raise PermissionError(
|
|
29
|
+
f"Unauthorized: Certificate CN '{cert_identity}' does not match worker_id '{worker_id_hint}'"
|
|
30
|
+
)
|
|
31
|
+
return cert_identity
|
|
32
|
+
|
|
33
|
+
# Token Check
|
|
34
|
+
if not token:
|
|
35
|
+
raise PermissionError(f"Missing {AUTH_HEADER_WORKER} header or client certificate")
|
|
36
|
+
|
|
37
|
+
hashed_provided_token = sha256(token.encode()).hexdigest()
|
|
38
|
+
|
|
39
|
+
# STS Access Token
|
|
40
|
+
token_worker_id = await storage.verify_worker_access_token(hashed_provided_token)
|
|
41
|
+
if token_worker_id:
|
|
42
|
+
if worker_id_hint and token_worker_id != worker_id_hint:
|
|
43
|
+
raise PermissionError(
|
|
44
|
+
f"Unauthorized: Access Token belongs to '{token_worker_id}', but request is for '{worker_id_hint}'"
|
|
45
|
+
)
|
|
46
|
+
return token_worker_id
|
|
47
|
+
|
|
48
|
+
# Individual/Global Token
|
|
49
|
+
if not worker_id_hint:
|
|
50
|
+
if config.GLOBAL_WORKER_TOKEN and token == config.GLOBAL_WORKER_TOKEN:
|
|
51
|
+
return "unknown_authenticated_by_global_token"
|
|
52
|
+
|
|
53
|
+
raise PermissionError("Unauthorized: Invalid token or missing worker_id hint")
|
|
54
|
+
|
|
55
|
+
# Individual Token for specific worker
|
|
56
|
+
expected_token_hash = await storage.get_worker_token(worker_id_hint)
|
|
57
|
+
if expected_token_hash:
|
|
58
|
+
if hashed_provided_token == expected_token_hash:
|
|
59
|
+
return worker_id_hint
|
|
60
|
+
raise PermissionError("Unauthorized: Invalid individual worker token")
|
|
61
|
+
|
|
62
|
+
# Global Token Fallback
|
|
63
|
+
if config.GLOBAL_WORKER_TOKEN and token == config.GLOBAL_WORKER_TOKEN:
|
|
64
|
+
return worker_id_hint
|
|
65
|
+
|
|
66
|
+
raise PermissionError("Unauthorized: No valid token found")
|
|
67
|
+
|
|
68
|
+
|
|
13
69
|
def client_auth_middleware_factory(
|
|
14
70
|
storage: StorageBackend,
|
|
15
71
|
) -> Any:
|
|
@@ -38,77 +94,3 @@ def client_auth_middleware_factory(
|
|
|
38
94
|
return await handler(request)
|
|
39
95
|
|
|
40
96
|
return middleware
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def worker_auth_middleware_factory(
|
|
44
|
-
storage: StorageBackend,
|
|
45
|
-
config: Config,
|
|
46
|
-
) -> Any:
|
|
47
|
-
"""
|
|
48
|
-
Middleware factory for worker authentication.
|
|
49
|
-
It supports both individual tokens and a global fallback token for backward compatibility.
|
|
50
|
-
It also attaches the authenticated worker_id to the request.
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
@web.middleware
|
|
54
|
-
async def middleware(request: web.Request, handler: Handler) -> web.Response:
|
|
55
|
-
provided_token = request.headers.get(AUTH_HEADER_WORKER)
|
|
56
|
-
if not provided_token:
|
|
57
|
-
return web.json_response(
|
|
58
|
-
{"error": f"Missing {AUTH_HEADER_WORKER} header"},
|
|
59
|
-
status=401,
|
|
60
|
-
)
|
|
61
|
-
|
|
62
|
-
worker_id = request.match_info.get("worker_id")
|
|
63
|
-
data = None
|
|
64
|
-
|
|
65
|
-
# For specific endpoints, worker_id is in the body.
|
|
66
|
-
# We need to read the body here, which can be tricky as it's a stream.
|
|
67
|
-
# We clone the request to allow the handler to read the body again.
|
|
68
|
-
if not worker_id and (request.path.endswith("/register") or request.path.endswith("/tasks/result")):
|
|
69
|
-
try:
|
|
70
|
-
cloned_request = request.clone()
|
|
71
|
-
data = await cloned_request.json()
|
|
72
|
-
worker_id = data.get("worker_id")
|
|
73
|
-
# Attach the parsed data to the request so the handler doesn't need to re-parse
|
|
74
|
-
if request.path.endswith("/register"):
|
|
75
|
-
request["worker_registration_data"] = data
|
|
76
|
-
except Exception:
|
|
77
|
-
return web.json_response({"error": "Invalid JSON body"}, status=400)
|
|
78
|
-
|
|
79
|
-
# If no worker_id could be determined from path or body, we can only validate against the global token.
|
|
80
|
-
if not worker_id:
|
|
81
|
-
if provided_token == config.GLOBAL_WORKER_TOKEN:
|
|
82
|
-
# We don't know the worker_id, so we can't attach it.
|
|
83
|
-
return await handler(request)
|
|
84
|
-
else:
|
|
85
|
-
return web.json_response(
|
|
86
|
-
{"error": "Unauthorized: Invalid token or missing worker_id"},
|
|
87
|
-
status=401,
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
# --- Individual Token Check ---
|
|
91
|
-
expected_token_hash = await storage.get_worker_token(worker_id)
|
|
92
|
-
if expected_token_hash:
|
|
93
|
-
hashed_provided_token = sha256(provided_token.encode()).hexdigest()
|
|
94
|
-
if hashed_provided_token == expected_token_hash:
|
|
95
|
-
request["worker_id"] = worker_id # Attach authenticated worker_id
|
|
96
|
-
return await handler(request)
|
|
97
|
-
else:
|
|
98
|
-
# If an individual token exists, we do not fall back to the global token.
|
|
99
|
-
return web.json_response(
|
|
100
|
-
{"error": "Unauthorized: Invalid individual worker token"},
|
|
101
|
-
status=401,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
# --- Global Token Fallback ---
|
|
105
|
-
if config.GLOBAL_WORKER_TOKEN and provided_token == config.GLOBAL_WORKER_TOKEN:
|
|
106
|
-
request["worker_id"] = worker_id # Attach authenticated worker_id
|
|
107
|
-
return await handler(request)
|
|
108
|
-
|
|
109
|
-
return web.json_response(
|
|
110
|
-
{"error": "Unauthorized: No valid token found"},
|
|
111
|
-
status=401,
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
return middleware
|
|
File without changes
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
from hashlib import sha256
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from secrets import token_urlsafe
|
|
4
|
+
from time import monotonic
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
from rxon.models import TokenResponse
|
|
8
|
+
from rxon.validators import validate_identifier
|
|
9
|
+
|
|
10
|
+
from ..app_keys import S3_SERVICE_KEY
|
|
11
|
+
from ..config import Config
|
|
12
|
+
from ..constants import (
|
|
13
|
+
ERROR_CODE_DEPENDENCY,
|
|
14
|
+
ERROR_CODE_INTEGRITY_MISMATCH,
|
|
15
|
+
ERROR_CODE_INVALID_INPUT,
|
|
16
|
+
ERROR_CODE_PERMANENT,
|
|
17
|
+
ERROR_CODE_SECURITY,
|
|
18
|
+
ERROR_CODE_TRANSIENT,
|
|
19
|
+
JOB_STATUS_CANCELLED,
|
|
20
|
+
JOB_STATUS_FAILED,
|
|
21
|
+
JOB_STATUS_QUARANTINED,
|
|
22
|
+
JOB_STATUS_RUNNING,
|
|
23
|
+
JOB_STATUS_WAITING_FOR_PARALLEL,
|
|
24
|
+
TASK_STATUS_CANCELLED,
|
|
25
|
+
TASK_STATUS_FAILURE,
|
|
26
|
+
TASK_STATUS_SUCCESS,
|
|
27
|
+
)
|
|
28
|
+
from ..history.base import HistoryStorageBase
|
|
29
|
+
from ..storage.base import StorageBackend
|
|
30
|
+
|
|
31
|
+
logger = getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WorkerService:
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
storage: StorageBackend,
|
|
38
|
+
history_storage: HistoryStorageBase,
|
|
39
|
+
config: Config,
|
|
40
|
+
engine: Any,
|
|
41
|
+
):
|
|
42
|
+
self.storage = storage
|
|
43
|
+
self.history_storage = history_storage
|
|
44
|
+
self.config = config
|
|
45
|
+
self.engine = engine
|
|
46
|
+
|
|
47
|
+
async def register_worker(self, worker_data: dict[str, Any]) -> None:
|
|
48
|
+
"""
|
|
49
|
+
Registers a new worker.
|
|
50
|
+
:param worker_data: Raw dictionary from request (to be validated/converted to Model later)
|
|
51
|
+
"""
|
|
52
|
+
worker_id = worker_data.get("worker_id")
|
|
53
|
+
if not worker_id:
|
|
54
|
+
raise ValueError("Missing required field: worker_id")
|
|
55
|
+
|
|
56
|
+
validate_identifier(worker_id, "worker_id")
|
|
57
|
+
|
|
58
|
+
# S3 Consistency Check
|
|
59
|
+
s3_service = self.engine.app.get(S3_SERVICE_KEY)
|
|
60
|
+
if s3_service:
|
|
61
|
+
orchestrator_s3_hash = s3_service.get_config_hash()
|
|
62
|
+
worker_capabilities = worker_data.get("capabilities", {})
|
|
63
|
+
worker_s3_hash = worker_capabilities.get("s3_config_hash")
|
|
64
|
+
|
|
65
|
+
if orchestrator_s3_hash and worker_s3_hash and orchestrator_s3_hash != worker_s3_hash:
|
|
66
|
+
logger.warning(
|
|
67
|
+
f"Worker '{worker_id}' has a different S3 configuration hash! "
|
|
68
|
+
f"Orchestrator: {orchestrator_s3_hash}, Worker: {worker_s3_hash}. "
|
|
69
|
+
"This may lead to 'split-brain' storage issues."
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
|
|
73
|
+
await self.storage.register_worker(worker_id, worker_data, ttl)
|
|
74
|
+
|
|
75
|
+
logger.info(f"Worker '{worker_id}' registered with info: {worker_data}")
|
|
76
|
+
|
|
77
|
+
await self.history_storage.log_worker_event(
|
|
78
|
+
{
|
|
79
|
+
"worker_id": worker_id,
|
|
80
|
+
"event_type": "registered",
|
|
81
|
+
"worker_info_snapshot": worker_data,
|
|
82
|
+
}
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def get_next_task(self, worker_id: str) -> Optional[dict[str, Any]]:
|
|
86
|
+
"""
|
|
87
|
+
Retrieves the next task for a worker using long-polling configuration.
|
|
88
|
+
"""
|
|
89
|
+
logger.debug(f"Worker {worker_id} is requesting a new task.")
|
|
90
|
+
return await self.storage.dequeue_task_for_worker(worker_id, self.config.WORKER_POLL_TIMEOUT_SECONDS)
|
|
91
|
+
|
|
92
|
+
async def process_task_result(self, result_payload: dict[str, Any], authenticated_worker_id: str) -> str:
|
|
93
|
+
"""
|
|
94
|
+
Processes a task result submitted by a worker.
|
|
95
|
+
Returns a status string constant.
|
|
96
|
+
"""
|
|
97
|
+
payload_worker_id = result_payload.get("worker_id")
|
|
98
|
+
|
|
99
|
+
if payload_worker_id and payload_worker_id != authenticated_worker_id:
|
|
100
|
+
raise PermissionError(
|
|
101
|
+
f"Forbidden: Authenticated worker '{authenticated_worker_id}' "
|
|
102
|
+
f"cannot submit results for another worker '{payload_worker_id}'."
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
job_id = result_payload.get("job_id")
|
|
106
|
+
task_id = result_payload.get("task_id")
|
|
107
|
+
result_data = result_payload.get("result", {})
|
|
108
|
+
|
|
109
|
+
if not job_id or not task_id:
|
|
110
|
+
raise ValueError("job_id and task_id are required")
|
|
111
|
+
|
|
112
|
+
job_state = await self.storage.get_job_state(job_id)
|
|
113
|
+
if not job_state:
|
|
114
|
+
raise LookupError("Job not found")
|
|
115
|
+
|
|
116
|
+
if job_state.get("status") == JOB_STATUS_WAITING_FOR_PARALLEL:
|
|
117
|
+
await self.storage.remove_job_from_watch(f"{job_id}:{task_id}")
|
|
118
|
+
job_state.setdefault("aggregation_results", {})[task_id] = result_data
|
|
119
|
+
|
|
120
|
+
branches = job_state.setdefault("active_branches", [])
|
|
121
|
+
if task_id in branches:
|
|
122
|
+
branches.remove(task_id)
|
|
123
|
+
|
|
124
|
+
if not branches:
|
|
125
|
+
logger.info(f"All parallel branches for job {job_id} have completed.")
|
|
126
|
+
job_state["status"] = JOB_STATUS_RUNNING
|
|
127
|
+
job_state["current_state"] = job_state["aggregation_target"]
|
|
128
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
129
|
+
await self.storage.enqueue_job(job_id)
|
|
130
|
+
else:
|
|
131
|
+
logger.info(
|
|
132
|
+
f"Branch {task_id} for job {job_id} completed. Waiting for {len(branches)} more.",
|
|
133
|
+
)
|
|
134
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
135
|
+
|
|
136
|
+
return "parallel_branch_result_accepted"
|
|
137
|
+
|
|
138
|
+
await self.storage.remove_job_from_watch(job_id)
|
|
139
|
+
|
|
140
|
+
now = monotonic()
|
|
141
|
+
dispatched_at = job_state.get("task_dispatched_at", now)
|
|
142
|
+
duration_ms = int((now - dispatched_at) * 1000)
|
|
143
|
+
|
|
144
|
+
await self.history_storage.log_job_event(
|
|
145
|
+
{
|
|
146
|
+
"job_id": job_id,
|
|
147
|
+
"state": job_state.get("current_state"),
|
|
148
|
+
"event_type": "task_finished",
|
|
149
|
+
"duration_ms": duration_ms,
|
|
150
|
+
"worker_id": authenticated_worker_id,
|
|
151
|
+
"context_snapshot": {**job_state, "result": result_data},
|
|
152
|
+
},
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
result_status = result_data.get("status", TASK_STATUS_SUCCESS) # Default to success? Constant?
|
|
156
|
+
|
|
157
|
+
if result_status == TASK_STATUS_FAILURE:
|
|
158
|
+
return await self._handle_task_failure(job_state, task_id, result_data)
|
|
159
|
+
|
|
160
|
+
if result_status == TASK_STATUS_CANCELLED:
|
|
161
|
+
logger.info(f"Task {task_id} for job {job_id} was cancelled by worker.")
|
|
162
|
+
job_state["status"] = JOB_STATUS_CANCELLED
|
|
163
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
164
|
+
|
|
165
|
+
transitions = job_state.get("current_task_transitions", {})
|
|
166
|
+
if next_state := transitions.get("cancelled"):
|
|
167
|
+
job_state["current_state"] = next_state
|
|
168
|
+
job_state["status"] = JOB_STATUS_RUNNING
|
|
169
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
170
|
+
await self.storage.enqueue_job(job_id)
|
|
171
|
+
return "result_accepted_cancelled"
|
|
172
|
+
|
|
173
|
+
transitions = job_state.get("current_task_transitions", {})
|
|
174
|
+
result_status = result_data.get("status", TASK_STATUS_SUCCESS)
|
|
175
|
+
next_state = transitions.get(result_status)
|
|
176
|
+
|
|
177
|
+
if next_state:
|
|
178
|
+
logger.info(f"Job {job_id} transitioning based on worker status '{result_status}' to state '{next_state}'")
|
|
179
|
+
|
|
180
|
+
worker_data_content = result_data.get("data")
|
|
181
|
+
if worker_data_content and isinstance(worker_data_content, dict):
|
|
182
|
+
if "state_history" not in job_state:
|
|
183
|
+
job_state["state_history"] = {}
|
|
184
|
+
job_state["state_history"].update(worker_data_content)
|
|
185
|
+
|
|
186
|
+
data_metadata = result_payload.get("data_metadata")
|
|
187
|
+
if data_metadata:
|
|
188
|
+
if "data_metadata" not in job_state:
|
|
189
|
+
job_state["data_metadata"] = {}
|
|
190
|
+
job_state["data_metadata"].update(data_metadata)
|
|
191
|
+
logger.debug(f"Stored data metadata for job {job_id}: {list(data_metadata.keys())}")
|
|
192
|
+
|
|
193
|
+
job_state["current_state"] = next_state
|
|
194
|
+
job_state["status"] = JOB_STATUS_RUNNING
|
|
195
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
196
|
+
await self.storage.enqueue_job(job_id)
|
|
197
|
+
return "result_accepted_success"
|
|
198
|
+
else:
|
|
199
|
+
logger.error(f"Job {job_id} failed. Worker returned unhandled status '{result_status}'.")
|
|
200
|
+
job_state["status"] = JOB_STATUS_FAILED
|
|
201
|
+
job_state["error_message"] = f"Worker returned unhandled status: {result_status}"
|
|
202
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
203
|
+
return "result_accepted_failure"
|
|
204
|
+
|
|
205
|
+
async def _handle_task_failure(self, job_state: dict, task_id: str, result_data: dict) -> str:
|
|
206
|
+
error_details = result_data.get("error", {})
|
|
207
|
+
error_type = ERROR_CODE_TRANSIENT
|
|
208
|
+
error_message = "No error details provided."
|
|
209
|
+
|
|
210
|
+
if isinstance(error_details, dict):
|
|
211
|
+
error_type = error_details.get("code", ERROR_CODE_TRANSIENT)
|
|
212
|
+
error_message = error_details.get("message", "No error message provided.")
|
|
213
|
+
elif isinstance(error_details, str):
|
|
214
|
+
error_message = error_details
|
|
215
|
+
|
|
216
|
+
job_id = job_state["id"]
|
|
217
|
+
logger.warning(f"Task {task_id} for job {job_id} failed with error type '{error_type}'.")
|
|
218
|
+
|
|
219
|
+
if error_type in (ERROR_CODE_PERMANENT, ERROR_CODE_SECURITY, ERROR_CODE_DEPENDENCY):
|
|
220
|
+
job_state["status"] = JOB_STATUS_QUARANTINED
|
|
221
|
+
job_state["error_message"] = f"Task failed with permanent error ({error_type}): {error_message}"
|
|
222
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
223
|
+
await self.storage.quarantine_job(job_id)
|
|
224
|
+
elif error_type == ERROR_CODE_INVALID_INPUT:
|
|
225
|
+
job_state["status"] = JOB_STATUS_FAILED
|
|
226
|
+
job_state["error_message"] = f"Task failed due to invalid input: {error_message}"
|
|
227
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
228
|
+
elif error_type == ERROR_CODE_INTEGRITY_MISMATCH:
|
|
229
|
+
job_state["status"] = JOB_STATUS_FAILED
|
|
230
|
+
job_state["error_message"] = f"Task failed due to data integrity mismatch: {error_message}"
|
|
231
|
+
await self.storage.save_job_state(job_id, job_state)
|
|
232
|
+
logger.critical(f"Data integrity mismatch detected for job {job_id}: {error_message}")
|
|
233
|
+
else:
|
|
234
|
+
await self.engine.handle_task_failure(job_state, task_id, error_message)
|
|
235
|
+
return "result_accepted_failure"
|
|
236
|
+
|
|
237
|
+
async def issue_access_token(self, worker_id: str) -> TokenResponse:
|
|
238
|
+
"""Generates and stores a temporary access token."""
|
|
239
|
+
raw_token = token_urlsafe(32)
|
|
240
|
+
token_hash = sha256(raw_token.encode()).hexdigest()
|
|
241
|
+
ttl = 3600
|
|
242
|
+
|
|
243
|
+
await self.storage.save_worker_access_token(worker_id, token_hash, ttl)
|
|
244
|
+
logger.info(f"Issued temporary access token for worker {worker_id}")
|
|
245
|
+
|
|
246
|
+
return TokenResponse(access_token=raw_token, expires_in=ttl, worker_id=worker_id)
|
|
247
|
+
|
|
248
|
+
async def update_worker_heartbeat(
|
|
249
|
+
self, worker_id: str, update_data: Optional[dict[str, Any]]
|
|
250
|
+
) -> Optional[dict[str, Any]]:
|
|
251
|
+
"""Updates worker TTL and status."""
|
|
252
|
+
ttl = self.config.WORKER_HEALTH_CHECK_INTERVAL_SECONDS * 2
|
|
253
|
+
|
|
254
|
+
if update_data:
|
|
255
|
+
updated_worker = await self.storage.update_worker_status(worker_id, update_data, ttl)
|
|
256
|
+
if updated_worker:
|
|
257
|
+
await self.history_storage.log_worker_event(
|
|
258
|
+
{
|
|
259
|
+
"worker_id": worker_id,
|
|
260
|
+
"event_type": "status_update",
|
|
261
|
+
"worker_info_snapshot": updated_worker,
|
|
262
|
+
},
|
|
263
|
+
)
|
|
264
|
+
return updated_worker
|
|
265
|
+
else:
|
|
266
|
+
refreshed = await self.storage.refresh_worker_ttl(worker_id, ttl)
|
|
267
|
+
return {"status": "ttl_refreshed"} if refreshed else None
|
avtomatika/storage/base.py
CHANGED
|
@@ -292,6 +292,16 @@ class StorageBackend(ABC):
|
|
|
292
292
|
"""Retrieves an individual token for a specific worker."""
|
|
293
293
|
raise NotImplementedError
|
|
294
294
|
|
|
295
|
+
@abstractmethod
|
|
296
|
+
async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
|
|
297
|
+
"""Saves a temporary access token for a worker (STS)."""
|
|
298
|
+
raise NotImplementedError
|
|
299
|
+
|
|
300
|
+
@abstractmethod
|
|
301
|
+
async def verify_worker_access_token(self, token: str) -> str | None:
|
|
302
|
+
"""Verifies a temporary access token and returns the associated worker_id if valid."""
|
|
303
|
+
raise NotImplementedError
|
|
304
|
+
|
|
295
305
|
@abstractmethod
|
|
296
306
|
async def get_worker_info(self, worker_id: str) -> dict[str, Any] | None:
|
|
297
307
|
"""Get complete information about a worker by its ID."""
|
avtomatika/storage/memory.py
CHANGED
|
@@ -12,12 +12,12 @@ class MemoryStorage(StorageBackend):
|
|
|
12
12
|
Not persistent.
|
|
13
13
|
"""
|
|
14
14
|
|
|
15
|
-
def __init__(self):
|
|
15
|
+
def __init__(self) -> None:
|
|
16
16
|
self._jobs: dict[str, dict[str, Any]] = {}
|
|
17
17
|
self._workers: dict[str, dict[str, Any]] = {}
|
|
18
18
|
self._worker_ttls: dict[str, float] = {}
|
|
19
|
-
self._worker_task_queues: dict[str, PriorityQueue] = {}
|
|
20
|
-
self._job_queue = Queue()
|
|
19
|
+
self._worker_task_queues: dict[str, PriorityQueue[Any]] = {}
|
|
20
|
+
self._job_queue: Queue[str] = Queue()
|
|
21
21
|
self._quarantine_queue: list[str] = []
|
|
22
22
|
self._watched_jobs: dict[str, float] = {}
|
|
23
23
|
self._client_configs: dict[str, dict[str, Any]] = {}
|
|
@@ -189,10 +189,11 @@ class MemoryStorage(StorageBackend):
|
|
|
189
189
|
async with self._lock:
|
|
190
190
|
self._watched_jobs.pop(job_id, None)
|
|
191
191
|
|
|
192
|
-
async def get_timed_out_jobs(self) -> list[str]:
|
|
192
|
+
async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
|
|
193
193
|
async with self._lock:
|
|
194
194
|
now = monotonic()
|
|
195
195
|
timed_out_ids = [job_id for job_id, timeout_at in self._watched_jobs.items() if timeout_at <= now]
|
|
196
|
+
timed_out_ids = timed_out_ids[:limit]
|
|
196
197
|
for job_id in timed_out_ids:
|
|
197
198
|
self._watched_jobs.pop(job_id, None)
|
|
198
199
|
return timed_out_ids
|
|
@@ -331,6 +332,16 @@ class MemoryStorage(StorageBackend):
|
|
|
331
332
|
async with self._lock:
|
|
332
333
|
return self._worker_tokens.get(worker_id)
|
|
333
334
|
|
|
335
|
+
async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
|
|
336
|
+
async with self._lock:
|
|
337
|
+
self._generic_keys[f"sts:{token}"] = worker_id
|
|
338
|
+
self._generic_key_ttls[f"sts:{token}"] = monotonic() + ttl
|
|
339
|
+
|
|
340
|
+
async def verify_worker_access_token(self, token: str) -> str | None:
|
|
341
|
+
async with self._lock:
|
|
342
|
+
await self._clean_expired()
|
|
343
|
+
return self._generic_keys.get(f"sts:{token}")
|
|
344
|
+
|
|
334
345
|
async def set_task_cancellation_flag(self, task_id: str) -> None:
|
|
335
346
|
key = f"task_cancel:{task_id}"
|
|
336
347
|
await self.increment_key_with_ttl(key, 3600)
|
avtomatika/storage/redis.py
CHANGED
|
@@ -95,7 +95,7 @@ class RedisStorage(StorageBackend):
|
|
|
95
95
|
self,
|
|
96
96
|
job_id: str,
|
|
97
97
|
update_data: dict[str, Any],
|
|
98
|
-
) -> dict[
|
|
98
|
+
) -> dict[str, Any]:
|
|
99
99
|
"""Atomically update the job state in Redis using a transaction."""
|
|
100
100
|
key = self._get_key(job_id)
|
|
101
101
|
|
|
@@ -104,7 +104,7 @@ class RedisStorage(StorageBackend):
|
|
|
104
104
|
try:
|
|
105
105
|
await pipe.watch(key)
|
|
106
106
|
current_state_raw = await pipe.get(key)
|
|
107
|
-
current_state = self._unpack(current_state_raw) if current_state_raw else {}
|
|
107
|
+
current_state: dict[str, Any] = self._unpack(current_state_raw) if current_state_raw else {}
|
|
108
108
|
current_state.update(update_data)
|
|
109
109
|
|
|
110
110
|
pipe.multi()
|
|
@@ -147,7 +147,7 @@ class RedisStorage(StorageBackend):
|
|
|
147
147
|
key = f"orchestrator:worker:info:{worker_id}"
|
|
148
148
|
tasks_key = f"orchestrator:worker:tasks:{worker_id}"
|
|
149
149
|
|
|
150
|
-
tasks = await self._redis.smembers(tasks_key) # type: ignore
|
|
150
|
+
tasks = await self._redis.smembers(tasks_key) # type: ignore[var-annotated]
|
|
151
151
|
|
|
152
152
|
async with self._redis.pipeline(transaction=True) as pipe:
|
|
153
153
|
pipe.delete(key)
|
|
@@ -156,7 +156,7 @@ class RedisStorage(StorageBackend):
|
|
|
156
156
|
pipe.srem("orchestrator:index:workers:idle", worker_id)
|
|
157
157
|
|
|
158
158
|
for task in tasks:
|
|
159
|
-
task_str = task.decode("utf-8") if isinstance(task, bytes) else task
|
|
159
|
+
task_str = task.decode("utf-8") if isinstance(task, bytes) else str(task)
|
|
160
160
|
pipe.srem(f"orchestrator:index:workers:task:{task_str}", worker_id)
|
|
161
161
|
|
|
162
162
|
await pipe.execute()
|
|
@@ -204,8 +204,8 @@ class RedisStorage(StorageBackend):
|
|
|
204
204
|
"""Finds idle workers that support the given task using set intersection."""
|
|
205
205
|
task_index = f"orchestrator:index:workers:task:{task_type}"
|
|
206
206
|
idle_index = "orchestrator:index:workers:idle"
|
|
207
|
-
worker_ids = await self._redis.sinter(task_index, idle_index) # type: ignore
|
|
208
|
-
return [wid.decode("utf-8") if isinstance(wid, bytes) else wid for wid in worker_ids]
|
|
207
|
+
worker_ids = await self._redis.sinter(task_index, idle_index) # type: ignore[var-annotated]
|
|
208
|
+
return [wid.decode("utf-8") if isinstance(wid, bytes) else str(wid) for wid in worker_ids]
|
|
209
209
|
|
|
210
210
|
async def enqueue_task_for_worker(self, worker_id: str, task_payload: dict[str, Any], priority: float) -> None:
|
|
211
211
|
key = f"orchestrator:task_queue:{worker_id}"
|
|
@@ -274,13 +274,14 @@ class RedisStorage(StorageBackend):
|
|
|
274
274
|
existence = await pipe.execute()
|
|
275
275
|
dead_ids = [worker_ids[i] for i, exists in enumerate(existence) if not exists]
|
|
276
276
|
for wid in dead_ids:
|
|
277
|
-
tasks = await self._redis.smembers(f"orchestrator:worker:tasks:{wid}") # type: ignore
|
|
277
|
+
tasks = await self._redis.smembers(f"orchestrator:worker:tasks:{wid}") # type: ignore[var-annotated]
|
|
278
278
|
async with self._redis.pipeline(transaction=True) as p:
|
|
279
279
|
p.delete(f"orchestrator:worker:tasks:{wid}")
|
|
280
280
|
p.srem("orchestrator:index:workers:all", wid)
|
|
281
281
|
p.srem("orchestrator:index:workers:idle", wid)
|
|
282
282
|
for t in tasks:
|
|
283
|
-
|
|
283
|
+
t_str = t.decode() if isinstance(t, bytes) else str(t)
|
|
284
|
+
p.srem(f"orchestrator:index:workers:task:{t_str}", wid)
|
|
284
285
|
await p.execute()
|
|
285
286
|
|
|
286
287
|
async def add_job_to_watch(self, job_id: str, timeout_at: float) -> None:
|
|
@@ -291,10 +292,33 @@ class RedisStorage(StorageBackend):
|
|
|
291
292
|
|
|
292
293
|
async def get_timed_out_jobs(self, limit: int = 100) -> list[str]:
|
|
293
294
|
now = get_running_loop().time()
|
|
294
|
-
|
|
295
|
+
# Lua script to atomically fetch and remove timed out jobs
|
|
296
|
+
LUA_POP_TIMEOUTS = """
|
|
297
|
+
local now = ARGV[1]
|
|
298
|
+
local limit = ARGV[2]
|
|
299
|
+
local ids = redis.call('ZRANGEBYSCORE', KEYS[1], 0, now, 'LIMIT', 0, limit)
|
|
300
|
+
if #ids > 0 then
|
|
301
|
+
redis.call('ZREM', KEYS[1], unpack(ids))
|
|
302
|
+
end
|
|
303
|
+
return ids
|
|
304
|
+
"""
|
|
305
|
+
try:
|
|
306
|
+
sha = await self._redis.script_load(LUA_POP_TIMEOUTS)
|
|
307
|
+
ids = await self._redis.evalsha(sha, 1, "orchestrator:watched_jobs", now, limit)
|
|
308
|
+
except NoScriptError:
|
|
309
|
+
ids = await self._redis.eval(LUA_POP_TIMEOUTS, 1, "orchestrator:watched_jobs", now, limit)
|
|
310
|
+
except ResponseError as e:
|
|
311
|
+
# Fallback for Redis versions that don't support script_load/evalsha or other errors
|
|
312
|
+
if "unknown command" in str(e).lower():
|
|
313
|
+
logger.warning("Redis does not support LUA scripts. Falling back to non-atomic get_timed_out_jobs.")
|
|
314
|
+
ids = await self._redis.zrangebyscore("orchestrator:watched_jobs", 0, now, start=0, num=limit)
|
|
315
|
+
if ids:
|
|
316
|
+
await self._redis.zrem("orchestrator:watched_jobs", *ids) # type: ignore
|
|
317
|
+
else:
|
|
318
|
+
raise e
|
|
319
|
+
|
|
295
320
|
if ids:
|
|
296
|
-
|
|
297
|
-
return [i.decode("utf-8") for i in ids]
|
|
321
|
+
return [i.decode("utf-8") if isinstance(i, bytes) else i for i in ids]
|
|
298
322
|
return []
|
|
299
323
|
|
|
300
324
|
async def enqueue_job(self, job_id: str) -> None:
|
|
@@ -411,6 +435,13 @@ class RedisStorage(StorageBackend):
|
|
|
411
435
|
token = await self._redis.get(f"orchestrator:worker:token:{worker_id}")
|
|
412
436
|
return token.decode("utf-8") if token else None
|
|
413
437
|
|
|
438
|
+
async def save_worker_access_token(self, worker_id: str, token: str, ttl: int) -> None:
|
|
439
|
+
await self._redis.set(f"orchestrator:sts:token:{token}", worker_id, ex=ttl)
|
|
440
|
+
|
|
441
|
+
async def verify_worker_access_token(self, token: str) -> str | None:
|
|
442
|
+
worker_id = await self._redis.get(f"orchestrator:sts:token:{token}")
|
|
443
|
+
return worker_id.decode("utf-8") if worker_id else None
|
|
444
|
+
|
|
414
445
|
async def acquire_lock(self, key: str, holder_id: str, ttl: int) -> bool:
|
|
415
446
|
return bool(await self._redis.set(f"orchestrator:lock:{key}", holder_id, nx=True, ex=ttl))
|
|
416
447
|
|
avtomatika/telemetry.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from logging import getLogger
|
|
2
2
|
from os import getenv
|
|
3
|
+
from typing import Any
|
|
3
4
|
|
|
4
5
|
logger = getLogger(__name__)
|
|
5
6
|
|
|
@@ -17,28 +18,28 @@ except ImportError:
|
|
|
17
18
|
TELEMETRY_ENABLED = False
|
|
18
19
|
|
|
19
20
|
class DummySpan:
|
|
20
|
-
def __enter__(self):
|
|
21
|
+
def __enter__(self) -> "DummySpan":
|
|
21
22
|
return self
|
|
22
23
|
|
|
23
|
-
def __exit__(self, *args):
|
|
24
|
+
def __exit__(self, *args: Any) -> None:
|
|
24
25
|
pass
|
|
25
26
|
|
|
26
|
-
def set_attribute(self, key, value):
|
|
27
|
+
def set_attribute(self, key: str, value: Any) -> None:
|
|
27
28
|
pass
|
|
28
29
|
|
|
29
30
|
class DummyTracer:
|
|
30
31
|
@staticmethod
|
|
31
|
-
def start_as_current_span(name, context=None):
|
|
32
|
+
def start_as_current_span(name: str, context: Any = None) -> DummySpan:
|
|
32
33
|
return DummySpan()
|
|
33
34
|
|
|
34
35
|
class NoOpTrace:
|
|
35
|
-
def get_tracer(self, name):
|
|
36
|
+
def get_tracer(self, name: str) -> DummyTracer:
|
|
36
37
|
return DummyTracer()
|
|
37
38
|
|
|
38
|
-
trace = NoOpTrace()
|
|
39
|
+
trace: Any = NoOpTrace() # type: ignore[no-redef]
|
|
39
40
|
|
|
40
41
|
|
|
41
|
-
def setup_telemetry(service_name: str = "avtomatika"):
|
|
42
|
+
def setup_telemetry(service_name: str = "avtomatika") -> Any:
|
|
42
43
|
"""Configures OpenTelemetry for the application if installed."""
|
|
43
44
|
if not TELEMETRY_ENABLED:
|
|
44
45
|
logger.info("opentelemetry-sdk not found. Telemetry is disabled.")
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
from asyncio import CancelledError, Queue, QueueFull, create_task, sleep
|
|
1
|
+
from asyncio import CancelledError, Queue, QueueFull, Task, create_task, sleep
|
|
2
2
|
from contextlib import suppress
|
|
3
3
|
from dataclasses import asdict, dataclass
|
|
4
4
|
from logging import getLogger
|
|
5
|
-
from typing import Any
|
|
5
|
+
from typing import Any, Optional
|
|
6
6
|
|
|
7
7
|
from aiohttp import ClientSession, ClientTimeout
|
|
8
8
|
|
|
@@ -24,7 +24,7 @@ class WebhookSender:
|
|
|
24
24
|
self.timeout = ClientTimeout(total=10)
|
|
25
25
|
self.max_retries = 3
|
|
26
26
|
self._queue: Queue[tuple[str, WebhookPayload]] = Queue(maxsize=1000)
|
|
27
|
-
self._worker_task = None
|
|
27
|
+
self._worker_task: Optional[Task[None]] = None
|
|
28
28
|
|
|
29
29
|
def start(self) -> None:
|
|
30
30
|
if not self._worker_task:
|
avtomatika/watcher.py
CHANGED
|
@@ -3,6 +3,8 @@ from logging import getLogger
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
|
+
from .constants import JOB_STATUS_FAILED, JOB_STATUS_WAITING_FOR_WORKER
|
|
7
|
+
|
|
6
8
|
if TYPE_CHECKING:
|
|
7
9
|
from .engine import OrchestratorEngine
|
|
8
10
|
|
|
@@ -38,8 +40,8 @@ class Watcher:
|
|
|
38
40
|
try:
|
|
39
41
|
# Get the latest version to avoid overwriting
|
|
40
42
|
job_state = await self.storage.get_job_state(job_id)
|
|
41
|
-
if job_state and job_state["status"] ==
|
|
42
|
-
job_state["status"] =
|
|
43
|
+
if job_state and job_state["status"] == JOB_STATUS_WAITING_FOR_WORKER:
|
|
44
|
+
job_state["status"] = JOB_STATUS_FAILED
|
|
43
45
|
job_state["error_message"] = "Worker task timed out."
|
|
44
46
|
await self.storage.save_job_state(job_id, job_state)
|
|
45
47
|
|