avtomatika 1.0b7__py3-none-any.whl → 1.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,11 +49,12 @@ class SQLiteHistoryStorage(HistoryStorageBase):
49
49
  """
50
50
 
51
51
  def __init__(self, db_path: str, tz_name: str = "UTC"):
52
+ super().__init__()
52
53
  self._db_path = db_path
53
54
  self._conn: Connection | None = None
54
55
  self.tz = ZoneInfo(tz_name)
55
56
 
56
- async def initialize(self):
57
+ async def initialize(self) -> None:
57
58
  """Initializes the database connection and creates tables if they don't exist."""
58
59
  try:
59
60
  self._conn = await connect(self._db_path)
@@ -68,8 +69,9 @@ class SQLiteHistoryStorage(HistoryStorageBase):
68
69
  logger.error(f"Failed to initialize SQLite history storage: {e}")
69
70
  raise
70
71
 
71
- async def close(self):
72
- """Closes the database connection."""
72
+ async def close(self) -> None:
73
+ """Closes the database connection and background worker."""
74
+ await super().close()
73
75
  if self._conn:
74
76
  await self._conn.close()
75
77
  logger.info("SQLite history storage connection closed.")
@@ -91,7 +93,7 @@ class SQLiteHistoryStorage(HistoryStorageBase):
91
93
 
92
94
  return item
93
95
 
94
- async def log_job_event(self, event_data: dict[str, Any]):
96
+ async def _persist_job_event(self, event_data: dict[str, Any]) -> None:
95
97
  """Logs a job lifecycle event to the job_history table."""
96
98
  if not self._conn:
97
99
  raise RuntimeError("History storage is not initialized.")
@@ -128,7 +130,7 @@ class SQLiteHistoryStorage(HistoryStorageBase):
128
130
  except Error as e:
129
131
  logger.error(f"Failed to log job event: {e}")
130
132
 
131
- async def log_worker_event(self, event_data: dict[str, Any]):
133
+ async def _persist_worker_event(self, event_data: dict[str, Any]) -> None:
132
134
  """Logs a worker lifecycle event to the worker_history table."""
133
135
  if not self._conn:
134
136
  raise RuntimeError("History storage is not initialized.")
avtomatika/metrics.py CHANGED
@@ -12,7 +12,7 @@ task_queue_length: Gauge
12
12
  active_workers: Gauge
13
13
 
14
14
 
15
- def init_metrics():
15
+ def init_metrics() -> None:
16
16
  """
17
17
  Initializes Prometheus metrics.
18
18
  Uses a registry check for idempotency, which is important for tests.
avtomatika/reputation.py CHANGED
@@ -52,48 +52,54 @@ class ReputationCalculator:
52
52
  async def calculate_all_reputations(self):
53
53
  """Calculates and updates the reputation for all active workers."""
54
54
  logger.info("Starting reputation calculation for all workers...")
55
- workers = await self.storage.get_available_workers()
56
- if not workers:
55
+
56
+ # Get only IDs of active workers to avoid O(N) scan of all data
57
+ worker_ids = await self.storage.get_active_worker_ids()
58
+
59
+ if not worker_ids:
57
60
  logger.info("No active workers found for reputation calculation.")
58
61
  return
59
62
 
60
- for worker in workers:
61
- worker_id = worker.get("worker_id")
62
- if not worker_id:
63
- continue
64
-
65
- history = await self.history_storage.get_worker_history(
66
- worker_id,
67
- since_days=REPUTATION_HISTORY_DAYS,
68
- )
69
-
70
- # Count only task completion events
71
- task_finished_events = [event for event in history if event.get("event_type") == "task_finished"]
72
-
73
- if not task_finished_events:
74
- # If there is no history, the reputation does not change (remains 1.0 by default)
75
- continue
76
-
77
- successful_tasks = 0
78
- for event in task_finished_events:
79
- # Extract the result from the snapshot
80
- snapshot = event.get("context_snapshot", {})
81
- result = snapshot.get("result", {})
82
- if result.get("status") == "success":
83
- successful_tasks += 1
84
-
85
- total_tasks = len(task_finished_events)
86
- new_reputation = successful_tasks / total_tasks if total_tasks > 0 else 1.0
87
-
88
- # Round for cleanliness
89
- new_reputation = round(new_reputation, 4)
90
-
91
- logger.info(
92
- f"Updating reputation for worker {worker_id}: {worker.get('reputation')} -> {new_reputation}",
93
- )
94
- await self.storage.update_worker_data(
95
- worker_id,
96
- {"reputation": new_reputation},
97
- )
63
+ logger.info(f"Recalculating reputation for {len(worker_ids)} workers.")
64
+
65
+ for worker_id in worker_ids:
66
+ if not self._running:
67
+ break
68
+
69
+ try:
70
+ history = await self.history_storage.get_worker_history(
71
+ worker_id,
72
+ since_days=REPUTATION_HISTORY_DAYS,
73
+ )
74
+
75
+ # Count only task completion events
76
+ task_finished_events = [event for event in history if event.get("event_type") == "task_finished"]
77
+
78
+ if not task_finished_events:
79
+ # If there is no history, skip to next worker
80
+ continue
81
+
82
+ successful_tasks = 0
83
+ for event in task_finished_events:
84
+ # Extract the result from the snapshot
85
+ snapshot = event.get("context_snapshot", {})
86
+ result = snapshot.get("result", {})
87
+ if result.get("status") == "success":
88
+ successful_tasks += 1
89
+
90
+ total_tasks = len(task_finished_events)
91
+ new_reputation = successful_tasks / total_tasks if total_tasks > 0 else 1.0
92
+ new_reputation = round(new_reputation, 4)
93
+
94
+ await self.storage.update_worker_data(
95
+ worker_id,
96
+ {"reputation": new_reputation},
97
+ )
98
+
99
+ # Throttling: Small sleep to prevent DB spikes
100
+ await sleep(0.1)
101
+
102
+ except Exception as e:
103
+ logger.error(f"Failed to calculate reputation for worker {worker_id}: {e}")
98
104
 
99
105
  logger.info("Reputation calculation finished.")
avtomatika/s3.py ADDED
@@ -0,0 +1,379 @@
1
+ from asyncio import Semaphore, gather, to_thread
2
+ from logging import getLogger
3
+ from os import sep, walk
4
+ from pathlib import Path
5
+ from shutil import rmtree
6
+ from typing import Any
7
+
8
+ from aiofiles import open as aiopen
9
+ from obstore import delete_async, get_async, put_async
10
+ from obstore import list as obstore_list
11
+ from obstore.store import S3Store
12
+ from orjson import dumps, loads
13
+ from rxon.blob import calculate_config_hash, parse_uri
14
+ from rxon.exceptions import IntegrityError
15
+
16
+ from .config import Config
17
+ from .history.base import HistoryStorageBase
18
+
19
+ logger = getLogger(__name__)
20
+
21
+ try:
22
+ HAS_S3_LIBS = True
23
+ except ImportError:
24
+ HAS_S3_LIBS = False
25
+ S3Store = Any
26
+
27
+
28
+ class TaskFiles:
29
+ """
30
+ Manages files for a specific job, ensuring full compatibility with avtomatika-worker.
31
+ Supports recursive directory download/upload and non-blocking I/O.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ store: "S3Store",
37
+ bucket: str,
38
+ job_id: str,
39
+ base_local_dir: str | Path,
40
+ semaphore: Semaphore,
41
+ history: HistoryStorageBase | None = None,
42
+ ):
43
+ self._store = store
44
+ self._bucket = bucket
45
+ self._job_id = job_id
46
+ self._history = history
47
+ self._s3_prefix = f"jobs/{job_id}/"
48
+ self.local_dir = Path(base_local_dir) / job_id
49
+ self._semaphore = semaphore
50
+
51
+ def _ensure_local_dir(self) -> None:
52
+ if not self.local_dir.exists():
53
+ self.local_dir.mkdir(parents=True, exist_ok=True)
54
+
55
+ def path(self, filename: str) -> Path:
56
+ """Returns local path for a filename, ensuring the directory exists."""
57
+ self._ensure_local_dir()
58
+ clean_name = filename.split("/")[-1] if "://" in filename else filename.lstrip("/")
59
+ return self.local_dir / clean_name
60
+
61
+ async def _download_single_file(
62
+ self,
63
+ key: str,
64
+ local_path: Path,
65
+ expected_size: int | None = None,
66
+ expected_hash: str | None = None,
67
+ ) -> dict[str, Any]:
68
+ """Downloads a single file safely using semaphore and streaming.
69
+ Returns metadata (size, etag).
70
+ """
71
+ if not local_path.parent.exists():
72
+ await to_thread(local_path.parent.mkdir, parents=True, exist_ok=True)
73
+
74
+ async with self._semaphore:
75
+ response = await get_async(self._store, key)
76
+ meta = response.meta
77
+ file_size = meta.size
78
+ etag = meta.e_tag.strip('"') if meta.e_tag else None
79
+
80
+ if expected_size is not None and file_size != expected_size:
81
+ raise IntegrityError(f"File size mismatch for {key}: expected {expected_size}, got {file_size}")
82
+
83
+ if expected_hash is not None and etag and expected_hash != etag:
84
+ raise IntegrityError(f"Integrity mismatch for {key}: expected ETag {expected_hash}, got {etag}")
85
+
86
+ stream = response.stream()
87
+ async with aiopen(local_path, "wb") as f:
88
+ async for chunk in stream:
89
+ await f.write(chunk)
90
+
91
+ return {"size": file_size, "etag": etag}
92
+
93
+ async def download(
94
+ self,
95
+ name_or_uri: str,
96
+ local_name: str | None = None,
97
+ verify_meta: dict[str, Any] | None = None,
98
+ ) -> Path:
99
+ """
100
+ Downloads a file or directory (recursively).
101
+ If URI ends with '/', it treats it as a directory.
102
+ """
103
+ bucket, key, is_dir = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
104
+ verify_meta = verify_meta or {}
105
+
106
+ if local_name:
107
+ target_path = self.path(local_name)
108
+ else:
109
+ suffix = key.replace(self._s3_prefix, "", 1) if key.startswith(self._s3_prefix) else key.split("/")[-1]
110
+ target_path = self.local_dir / suffix
111
+
112
+ if is_dir:
113
+ logger.info(f"Recursive download: s3://{bucket}/{key} -> {target_path}")
114
+ entries = await to_thread(lambda: list(obstore_list(self._store, prefix=key)))
115
+
116
+ tasks = []
117
+ for entry in entries:
118
+ s3_key = entry["path"]
119
+ rel_path = s3_key[len(key) :]
120
+ if not rel_path:
121
+ continue
122
+
123
+ local_file_path = target_path / rel_path
124
+ tasks.append(self._download_single_file(s3_key, local_file_path))
125
+
126
+ if tasks:
127
+ results = await gather(*tasks)
128
+ total_size = sum(r["size"] for r in results)
129
+ await self._log_event(
130
+ "download_dir",
131
+ f"s3://{bucket}/{key}",
132
+ str(target_path),
133
+ metadata={"total_size": total_size, "file_count": len(results)},
134
+ )
135
+ else:
136
+ await self._log_event(
137
+ "download_dir",
138
+ f"s3://{bucket}/{key}",
139
+ str(target_path),
140
+ metadata={"total_size": 0, "file_count": 0},
141
+ )
142
+ return target_path
143
+ else:
144
+ logger.debug(f"Downloading s3://{bucket}/{key} -> {target_path}")
145
+ meta = await self._download_single_file(
146
+ key,
147
+ target_path,
148
+ expected_size=verify_meta.get("size"),
149
+ expected_hash=verify_meta.get("hash"),
150
+ )
151
+ await self._log_event("download", f"s3://{bucket}/{key}", str(target_path), metadata=meta)
152
+ return target_path
153
+
154
+ async def _upload_single_file(self, local_path: Path, s3_key: str) -> dict[str, Any]:
155
+ """Uploads a single file safely using semaphore. Returns S3 metadata."""
156
+ async with self._semaphore:
157
+ file_size = local_path.stat().st_size
158
+ async with aiopen(local_path, "rb") as f:
159
+ content = await f.read()
160
+ result = await put_async(self._store, s3_key, content)
161
+ etag = result.e_tag.strip('"') if result.e_tag else None
162
+ return {"size": file_size, "etag": etag}
163
+
164
+ async def upload(self, local_name: str, remote_name: str | None = None) -> str:
165
+ """
166
+ Uploads a file or directory recursively.
167
+ If local_name points to a directory, it uploads all contents.
168
+ """
169
+ local_path = self.path(local_name)
170
+
171
+ if local_path.is_dir():
172
+ base_remote = (remote_name or local_name).lstrip("/")
173
+ if not base_remote.endswith("/"):
174
+ base_remote += "/"
175
+
176
+ target_prefix = f"{self._s3_prefix}{base_remote}"
177
+ logger.info(f"Recursive upload: {local_path} -> s3://{self._bucket}/{target_prefix}")
178
+
179
+ def collect_files():
180
+ files_to_upload = []
181
+ for root, _, files in walk(local_path):
182
+ for file in files:
183
+ abs_path = Path(root) / file
184
+ rel_path = abs_path.relative_to(local_path)
185
+ s3_key = f"{target_prefix}{str(rel_path).replace(sep, '/')}"
186
+ files_to_upload.append((abs_path, s3_key))
187
+ return files_to_upload
188
+
189
+ files_map = await to_thread(collect_files)
190
+
191
+ tasks = [self._upload_single_file(lp, k) for lp, k in files_map]
192
+ if tasks:
193
+ results = await gather(*tasks)
194
+ total_size = sum(r["size"] for r in results)
195
+ metadata = {"total_size": total_size, "file_count": len(results)}
196
+ else:
197
+ metadata = {"total_size": 0, "file_count": 0}
198
+
199
+ uri = f"s3://{self._bucket}/{target_prefix}"
200
+ await self._log_event("upload_dir", uri, str(local_path), metadata=metadata)
201
+ return uri
202
+
203
+ elif local_path.exists():
204
+ target_key = f"{self._s3_prefix}{(remote_name or local_name).lstrip('/')}"
205
+ logger.debug(f"Uploading {local_path} -> s3://{self._bucket}/{target_key}")
206
+
207
+ meta = await self._upload_single_file(local_path, target_key)
208
+
209
+ uri = f"s3://{self._bucket}/{target_key}"
210
+ await self._log_event("upload", uri, str(local_path), metadata=meta)
211
+ return uri
212
+ else:
213
+ raise FileNotFoundError(f"Local file/dir not found: {local_path}")
214
+
215
+ async def read_text(self, name_or_uri: str) -> str:
216
+ bucket, key, _ = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
217
+ filename = key.split("/")[-1]
218
+ local_path = self.path(filename)
219
+
220
+ if not local_path.exists():
221
+ await self.download(name_or_uri)
222
+
223
+ async with aiopen(local_path, "r", encoding="utf-8") as f:
224
+ return await f.read()
225
+
226
+ async def read_json(self, name_or_uri: str) -> Any:
227
+ bucket, key, _ = parse_uri(name_or_uri, self._bucket, self._s3_prefix)
228
+ filename = key.split("/")[-1]
229
+ local_path = self.path(filename)
230
+
231
+ if not local_path.exists():
232
+ await self.download(name_or_uri)
233
+
234
+ async with aiopen(local_path, "rb") as f:
235
+ content = await f.read()
236
+ return loads(content)
237
+
238
+ async def write_json(self, filename: str, data: Any, upload: bool = True) -> str:
239
+ """Writes JSON locally (binary mode) and optionally uploads to S3."""
240
+ local_path = self.path(filename)
241
+ json_bytes = dumps(data)
242
+
243
+ async with aiopen(local_path, "wb") as f:
244
+ await f.write(json_bytes)
245
+
246
+ if upload:
247
+ return await self.upload(filename)
248
+ return f"file://{local_path}"
249
+
250
+ async def write_text(self, filename: str, text: str, upload: bool = True) -> Path:
251
+ local_path = self.path(filename)
252
+ async with aiopen(local_path, "w", encoding="utf-8") as f:
253
+ await f.write(text)
254
+
255
+ if upload:
256
+ await self.upload(filename)
257
+
258
+ return local_path
259
+
260
+ async def cleanup(self) -> None:
261
+ """Full cleanup of S3 prefix and local job directory."""
262
+ logger.info(f"Cleanup for job {self._job_id}...")
263
+ try:
264
+ entries = await to_thread(lambda: list(obstore_list(self._store, prefix=self._s3_prefix)))
265
+ paths_to_delete = [entry["path"] for entry in entries]
266
+ if paths_to_delete:
267
+ await delete_async(self._store, paths_to_delete)
268
+ except Exception as e:
269
+ logger.error(f"S3 cleanup error: {e}")
270
+
271
+ if self.local_dir.exists():
272
+ await to_thread(rmtree, self.local_dir)
273
+
274
+ async def _log_event(
275
+ self,
276
+ operation: str,
277
+ file_uri: str,
278
+ local_path: str,
279
+ metadata: dict[str, Any] | None = None,
280
+ ) -> None:
281
+ if not self._history:
282
+ return
283
+
284
+ try:
285
+ context_snapshot = {
286
+ "operation": operation,
287
+ "s3_uri": file_uri,
288
+ "local_path": str(local_path),
289
+ }
290
+ if metadata:
291
+ context_snapshot.update(metadata)
292
+
293
+ await self._history.log_job_event(
294
+ {
295
+ "job_id": self._job_id,
296
+ "event_type": "s3_operation",
297
+ "state": "running",
298
+ "context_snapshot": context_snapshot,
299
+ }
300
+ )
301
+ except Exception as e:
302
+ logger.warning(f"Failed to log S3 event: {e}")
303
+
304
+
305
+ class S3Service:
306
+ """
307
+ Central service for S3 operations.
308
+ Initializes the Store and provides TaskFiles instances.
309
+ """
310
+
311
+ def __init__(self, config: Config, history: HistoryStorageBase | None = None):
312
+ self.config = config
313
+ self._history = history
314
+ self._store: S3Store | None = None
315
+ self._semaphore: Semaphore | None = None
316
+
317
+ self._config_present = bool(config.S3_ENDPOINT_URL and config.S3_ACCESS_KEY and config.S3_SECRET_KEY)
318
+
319
+ if self._config_present:
320
+ if HAS_S3_LIBS:
321
+ self._enabled = True
322
+ self._initialize_store()
323
+ else:
324
+ logger.error(
325
+ "S3 configuration found, but 'avtomatika[s3]' extra dependencies are not installed. "
326
+ "S3 support will be disabled. Install with: pip install 'avtomatika[s3]'"
327
+ )
328
+ self._enabled = False
329
+ else:
330
+ self._enabled = False
331
+ if any([config.S3_ENDPOINT_URL, config.S3_ACCESS_KEY, config.S3_SECRET_KEY]):
332
+ logger.warning("Partial S3 configuration found. S3 support disabled.")
333
+
334
+ def _initialize_store(self) -> None:
335
+ try:
336
+ self._store = S3Store(
337
+ bucket=self.config.S3_DEFAULT_BUCKET,
338
+ access_key_id=self.config.S3_ACCESS_KEY,
339
+ secret_access_key=self.config.S3_SECRET_KEY,
340
+ region=self.config.S3_REGION,
341
+ endpoint=self.config.S3_ENDPOINT_URL,
342
+ allow_http="http://" in self.config.S3_ENDPOINT_URL,
343
+ force_path_style=True,
344
+ )
345
+ self._semaphore = Semaphore(self.config.S3_MAX_CONCURRENCY)
346
+ logger.info(
347
+ f"S3Service initialized (Endpoint: {self.config.S3_ENDPOINT_URL}, "
348
+ f"Bucket: {self.config.S3_DEFAULT_BUCKET}, "
349
+ f"Max Concurrency: {self.config.S3_MAX_CONCURRENCY})"
350
+ )
351
+ except Exception as e:
352
+ logger.error(f"Failed to initialize S3 Store: {e}")
353
+ self._enabled = False
354
+
355
+ def get_config_hash(self) -> str | None:
356
+ """Returns a hash of the current S3 configuration for consistency checks."""
357
+ if not self._enabled:
358
+ return None
359
+ return calculate_config_hash(
360
+ self.config.S3_ENDPOINT_URL,
361
+ self.config.S3_ACCESS_KEY,
362
+ self.config.S3_DEFAULT_BUCKET,
363
+ )
364
+
365
+ def get_task_files(self, job_id: str) -> TaskFiles | None:
366
+ if not self._enabled or not self._store or not self._semaphore:
367
+ return None
368
+
369
+ return TaskFiles(
370
+ self._store,
371
+ self.config.S3_DEFAULT_BUCKET,
372
+ job_id,
373
+ self.config.TASK_FILES_DIR,
374
+ self._semaphore,
375
+ self._history,
376
+ )
377
+
378
+ async def close(self) -> None:
379
+ pass
avtomatika/security.py CHANGED
@@ -10,6 +10,62 @@ from .storage.base import StorageBackend
10
10
  Handler = Callable[[web.Request], Awaitable[web.Response]]
11
11
 
12
12
 
13
+ async def verify_worker_auth(
14
+ storage: StorageBackend,
15
+ config: Config,
16
+ token: str | None,
17
+ cert_identity: str | None,
18
+ worker_id_hint: str | None,
19
+ ) -> str:
20
+ """
21
+ Verifies worker authentication using token or mTLS.
22
+ Returns authenticated worker_id.
23
+ Raises ValueError (400), PermissionError (401/403) on failure.
24
+ """
25
+ # mTLS Check
26
+ if cert_identity:
27
+ if worker_id_hint and cert_identity != worker_id_hint:
28
+ raise PermissionError(
29
+ f"Unauthorized: Certificate CN '{cert_identity}' does not match worker_id '{worker_id_hint}'"
30
+ )
31
+ return cert_identity
32
+
33
+ # Token Check
34
+ if not token:
35
+ raise PermissionError(f"Missing {AUTH_HEADER_WORKER} header or client certificate")
36
+
37
+ hashed_provided_token = sha256(token.encode()).hexdigest()
38
+
39
+ # STS Access Token
40
+ token_worker_id = await storage.verify_worker_access_token(hashed_provided_token)
41
+ if token_worker_id:
42
+ if worker_id_hint and token_worker_id != worker_id_hint:
43
+ raise PermissionError(
44
+ f"Unauthorized: Access Token belongs to '{token_worker_id}', but request is for '{worker_id_hint}'"
45
+ )
46
+ return token_worker_id
47
+
48
+ # Individual/Global Token
49
+ if not worker_id_hint:
50
+ if config.GLOBAL_WORKER_TOKEN and token == config.GLOBAL_WORKER_TOKEN:
51
+ return "unknown_authenticated_by_global_token"
52
+
53
+ raise PermissionError("Unauthorized: Invalid token or missing worker_id hint")
54
+
55
+ # Individual Token for specific worker
56
+ expected_token_hash = await storage.get_worker_token(worker_id_hint)
57
+ if expected_token_hash:
58
+ if hashed_provided_token == expected_token_hash:
59
+ return worker_id_hint
60
+ raise PermissionError("Unauthorized: Invalid individual worker token")
61
+
62
+ # Global Token Fallback
63
+ if config.GLOBAL_WORKER_TOKEN and token == config.GLOBAL_WORKER_TOKEN:
64
+ return worker_id_hint
65
+
66
+ raise PermissionError("Unauthorized: No valid token found")
67
+
68
+
13
69
  def client_auth_middleware_factory(
14
70
  storage: StorageBackend,
15
71
  ) -> Any:
@@ -38,77 +94,3 @@ def client_auth_middleware_factory(
38
94
  return await handler(request)
39
95
 
40
96
  return middleware
41
-
42
-
43
- def worker_auth_middleware_factory(
44
- storage: StorageBackend,
45
- config: Config,
46
- ) -> Any:
47
- """
48
- Middleware factory for worker authentication.
49
- It supports both individual tokens and a global fallback token for backward compatibility.
50
- It also attaches the authenticated worker_id to the request.
51
- """
52
-
53
- @web.middleware
54
- async def middleware(request: web.Request, handler: Handler) -> web.Response:
55
- provided_token = request.headers.get(AUTH_HEADER_WORKER)
56
- if not provided_token:
57
- return web.json_response(
58
- {"error": f"Missing {AUTH_HEADER_WORKER} header"},
59
- status=401,
60
- )
61
-
62
- worker_id = request.match_info.get("worker_id")
63
- data = None
64
-
65
- # For specific endpoints, worker_id is in the body.
66
- # We need to read the body here, which can be tricky as it's a stream.
67
- # We clone the request to allow the handler to read the body again.
68
- if not worker_id and (request.path.endswith("/register") or request.path.endswith("/tasks/result")):
69
- try:
70
- cloned_request = request.clone()
71
- data = await cloned_request.json()
72
- worker_id = data.get("worker_id")
73
- # Attach the parsed data to the request so the handler doesn't need to re-parse
74
- if request.path.endswith("/register"):
75
- request["worker_registration_data"] = data
76
- except Exception:
77
- return web.json_response({"error": "Invalid JSON body"}, status=400)
78
-
79
- # If no worker_id could be determined from path or body, we can only validate against the global token.
80
- if not worker_id:
81
- if provided_token == config.GLOBAL_WORKER_TOKEN:
82
- # We don't know the worker_id, so we can't attach it.
83
- return await handler(request)
84
- else:
85
- return web.json_response(
86
- {"error": "Unauthorized: Invalid token or missing worker_id"},
87
- status=401,
88
- )
89
-
90
- # --- Individual Token Check ---
91
- expected_token_hash = await storage.get_worker_token(worker_id)
92
- if expected_token_hash:
93
- hashed_provided_token = sha256(provided_token.encode()).hexdigest()
94
- if hashed_provided_token == expected_token_hash:
95
- request["worker_id"] = worker_id # Attach authenticated worker_id
96
- return await handler(request)
97
- else:
98
- # If an individual token exists, we do not fall back to the global token.
99
- return web.json_response(
100
- {"error": "Unauthorized: Invalid individual worker token"},
101
- status=401,
102
- )
103
-
104
- # --- Global Token Fallback ---
105
- if config.GLOBAL_WORKER_TOKEN and provided_token == config.GLOBAL_WORKER_TOKEN:
106
- request["worker_id"] = worker_id # Attach authenticated worker_id
107
- return await handler(request)
108
-
109
- return web.json_response(
110
- {"error": "Unauthorized: No valid token found"},
111
- status=401,
112
- )
113
-
114
- return middleware
File without changes