python3-commons 0.0.0__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python3-commons might be problematic. Click here for more details.
- python3_commons/api_client.py +44 -16
- python3_commons/audit.py +127 -138
- python3_commons/auth.py +53 -47
- python3_commons/cache.py +36 -38
- python3_commons/conf.py +37 -6
- python3_commons/db/__init__.py +15 -10
- python3_commons/db/helpers.py +5 -7
- python3_commons/db/models/__init__.py +8 -2
- python3_commons/db/models/auth.py +2 -2
- python3_commons/db/models/common.py +8 -6
- python3_commons/db/models/rbac.py +5 -5
- python3_commons/fs.py +2 -2
- python3_commons/helpers.py +44 -13
- python3_commons/object_storage.py +135 -73
- python3_commons/permissions.py +2 -4
- python3_commons/serializers/common.py +8 -0
- python3_commons/serializers/json.py +5 -7
- python3_commons/serializers/msgpack.py +19 -21
- python3_commons/serializers/msgspec.py +50 -27
- {python3_commons-0.0.0.dist-info → python3_commons-0.2.17.dist-info}/METADATA +13 -13
- python3_commons-0.2.17.dist-info/RECORD +30 -0
- {python3_commons-0.0.0.dist-info → python3_commons-0.2.17.dist-info}/WHEEL +1 -1
- python3_commons-0.0.0.dist-info/RECORD +0 -29
- /python3_commons/{logging → log}/__init__.py +0 -0
- /python3_commons/{logging → log}/filters.py +0 -0
- /python3_commons/{logging → log}/formatters.py +0 -0
- {python3_commons-0.0.0.dist-info → python3_commons-0.2.17.dist-info}/licenses/AUTHORS.rst +0 -0
- {python3_commons-0.0.0.dist-info → python3_commons-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {python3_commons-0.0.0.dist-info → python3_commons-0.2.17.dist-info}/top_level.txt +0 -0
python3_commons/api_client.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
|
1
|
+
import errno
|
|
1
2
|
import logging
|
|
3
|
+
from collections.abc import AsyncGenerator, Mapping, Sequence
|
|
2
4
|
from contextlib import asynccontextmanager
|
|
3
5
|
from datetime import UTC, datetime
|
|
4
6
|
from enum import Enum
|
|
5
7
|
from http import HTTPStatus
|
|
6
8
|
from json import dumps
|
|
7
|
-
from typing import
|
|
9
|
+
from typing import Literal
|
|
8
10
|
from uuid import uuid4
|
|
9
11
|
|
|
10
12
|
from aiohttp import ClientResponse, ClientSession, ClientTimeout, client_exceptions
|
|
11
|
-
from
|
|
13
|
+
from aiohttp.abc import URL
|
|
12
14
|
|
|
13
15
|
from python3_commons import audit
|
|
14
16
|
from python3_commons.conf import s3_settings
|
|
@@ -38,7 +40,7 @@ async def _store_response_for_audit(
|
|
|
38
40
|
@asynccontextmanager
|
|
39
41
|
async def request(
|
|
40
42
|
client: ClientSession,
|
|
41
|
-
base_url:
|
|
43
|
+
base_url: str,
|
|
42
44
|
uri: str,
|
|
43
45
|
query: Mapping | None = None,
|
|
44
46
|
method: Literal['get', 'post', 'put', 'patch', 'options', 'head', 'delete'] = 'get',
|
|
@@ -52,18 +54,21 @@ async def request(
|
|
|
52
54
|
date_path = now.strftime('%Y/%m/%d')
|
|
53
55
|
timestamp = now.strftime('%H%M%S_%f')
|
|
54
56
|
request_id = str(uuid4())[-12:]
|
|
55
|
-
uri_path = uri
|
|
56
|
-
uri_path = uri_path
|
|
57
|
+
uri_path = uri.removesuffix('/')
|
|
58
|
+
uri_path = uri_path.removeprefix('/')
|
|
57
59
|
url = f'{u[:-1] if (u := str(base_url)).endswith("/") else u}{uri}'
|
|
58
60
|
|
|
59
61
|
if audit_name:
|
|
60
62
|
curl_request = None
|
|
63
|
+
cookies = client.cookie_jar.filter_cookies(URL(base_url)) if base_url else None
|
|
61
64
|
|
|
62
65
|
if method == 'get':
|
|
63
66
|
if headers or query:
|
|
64
|
-
curl_request = request_to_curl(url, query, method, headers)
|
|
67
|
+
curl_request = request_to_curl(url=url, query=query, method=method, headers=headers, cookies=cookies)
|
|
65
68
|
else:
|
|
66
|
-
curl_request = request_to_curl(
|
|
69
|
+
curl_request = request_to_curl(
|
|
70
|
+
url=url, query=query, method=method, headers=headers, cookies=cookies, json=json, data=data
|
|
71
|
+
)
|
|
67
72
|
|
|
68
73
|
if curl_request:
|
|
69
74
|
await audit.write_audit_data(
|
|
@@ -71,6 +76,7 @@ async def request(
|
|
|
71
76
|
f'{date_path}/{audit_name}/{uri_path}/{method}_{timestamp}_{request_id}_request.txt',
|
|
72
77
|
curl_request.encode('utf-8'),
|
|
73
78
|
)
|
|
79
|
+
|
|
74
80
|
client_method = getattr(client, method)
|
|
75
81
|
|
|
76
82
|
logger.debug(f'Requesting {method} {url}')
|
|
@@ -86,13 +92,25 @@ async def request(
|
|
|
86
92
|
else:
|
|
87
93
|
match response.status:
|
|
88
94
|
case HTTPStatus.UNAUTHORIZED:
|
|
89
|
-
|
|
95
|
+
msg = 'Unauthorized'
|
|
96
|
+
|
|
97
|
+
raise PermissionError(msg)
|
|
90
98
|
case HTTPStatus.FORBIDDEN:
|
|
91
|
-
|
|
99
|
+
msg = 'Forbidden'
|
|
100
|
+
|
|
101
|
+
raise PermissionError(msg)
|
|
92
102
|
case HTTPStatus.NOT_FOUND:
|
|
93
|
-
|
|
103
|
+
msg = 'Not found'
|
|
104
|
+
|
|
105
|
+
raise LookupError(msg)
|
|
94
106
|
case HTTPStatus.BAD_REQUEST:
|
|
95
|
-
|
|
107
|
+
msg = 'Bad request'
|
|
108
|
+
|
|
109
|
+
raise ValueError(msg)
|
|
110
|
+
case HTTPStatus.TOO_MANY_REQUESTS:
|
|
111
|
+
msg = 'Too many requests'
|
|
112
|
+
|
|
113
|
+
raise InterruptedError(msg)
|
|
96
114
|
case _:
|
|
97
115
|
response.raise_for_status()
|
|
98
116
|
else:
|
|
@@ -109,12 +127,22 @@ async def request(
|
|
|
109
127
|
await _store_response_for_audit(response, audit_name, uri_path, method, request_id)
|
|
110
128
|
|
|
111
129
|
yield response
|
|
130
|
+
except client_exceptions.ClientConnectorError as e:
|
|
131
|
+
msg = 'Cient connection error'
|
|
132
|
+
|
|
133
|
+
raise ConnectionRefusedError(msg) from e
|
|
112
134
|
except client_exceptions.ClientOSError as e:
|
|
113
|
-
if e.errno ==
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
raise ConnectionResetError(
|
|
135
|
+
if e.errno == errno.EPIPE:
|
|
136
|
+
msg = 'Broken pipe'
|
|
137
|
+
|
|
138
|
+
raise ConnectionResetError(msg) from e
|
|
139
|
+
elif e.errno == errno.ECONNRESET:
|
|
140
|
+
msg = 'Connection reset by peer'
|
|
141
|
+
|
|
142
|
+
raise ConnectionResetError(msg) from e
|
|
117
143
|
|
|
118
144
|
raise
|
|
119
145
|
except client_exceptions.ServerDisconnectedError as e:
|
|
120
|
-
|
|
146
|
+
msg = 'Server disconnected'
|
|
147
|
+
|
|
148
|
+
raise ConnectionResetError(msg) from e
|
python3_commons/audit.py
CHANGED
|
@@ -1,167 +1,156 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import io
|
|
3
3
|
import logging
|
|
4
|
-
import
|
|
5
|
-
from bz2 import BZ2Compressor
|
|
6
|
-
from collections import deque
|
|
7
|
-
from datetime import UTC, datetime, timedelta
|
|
8
|
-
from typing import Generator, Iterable
|
|
4
|
+
from datetime import UTC, datetime
|
|
9
5
|
from uuid import uuid4
|
|
10
6
|
|
|
11
7
|
from lxml import etree
|
|
12
|
-
from minio import S3Error
|
|
13
8
|
from zeep.plugins import Plugin
|
|
14
9
|
from zeep.wsdl.definitions import AbstractOperation
|
|
15
10
|
|
|
16
11
|
from python3_commons import object_storage
|
|
17
12
|
from python3_commons.conf import S3Settings, s3_settings
|
|
18
|
-
from python3_commons.object_storage import ObjectStorage
|
|
19
13
|
|
|
20
14
|
logger = logging.getLogger(__name__)
|
|
21
15
|
|
|
22
16
|
|
|
23
|
-
class GeneratedStream(io.BytesIO):
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def generate_archive(
|
|
71
|
-
|
|
72
|
-
) ->
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
17
|
+
# class GeneratedStream(io.BytesIO):
|
|
18
|
+
# def __init__(self, generator: AsyncGenerator[bytes], *args, **kwargs):
|
|
19
|
+
# super().__init__(*args, **kwargs)
|
|
20
|
+
# self.generator = generator
|
|
21
|
+
#
|
|
22
|
+
# def read(self, size: int = -1):
|
|
23
|
+
# if size < 0:
|
|
24
|
+
# while True:
|
|
25
|
+
# try:
|
|
26
|
+
# chunk = anext(self.generator)
|
|
27
|
+
# except StopIteration:
|
|
28
|
+
# break
|
|
29
|
+
# else:
|
|
30
|
+
# self.write(chunk)
|
|
31
|
+
# else:
|
|
32
|
+
# total_written_size = 0
|
|
33
|
+
#
|
|
34
|
+
# while total_written_size < size:
|
|
35
|
+
# try:
|
|
36
|
+
# chunk = anext(self.generator)
|
|
37
|
+
# except StopIteration:
|
|
38
|
+
# break
|
|
39
|
+
# else:
|
|
40
|
+
# total_written_size += self.write(chunk)
|
|
41
|
+
#
|
|
42
|
+
# self.seek(0)
|
|
43
|
+
#
|
|
44
|
+
# if chunk := super().read(size):
|
|
45
|
+
# pos = self.tell()
|
|
46
|
+
#
|
|
47
|
+
# buf = self.getbuffer()
|
|
48
|
+
# unread_data_size = len(buf) - pos
|
|
49
|
+
#
|
|
50
|
+
# if unread_data_size > 0:
|
|
51
|
+
# buf[:unread_data_size] = buf[pos : pos + unread_data_size]
|
|
52
|
+
#
|
|
53
|
+
# del buf
|
|
54
|
+
#
|
|
55
|
+
# self.seek(0)
|
|
56
|
+
# self.truncate(unread_data_size)
|
|
57
|
+
#
|
|
58
|
+
# return chunk
|
|
59
|
+
#
|
|
60
|
+
# def readable(self):
|
|
61
|
+
# return True
|
|
62
|
+
#
|
|
63
|
+
#
|
|
64
|
+
# async def generate_archive(
|
|
65
|
+
# objects: AsyncGenerator[tuple[str, datetime, bytes]], chunk_size: int = 4096
|
|
66
|
+
# ) -> AsyncGenerator[bytes]:
|
|
67
|
+
# buffer = deque()
|
|
68
|
+
#
|
|
69
|
+
# with tarfile.open(fileobj=buffer, mode='w') as archive:
|
|
70
|
+
# async for name, last_modified, content in objects:
|
|
71
|
+
# logger.info(f'Adding {name} to archive')
|
|
72
|
+
# info = tarfile.TarInfo(name)
|
|
73
|
+
# info.size = len(content)
|
|
74
|
+
# info.mtime = int(last_modified.timestamp())
|
|
75
|
+
# archive.addfile(info, io.BytesIO(content))
|
|
76
|
+
#
|
|
77
|
+
# buffer_length = buffer.tell()
|
|
78
|
+
#
|
|
79
|
+
# while buffer_length >= chunk_size:
|
|
80
|
+
# buffer.seek(0)
|
|
81
|
+
# chunk = buffer.read(chunk_size)
|
|
82
|
+
# chunk_len = len(chunk)
|
|
83
|
+
#
|
|
84
|
+
# if not chunk:
|
|
85
|
+
# break
|
|
86
|
+
#
|
|
87
|
+
# yield chunk
|
|
88
|
+
#
|
|
89
|
+
# buffer.seek(0)
|
|
90
|
+
# buffer.truncate(chunk_len)
|
|
91
|
+
# buffer.seek(0, io.SEEK_END)
|
|
92
|
+
# buffer_length = buffer.tell()
|
|
93
|
+
#
|
|
94
|
+
# while True:
|
|
95
|
+
# chunk = buffer.read(chunk_size)
|
|
96
|
+
#
|
|
97
|
+
# if not chunk:
|
|
98
|
+
# break
|
|
99
|
+
#
|
|
100
|
+
# yield chunk
|
|
101
|
+
#
|
|
102
|
+
# buffer.seek(0)
|
|
103
|
+
# buffer.truncate(0)
|
|
104
|
+
#
|
|
105
|
+
#
|
|
106
|
+
# async def generate_bzip2(chunks: AsyncGenerator[bytes]) -> AsyncGenerator[bytes]:
|
|
107
|
+
# compressor = BZ2Compressor()
|
|
108
|
+
#
|
|
109
|
+
# async for chunk in chunks:
|
|
110
|
+
# if compressed_chunk := compressor.compress(chunk):
|
|
111
|
+
# yield compressed_chunk
|
|
112
|
+
#
|
|
113
|
+
# if compressed_chunk := compressor.flush():
|
|
114
|
+
# yield compressed_chunk
|
|
115
|
+
#
|
|
116
|
+
|
|
117
|
+
# async def archive_audit_data(root_path: str = 'audit'):
|
|
118
|
+
# now = datetime.now(tz=UTC) - timedelta(days=1)
|
|
119
|
+
# year = now.year
|
|
120
|
+
# month = now.month
|
|
121
|
+
# day = now.day
|
|
122
|
+
# bucket_name = s3_settings.s3_bucket
|
|
123
|
+
# date_path = object_storage.get_absolute_path(f'{root_path}/{year}/{month:02}/{day:02}')
|
|
124
|
+
#
|
|
125
|
+
# if objects := object_storage.get_objects(bucket_name, date_path, recursive=True):
|
|
126
|
+
# logger.info(f'Compacting files in: {date_path}')
|
|
127
|
+
#
|
|
128
|
+
# generator = generate_archive(objects, chunk_size=900_000)
|
|
129
|
+
# bzip2_generator = generate_bzip2(generator)
|
|
130
|
+
# archive_stream = GeneratedStream(bzip2_generator)
|
|
131
|
+
#
|
|
132
|
+
# archive_path = object_storage.get_absolute_path(f'audit/.archive/{year}_{month:02}_{day:02}.tar.bz2')
|
|
133
|
+
# await object_storage.put_object(bucket_name, archive_path, archive_stream, -1, part_size=5 * 1024 * 1024)
|
|
134
|
+
#
|
|
135
|
+
# if errors := await object_storage.remove_objects(bucket_name, date_path):
|
|
136
|
+
# for error in errors:
|
|
137
|
+
# logger.error(f'Failed to delete object in {bucket_name=}: {error}')
|
|
99
138
|
|
|
100
|
-
while True:
|
|
101
|
-
chunk = buffer.read(chunk_size)
|
|
102
139
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
yield chunk
|
|
107
|
-
|
|
108
|
-
buffer.seek(0)
|
|
109
|
-
buffer.truncate(0)
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def generate_bzip2(chunks: Generator[bytes, None, None]) -> Generator[bytes, None, None]:
|
|
113
|
-
compressor = BZ2Compressor()
|
|
114
|
-
|
|
115
|
-
for chunk in chunks:
|
|
116
|
-
if compressed_chunk := compressor.compress(chunk):
|
|
117
|
-
yield compressed_chunk
|
|
118
|
-
|
|
119
|
-
if compressed_chunk := compressor.flush():
|
|
120
|
-
yield compressed_chunk
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
def write_audit_data_sync(settings: S3Settings, key: str, data: bytes):
|
|
124
|
-
if settings.s3_secret_access_key:
|
|
140
|
+
async def write_audit_data(settings: S3Settings, key: str, data: bytes):
|
|
141
|
+
if settings.aws_secret_access_key:
|
|
125
142
|
try:
|
|
126
|
-
client = ObjectStorage(settings).get_client()
|
|
127
143
|
absolute_path = object_storage.get_absolute_path(f'audit/{key}')
|
|
128
144
|
|
|
129
|
-
|
|
130
|
-
except
|
|
131
|
-
logger.
|
|
145
|
+
await object_storage.put_object(settings.s3_bucket, absolute_path, io.BytesIO(data), len(data))
|
|
146
|
+
except Exception:
|
|
147
|
+
logger.exception('Failed storing object in storage.')
|
|
132
148
|
else:
|
|
133
149
|
logger.debug(f'Stored object in storage: {key}')
|
|
134
150
|
else:
|
|
135
151
|
logger.debug(f'S3 is not configured, not storing object in storage: {key}')
|
|
136
152
|
|
|
137
153
|
|
|
138
|
-
async def write_audit_data(settings: S3Settings, key: str, data: bytes):
|
|
139
|
-
await asyncio.to_thread(write_audit_data_sync, settings, key, data)
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
async def archive_audit_data(root_path: str = 'audit'):
|
|
143
|
-
now = datetime.now(tz=UTC) - timedelta(days=1)
|
|
144
|
-
year = now.year
|
|
145
|
-
month = now.month
|
|
146
|
-
day = now.day
|
|
147
|
-
bucket_name = s3_settings.s3_bucket
|
|
148
|
-
date_path = object_storage.get_absolute_path(f'{root_path}/{year}/{month:02}/{day:02}')
|
|
149
|
-
|
|
150
|
-
if objects := object_storage.get_objects(bucket_name, date_path, recursive=True):
|
|
151
|
-
logger.info(f'Compacting files in: {date_path}')
|
|
152
|
-
|
|
153
|
-
generator = generate_archive(objects, chunk_size=900_000)
|
|
154
|
-
bzip2_generator = generate_bzip2(generator)
|
|
155
|
-
archive_stream = GeneratedStream(bzip2_generator)
|
|
156
|
-
|
|
157
|
-
archive_path = object_storage.get_absolute_path(f'audit/.archive/{year}_{month:02}_{day:02}.tar.bz2')
|
|
158
|
-
object_storage.put_object(bucket_name, archive_path, archive_stream, -1, part_size=5 * 1024 * 1024)
|
|
159
|
-
|
|
160
|
-
if errors := object_storage.remove_objects(bucket_name, date_path):
|
|
161
|
-
for error in errors:
|
|
162
|
-
logger.error(f'Failed to delete object in {bucket_name=}: {error}')
|
|
163
|
-
|
|
164
|
-
|
|
165
154
|
class ZeepAuditPlugin(Plugin):
|
|
166
155
|
def __init__(self, audit_name: str = 'zeep'):
|
|
167
156
|
super().__init__()
|
python3_commons/auth.py
CHANGED
|
@@ -1,28 +1,28 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from collections.abc import Callable, Coroutine, MutableMapping, Sequence
|
|
2
3
|
from http import HTTPStatus
|
|
3
|
-
from typing import Annotated
|
|
4
|
+
from typing import Annotated, Any, TypeVar
|
|
4
5
|
|
|
5
6
|
import aiohttp
|
|
7
|
+
import msgspec
|
|
6
8
|
from fastapi import Depends, HTTPException
|
|
7
9
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
8
10
|
from jose import JWTError, jwt
|
|
9
|
-
from pydantic import BaseModel
|
|
10
11
|
|
|
11
12
|
from python3_commons.conf import oidc_settings
|
|
12
13
|
|
|
13
14
|
logger = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
class TokenData(
|
|
17
|
+
class TokenData(msgspec.Struct):
|
|
17
18
|
sub: str
|
|
18
|
-
aud: str
|
|
19
|
+
aud: str | Sequence[str]
|
|
19
20
|
exp: int
|
|
20
21
|
iss: str
|
|
21
22
|
|
|
22
23
|
|
|
24
|
+
T = TypeVar('T', bound=TokenData)
|
|
23
25
|
OIDC_CONFIG_URL = f'{oidc_settings.authority_url}/.well-known/openid-configuration'
|
|
24
|
-
_JWKS: dict | None = None
|
|
25
|
-
|
|
26
26
|
bearer_security = HTTPBearer(auto_error=oidc_settings.enabled)
|
|
27
27
|
|
|
28
28
|
|
|
@@ -30,53 +30,59 @@ async def fetch_openid_config() -> dict:
|
|
|
30
30
|
"""
|
|
31
31
|
Fetch the OpenID configuration (including JWKS URI) from OIDC authority.
|
|
32
32
|
"""
|
|
33
|
-
async with aiohttp.ClientSession() as session:
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
)
|
|
33
|
+
async with aiohttp.ClientSession() as session, session.get(OIDC_CONFIG_URL) as response:
|
|
34
|
+
if response.status != HTTPStatus.OK:
|
|
35
|
+
raise HTTPException(
|
|
36
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch OpenID configuration'
|
|
37
|
+
)
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
return await response.json()
|
|
41
40
|
|
|
42
41
|
|
|
43
42
|
async def fetch_jwks(jwks_uri: str) -> dict:
|
|
44
43
|
"""
|
|
45
44
|
Fetch the JSON Web Key Set (JWKS) for validating the token's signature.
|
|
46
45
|
"""
|
|
47
|
-
async with aiohttp.ClientSession() as session:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
) ->
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
46
|
+
async with aiohttp.ClientSession() as session, session.get(jwks_uri) as response:
|
|
47
|
+
if response.status != HTTPStatus.OK:
|
|
48
|
+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch JWKS')
|
|
49
|
+
|
|
50
|
+
return await response.json()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_token_verifier[T](
|
|
54
|
+
token_cls: type[T],
|
|
55
|
+
jwks: MutableMapping,
|
|
56
|
+
) -> Callable[[HTTPAuthorizationCredentials], Coroutine[Any, Any, T | None]]:
|
|
57
|
+
async def get_verified_token(
|
|
58
|
+
authorization: Annotated[HTTPAuthorizationCredentials, Depends(bearer_security)],
|
|
59
|
+
) -> T | None:
|
|
60
|
+
"""
|
|
61
|
+
Verify the JWT access token using OIDC authority JWKS.
|
|
62
|
+
"""
|
|
63
|
+
if not oidc_settings.enabled:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
token = authorization.credentials
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
if not jwks:
|
|
70
|
+
openid_config = await fetch_openid_config()
|
|
71
|
+
_jwks = await fetch_jwks(openid_config['jwks_uri'])
|
|
72
|
+
jwks.clear()
|
|
73
|
+
jwks.update(_jwks)
|
|
74
|
+
|
|
75
|
+
if oidc_settings.client_id:
|
|
76
|
+
payload = jwt.decode(token, jwks, algorithms=['RS256'], audience=oidc_settings.client_id)
|
|
77
|
+
else:
|
|
78
|
+
payload = jwt.decode(token, jwks, algorithms=['RS256'])
|
|
79
|
+
|
|
80
|
+
token_data = token_cls(**payload)
|
|
81
|
+
except jwt.ExpiredSignatureError as e:
|
|
82
|
+
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail='Token has expired') from e
|
|
83
|
+
except JWTError as e:
|
|
84
|
+
raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=f'Token is invalid: {e!s}') from e
|
|
79
85
|
|
|
80
86
|
return token_data
|
|
81
|
-
|
|
82
|
-
|
|
87
|
+
|
|
88
|
+
return get_verified_token
|