python3-commons 0.0.0__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of python3-commons might be problematic. Click here for more details.

@@ -1,14 +1,16 @@
1
+ import errno
1
2
  import logging
3
+ from collections.abc import AsyncGenerator, Mapping, Sequence
2
4
  from contextlib import asynccontextmanager
3
5
  from datetime import UTC, datetime
4
6
  from enum import Enum
5
7
  from http import HTTPStatus
6
8
  from json import dumps
7
- from typing import AsyncGenerator, Literal, Mapping, Sequence
9
+ from typing import Literal
8
10
  from uuid import uuid4
9
11
 
10
12
  from aiohttp import ClientResponse, ClientSession, ClientTimeout, client_exceptions
11
- from pydantic import HttpUrl
13
+ from aiohttp.abc import URL
12
14
 
13
15
  from python3_commons import audit
14
16
  from python3_commons.conf import s3_settings
@@ -38,7 +40,7 @@ async def _store_response_for_audit(
38
40
  @asynccontextmanager
39
41
  async def request(
40
42
  client: ClientSession,
41
- base_url: HttpUrl,
43
+ base_url: str,
42
44
  uri: str,
43
45
  query: Mapping | None = None,
44
46
  method: Literal['get', 'post', 'put', 'patch', 'options', 'head', 'delete'] = 'get',
@@ -52,18 +54,21 @@ async def request(
52
54
  date_path = now.strftime('%Y/%m/%d')
53
55
  timestamp = now.strftime('%H%M%S_%f')
54
56
  request_id = str(uuid4())[-12:]
55
- uri_path = uri[:-1] if uri.endswith('/') else uri
56
- uri_path = uri_path[1:] if uri_path.startswith('/') else uri_path
57
+ uri_path = uri.removesuffix('/')
58
+ uri_path = uri_path.removeprefix('/')
57
59
  url = f'{u[:-1] if (u := str(base_url)).endswith("/") else u}{uri}'
58
60
 
59
61
  if audit_name:
60
62
  curl_request = None
63
+ cookies = client.cookie_jar.filter_cookies(URL(base_url)) if base_url else None
61
64
 
62
65
  if method == 'get':
63
66
  if headers or query:
64
- curl_request = request_to_curl(url, query, method, headers)
67
+ curl_request = request_to_curl(url=url, query=query, method=method, headers=headers, cookies=cookies)
65
68
  else:
66
- curl_request = request_to_curl(url, query, method, headers, json, data)
69
+ curl_request = request_to_curl(
70
+ url=url, query=query, method=method, headers=headers, cookies=cookies, json=json, data=data
71
+ )
67
72
 
68
73
  if curl_request:
69
74
  await audit.write_audit_data(
@@ -71,6 +76,7 @@ async def request(
71
76
  f'{date_path}/{audit_name}/{uri_path}/{method}_{timestamp}_{request_id}_request.txt',
72
77
  curl_request.encode('utf-8'),
73
78
  )
79
+
74
80
  client_method = getattr(client, method)
75
81
 
76
82
  logger.debug(f'Requesting {method} {url}')
@@ -86,13 +92,25 @@ async def request(
86
92
  else:
87
93
  match response.status:
88
94
  case HTTPStatus.UNAUTHORIZED:
89
- raise PermissionError('Unauthorized')
95
+ msg = 'Unauthorized'
96
+
97
+ raise PermissionError(msg)
90
98
  case HTTPStatus.FORBIDDEN:
91
- raise PermissionError('Forbidden')
99
+ msg = 'Forbidden'
100
+
101
+ raise PermissionError(msg)
92
102
  case HTTPStatus.NOT_FOUND:
93
- raise LookupError('Not found')
103
+ msg = 'Not found'
104
+
105
+ raise LookupError(msg)
94
106
  case HTTPStatus.BAD_REQUEST:
95
- raise ValueError('Bad request')
107
+ msg = 'Bad request'
108
+
109
+ raise ValueError(msg)
110
+ case HTTPStatus.TOO_MANY_REQUESTS:
111
+ msg = 'Too many requests'
112
+
113
+ raise InterruptedError(msg)
96
114
  case _:
97
115
  response.raise_for_status()
98
116
  else:
@@ -109,12 +127,22 @@ async def request(
109
127
  await _store_response_for_audit(response, audit_name, uri_path, method, request_id)
110
128
 
111
129
  yield response
130
+ except client_exceptions.ClientConnectorError as e:
131
+ msg = 'Cient connection error'
132
+
133
+ raise ConnectionRefusedError(msg) from e
112
134
  except client_exceptions.ClientOSError as e:
113
- if e.errno == 32:
114
- raise ConnectionResetError('Broken pipe') from e
115
- elif e.errno == 104:
116
- raise ConnectionResetError('Connection reset by peer') from e
135
+ if e.errno == errno.EPIPE:
136
+ msg = 'Broken pipe'
137
+
138
+ raise ConnectionResetError(msg) from e
139
+ elif e.errno == errno.ECONNRESET:
140
+ msg = 'Connection reset by peer'
141
+
142
+ raise ConnectionResetError(msg) from e
117
143
 
118
144
  raise
119
145
  except client_exceptions.ServerDisconnectedError as e:
120
- raise ConnectionResetError('Server disconnected') from e
146
+ msg = 'Server disconnected'
147
+
148
+ raise ConnectionResetError(msg) from e
python3_commons/audit.py CHANGED
@@ -1,167 +1,156 @@
1
1
  import asyncio
2
2
  import io
3
3
  import logging
4
- import tarfile
5
- from bz2 import BZ2Compressor
6
- from collections import deque
7
- from datetime import UTC, datetime, timedelta
8
- from typing import Generator, Iterable
4
+ from datetime import UTC, datetime
9
5
  from uuid import uuid4
10
6
 
11
7
  from lxml import etree
12
- from minio import S3Error
13
8
  from zeep.plugins import Plugin
14
9
  from zeep.wsdl.definitions import AbstractOperation
15
10
 
16
11
  from python3_commons import object_storage
17
12
  from python3_commons.conf import S3Settings, s3_settings
18
- from python3_commons.object_storage import ObjectStorage
19
13
 
20
14
  logger = logging.getLogger(__name__)
21
15
 
22
16
 
23
- class GeneratedStream(io.BytesIO):
24
- def __init__(self, generator: Generator[bytes, None, None], *args, **kwargs):
25
- super().__init__(*args, **kwargs)
26
- self.generator = generator
27
-
28
- def read(self, size: int = -1):
29
- if size < 0:
30
- while True:
31
- try:
32
- chunk = next(self.generator)
33
- except StopIteration:
34
- break
35
- else:
36
- self.write(chunk)
37
- else:
38
- total_written_size = 0
39
-
40
- while total_written_size < size:
41
- try:
42
- chunk = next(self.generator)
43
- except StopIteration:
44
- break
45
- else:
46
- total_written_size += self.write(chunk)
47
-
48
- self.seek(0)
49
-
50
- if chunk := super().read(size):
51
- pos = self.tell()
52
-
53
- buf = self.getbuffer()
54
- unread_data_size = len(buf) - pos
55
-
56
- if unread_data_size > 0:
57
- buf[:unread_data_size] = buf[pos : pos + unread_data_size]
58
-
59
- del buf
60
-
61
- self.seek(0)
62
- self.truncate(unread_data_size)
63
-
64
- return chunk
65
-
66
- def readable(self):
67
- return True
68
-
69
-
70
- def generate_archive(
71
- objects: Iterable[tuple[str, datetime, bytes]], chunk_size: int = 4096
72
- ) -> Generator[bytes, None, None]:
73
- buffer = deque()
74
-
75
- with tarfile.open(fileobj=buffer, mode='w') as archive:
76
- for name, last_modified, content in objects:
77
- logger.info(f'Adding {name} to archive')
78
- info = tarfile.TarInfo(name)
79
- info.size = len(content)
80
- info.mtime = last_modified.timestamp()
81
- archive.addfile(info, io.BytesIO(content))
82
-
83
- buffer_length = buffer.tell()
84
-
85
- while buffer_length >= chunk_size:
86
- buffer.seek(0)
87
- chunk = buffer.read(chunk_size)
88
- chunk_len = len(chunk)
89
-
90
- if not chunk:
91
- break
92
-
93
- yield chunk
94
-
95
- buffer.seek(0)
96
- buffer.truncate(chunk_len)
97
- buffer.seek(0, io.SEEK_END)
98
- buffer_length = buffer.tell()
17
+ # class GeneratedStream(io.BytesIO):
18
+ # def __init__(self, generator: AsyncGenerator[bytes], *args, **kwargs):
19
+ # super().__init__(*args, **kwargs)
20
+ # self.generator = generator
21
+ #
22
+ # def read(self, size: int = -1):
23
+ # if size < 0:
24
+ # while True:
25
+ # try:
26
+ # chunk = anext(self.generator)
27
+ # except StopIteration:
28
+ # break
29
+ # else:
30
+ # self.write(chunk)
31
+ # else:
32
+ # total_written_size = 0
33
+ #
34
+ # while total_written_size < size:
35
+ # try:
36
+ # chunk = anext(self.generator)
37
+ # except StopIteration:
38
+ # break
39
+ # else:
40
+ # total_written_size += self.write(chunk)
41
+ #
42
+ # self.seek(0)
43
+ #
44
+ # if chunk := super().read(size):
45
+ # pos = self.tell()
46
+ #
47
+ # buf = self.getbuffer()
48
+ # unread_data_size = len(buf) - pos
49
+ #
50
+ # if unread_data_size > 0:
51
+ # buf[:unread_data_size] = buf[pos : pos + unread_data_size]
52
+ #
53
+ # del buf
54
+ #
55
+ # self.seek(0)
56
+ # self.truncate(unread_data_size)
57
+ #
58
+ # return chunk
59
+ #
60
+ # def readable(self):
61
+ # return True
62
+ #
63
+ #
64
+ # async def generate_archive(
65
+ # objects: AsyncGenerator[tuple[str, datetime, bytes]], chunk_size: int = 4096
66
+ # ) -> AsyncGenerator[bytes]:
67
+ # buffer = deque()
68
+ #
69
+ # with tarfile.open(fileobj=buffer, mode='w') as archive:
70
+ # async for name, last_modified, content in objects:
71
+ # logger.info(f'Adding {name} to archive')
72
+ # info = tarfile.TarInfo(name)
73
+ # info.size = len(content)
74
+ # info.mtime = int(last_modified.timestamp())
75
+ # archive.addfile(info, io.BytesIO(content))
76
+ #
77
+ # buffer_length = buffer.tell()
78
+ #
79
+ # while buffer_length >= chunk_size:
80
+ # buffer.seek(0)
81
+ # chunk = buffer.read(chunk_size)
82
+ # chunk_len = len(chunk)
83
+ #
84
+ # if not chunk:
85
+ # break
86
+ #
87
+ # yield chunk
88
+ #
89
+ # buffer.seek(0)
90
+ # buffer.truncate(chunk_len)
91
+ # buffer.seek(0, io.SEEK_END)
92
+ # buffer_length = buffer.tell()
93
+ #
94
+ # while True:
95
+ # chunk = buffer.read(chunk_size)
96
+ #
97
+ # if not chunk:
98
+ # break
99
+ #
100
+ # yield chunk
101
+ #
102
+ # buffer.seek(0)
103
+ # buffer.truncate(0)
104
+ #
105
+ #
106
+ # async def generate_bzip2(chunks: AsyncGenerator[bytes]) -> AsyncGenerator[bytes]:
107
+ # compressor = BZ2Compressor()
108
+ #
109
+ # async for chunk in chunks:
110
+ # if compressed_chunk := compressor.compress(chunk):
111
+ # yield compressed_chunk
112
+ #
113
+ # if compressed_chunk := compressor.flush():
114
+ # yield compressed_chunk
115
+ #
116
+
117
+ # async def archive_audit_data(root_path: str = 'audit'):
118
+ # now = datetime.now(tz=UTC) - timedelta(days=1)
119
+ # year = now.year
120
+ # month = now.month
121
+ # day = now.day
122
+ # bucket_name = s3_settings.s3_bucket
123
+ # date_path = object_storage.get_absolute_path(f'{root_path}/{year}/{month:02}/{day:02}')
124
+ #
125
+ # if objects := object_storage.get_objects(bucket_name, date_path, recursive=True):
126
+ # logger.info(f'Compacting files in: {date_path}')
127
+ #
128
+ # generator = generate_archive(objects, chunk_size=900_000)
129
+ # bzip2_generator = generate_bzip2(generator)
130
+ # archive_stream = GeneratedStream(bzip2_generator)
131
+ #
132
+ # archive_path = object_storage.get_absolute_path(f'audit/.archive/{year}_{month:02}_{day:02}.tar.bz2')
133
+ # await object_storage.put_object(bucket_name, archive_path, archive_stream, -1, part_size=5 * 1024 * 1024)
134
+ #
135
+ # if errors := await object_storage.remove_objects(bucket_name, date_path):
136
+ # for error in errors:
137
+ # logger.error(f'Failed to delete object in {bucket_name=}: {error}')
99
138
 
100
- while True:
101
- chunk = buffer.read(chunk_size)
102
139
 
103
- if not chunk:
104
- break
105
-
106
- yield chunk
107
-
108
- buffer.seek(0)
109
- buffer.truncate(0)
110
-
111
-
112
- def generate_bzip2(chunks: Generator[bytes, None, None]) -> Generator[bytes, None, None]:
113
- compressor = BZ2Compressor()
114
-
115
- for chunk in chunks:
116
- if compressed_chunk := compressor.compress(chunk):
117
- yield compressed_chunk
118
-
119
- if compressed_chunk := compressor.flush():
120
- yield compressed_chunk
121
-
122
-
123
- def write_audit_data_sync(settings: S3Settings, key: str, data: bytes):
124
- if settings.s3_secret_access_key:
140
+ async def write_audit_data(settings: S3Settings, key: str, data: bytes):
141
+ if settings.aws_secret_access_key:
125
142
  try:
126
- client = ObjectStorage(settings).get_client()
127
143
  absolute_path = object_storage.get_absolute_path(f'audit/{key}')
128
144
 
129
- client.put_object(settings.s3_bucket, absolute_path, io.BytesIO(data), len(data))
130
- except S3Error as e:
131
- logger.error(f'Failed storing object in storage: {e}')
145
+ await object_storage.put_object(settings.s3_bucket, absolute_path, io.BytesIO(data), len(data))
146
+ except Exception:
147
+ logger.exception('Failed storing object in storage.')
132
148
  else:
133
149
  logger.debug(f'Stored object in storage: {key}')
134
150
  else:
135
151
  logger.debug(f'S3 is not configured, not storing object in storage: {key}')
136
152
 
137
153
 
138
- async def write_audit_data(settings: S3Settings, key: str, data: bytes):
139
- await asyncio.to_thread(write_audit_data_sync, settings, key, data)
140
-
141
-
142
- async def archive_audit_data(root_path: str = 'audit'):
143
- now = datetime.now(tz=UTC) - timedelta(days=1)
144
- year = now.year
145
- month = now.month
146
- day = now.day
147
- bucket_name = s3_settings.s3_bucket
148
- date_path = object_storage.get_absolute_path(f'{root_path}/{year}/{month:02}/{day:02}')
149
-
150
- if objects := object_storage.get_objects(bucket_name, date_path, recursive=True):
151
- logger.info(f'Compacting files in: {date_path}')
152
-
153
- generator = generate_archive(objects, chunk_size=900_000)
154
- bzip2_generator = generate_bzip2(generator)
155
- archive_stream = GeneratedStream(bzip2_generator)
156
-
157
- archive_path = object_storage.get_absolute_path(f'audit/.archive/{year}_{month:02}_{day:02}.tar.bz2')
158
- object_storage.put_object(bucket_name, archive_path, archive_stream, -1, part_size=5 * 1024 * 1024)
159
-
160
- if errors := object_storage.remove_objects(bucket_name, date_path):
161
- for error in errors:
162
- logger.error(f'Failed to delete object in {bucket_name=}: {error}')
163
-
164
-
165
154
  class ZeepAuditPlugin(Plugin):
166
155
  def __init__(self, audit_name: str = 'zeep'):
167
156
  super().__init__()
python3_commons/auth.py CHANGED
@@ -1,28 +1,28 @@
1
1
  import logging
2
+ from collections.abc import Callable, Coroutine, MutableMapping, Sequence
2
3
  from http import HTTPStatus
3
- from typing import Annotated
4
+ from typing import Annotated, Any, TypeVar
4
5
 
5
6
  import aiohttp
7
+ import msgspec
6
8
  from fastapi import Depends, HTTPException
7
9
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
8
10
  from jose import JWTError, jwt
9
- from pydantic import BaseModel
10
11
 
11
12
  from python3_commons.conf import oidc_settings
12
13
 
13
14
  logger = logging.getLogger(__name__)
14
15
 
15
16
 
16
- class TokenData(BaseModel):
17
+ class TokenData(msgspec.Struct):
17
18
  sub: str
18
- aud: str
19
+ aud: str | Sequence[str]
19
20
  exp: int
20
21
  iss: str
21
22
 
22
23
 
24
+ T = TypeVar('T', bound=TokenData)
23
25
  OIDC_CONFIG_URL = f'{oidc_settings.authority_url}/.well-known/openid-configuration'
24
- _JWKS: dict | None = None
25
-
26
26
  bearer_security = HTTPBearer(auto_error=oidc_settings.enabled)
27
27
 
28
28
 
@@ -30,53 +30,59 @@ async def fetch_openid_config() -> dict:
30
30
  """
31
31
  Fetch the OpenID configuration (including JWKS URI) from OIDC authority.
32
32
  """
33
- async with aiohttp.ClientSession() as session:
34
- async with session.get(OIDC_CONFIG_URL) as response:
35
- if response.status != HTTPStatus.OK:
36
- raise HTTPException(
37
- status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch OpenID configuration'
38
- )
33
+ async with aiohttp.ClientSession() as session, session.get(OIDC_CONFIG_URL) as response:
34
+ if response.status != HTTPStatus.OK:
35
+ raise HTTPException(
36
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch OpenID configuration'
37
+ )
39
38
 
40
- return await response.json()
39
+ return await response.json()
41
40
 
42
41
 
43
42
  async def fetch_jwks(jwks_uri: str) -> dict:
44
43
  """
45
44
  Fetch the JSON Web Key Set (JWKS) for validating the token's signature.
46
45
  """
47
- async with aiohttp.ClientSession() as session:
48
- async with session.get(jwks_uri) as response:
49
- if response.status != HTTPStatus.OK:
50
- raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch JWKS')
51
-
52
- return await response.json()
53
-
54
-
55
- async def get_verified_token(
56
- authorization: Annotated[HTTPAuthorizationCredentials, Depends(bearer_security)],
57
- ) -> TokenData | None:
58
- """
59
- Verify the JWT access token using OIDC authority JWKS.
60
- """
61
- global _JWKS
62
-
63
- if not oidc_settings.enabled:
64
- return None
65
-
66
- token = authorization.credentials
67
-
68
- try:
69
- if not _JWKS:
70
- openid_config = await fetch_openid_config()
71
- _JWKS = await fetch_jwks(openid_config['jwks_uri'])
72
-
73
- if oidc_settings.client_id:
74
- payload = jwt.decode(token, _JWKS, algorithms=['RS256'], audience=oidc_settings.client_id)
75
- else:
76
- payload = jwt.decode(token, _JWKS, algorithms=['RS256'])
77
-
78
- token_data = TokenData(**payload)
46
+ async with aiohttp.ClientSession() as session, session.get(jwks_uri) as response:
47
+ if response.status != HTTPStatus.OK:
48
+ raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch JWKS')
49
+
50
+ return await response.json()
51
+
52
+
53
+ def get_token_verifier[T](
54
+ token_cls: type[T],
55
+ jwks: MutableMapping,
56
+ ) -> Callable[[HTTPAuthorizationCredentials], Coroutine[Any, Any, T | None]]:
57
+ async def get_verified_token(
58
+ authorization: Annotated[HTTPAuthorizationCredentials, Depends(bearer_security)],
59
+ ) -> T | None:
60
+ """
61
+ Verify the JWT access token using OIDC authority JWKS.
62
+ """
63
+ if not oidc_settings.enabled:
64
+ return None
65
+
66
+ token = authorization.credentials
67
+
68
+ try:
69
+ if not jwks:
70
+ openid_config = await fetch_openid_config()
71
+ _jwks = await fetch_jwks(openid_config['jwks_uri'])
72
+ jwks.clear()
73
+ jwks.update(_jwks)
74
+
75
+ if oidc_settings.client_id:
76
+ payload = jwt.decode(token, jwks, algorithms=['RS256'], audience=oidc_settings.client_id)
77
+ else:
78
+ payload = jwt.decode(token, jwks, algorithms=['RS256'])
79
+
80
+ token_data = token_cls(**payload)
81
+ except jwt.ExpiredSignatureError as e:
82
+ raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail='Token has expired') from e
83
+ except JWTError as e:
84
+ raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail=f'Token is invalid: {e!s}') from e
79
85
 
80
86
  return token_data
81
- except JWTError as e:
82
- raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=f'Token is invalid: {str(e)}')
87
+
88
+ return get_verified_token