panther 5.0.0b3__py3-none-any.whl → 5.0.0b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- panther/__init__.py +1 -1
- panther/_load_configs.py +46 -37
- panther/_utils.py +49 -34
- panther/app.py +96 -97
- panther/authentications.py +97 -50
- panther/background_tasks.py +98 -124
- panther/base_request.py +16 -10
- panther/base_websocket.py +8 -8
- panther/caching.py +16 -80
- panther/cli/create_command.py +17 -16
- panther/cli/main.py +1 -1
- panther/cli/monitor_command.py +11 -6
- panther/cli/run_command.py +5 -71
- panther/cli/template.py +7 -7
- panther/cli/utils.py +58 -69
- panther/configs.py +70 -72
- panther/db/connections.py +30 -24
- panther/db/cursor.py +3 -1
- panther/db/models.py +26 -10
- panther/db/queries/base_queries.py +4 -5
- panther/db/queries/mongodb_queries.py +21 -21
- panther/db/queries/pantherdb_queries.py +1 -1
- panther/db/queries/queries.py +26 -8
- panther/db/utils.py +1 -1
- panther/events.py +25 -14
- panther/exceptions.py +2 -7
- panther/file_handler.py +1 -1
- panther/generics.py +74 -100
- panther/logging.py +2 -1
- panther/main.py +12 -13
- panther/middlewares/cors.py +67 -0
- panther/middlewares/monitoring.py +5 -3
- panther/openapi/urls.py +2 -2
- panther/openapi/utils.py +3 -3
- panther/openapi/views.py +20 -37
- panther/pagination.py +4 -2
- panther/panel/apis.py +2 -7
- panther/panel/urls.py +2 -6
- panther/panel/utils.py +9 -5
- panther/panel/views.py +13 -22
- panther/permissions.py +2 -1
- panther/request.py +2 -1
- panther/response.py +101 -94
- panther/routings.py +12 -12
- panther/serializer.py +20 -43
- panther/test.py +73 -58
- panther/throttling.py +68 -3
- panther/utils.py +5 -11
- panther-5.0.0b5.dist-info/METADATA +188 -0
- panther-5.0.0b5.dist-info/RECORD +75 -0
- panther/monitoring.py +0 -34
- panther-5.0.0b3.dist-info/METADATA +0 -223
- panther-5.0.0b3.dist-info/RECORD +0 -75
- {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/WHEEL +0 -0
- {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/entry_points.txt +0 -0
- {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/licenses/LICENSE +0 -0
- {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/top_level.txt +0 -0
panther/app.py
CHANGED
@@ -2,36 +2,35 @@ import functools
|
|
2
2
|
import logging
|
3
3
|
import traceback
|
4
4
|
import typing
|
5
|
+
from collections.abc import Callable
|
5
6
|
from datetime import timedelta
|
6
|
-
from typing import Literal
|
7
|
+
from typing import Literal
|
7
8
|
|
8
9
|
from orjson import JSONDecodeError
|
9
|
-
from pydantic import
|
10
|
+
from pydantic import BaseModel, ValidationError
|
10
11
|
|
11
12
|
from panther._utils import is_function_async
|
13
|
+
from panther.base_request import BaseRequest
|
12
14
|
from panther.caching import (
|
13
15
|
get_response_from_cache,
|
14
16
|
set_response_in_cache,
|
15
|
-
get_throttling_from_cache,
|
16
|
-
increment_throttling_in_cache
|
17
17
|
)
|
18
18
|
from panther.configs import config
|
19
19
|
from panther.exceptions import (
|
20
20
|
APIError,
|
21
21
|
AuthorizationAPIError,
|
22
|
+
BadRequestAPIError,
|
22
23
|
JSONDecodeAPIError,
|
23
24
|
MethodNotAllowedAPIError,
|
24
|
-
|
25
|
-
BadRequestAPIError
|
25
|
+
PantherError,
|
26
26
|
)
|
27
|
-
from panther.exceptions import PantherError
|
28
27
|
from panther.middlewares import HTTPMiddleware
|
29
28
|
from panther.openapi import OutputSchema
|
30
29
|
from panther.permissions import BasePermission
|
31
30
|
from panther.request import Request
|
32
31
|
from panther.response import Response
|
33
32
|
from panther.serializer import ModelSerializer
|
34
|
-
from panther.throttling import
|
33
|
+
from panther.throttling import Throttle
|
35
34
|
|
36
35
|
__all__ = ('API', 'GenericAPI')
|
37
36
|
|
@@ -40,19 +39,20 @@ logger = logging.getLogger('panther')
|
|
40
39
|
|
41
40
|
class API:
|
42
41
|
"""
|
42
|
+
methods: Specify the allowed methods.
|
43
43
|
input_model: The `request.data` will be validated with this attribute, It will raise an
|
44
44
|
`panther.exceptions.BadRequestAPIError` or put the validated data in the `request.validated_data`.
|
45
|
+
output_model: The `response.data` will be passed through this class to filter its attributes.
|
45
46
|
output_schema: This attribute only used in creation of OpenAPI scheme which is available in `panther.openapi.urls`
|
46
47
|
You may want to add its `url` to your urls.
|
47
48
|
auth: It will authenticate the user with header of its request or raise an
|
48
49
|
`panther.exceptions.AuthenticationAPIError`.
|
49
50
|
permissions: List of permissions that will be called sequentially after authentication to authorize the user.
|
50
|
-
throttling: It will limit the users' request on a specific (time-
|
51
|
-
cache:
|
52
|
-
cache_exp_time: Specify the expiry time of the cache. (default is `config.DEFAULT_CACHE_EXP`)
|
53
|
-
methods: Specify the allowed methods.
|
51
|
+
throttling: It will limit the users' request on a specific (time-window, path)
|
52
|
+
cache: Specify the duration of the cache (Will be used only in GET requests).
|
54
53
|
middlewares: These middlewares have inner priority than global middlewares.
|
55
54
|
"""
|
55
|
+
|
56
56
|
func: Callable
|
57
57
|
|
58
58
|
def __init__(
|
@@ -60,36 +60,62 @@ class API:
|
|
60
60
|
*,
|
61
61
|
methods: list[Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE']] | None = None,
|
62
62
|
input_model: type[ModelSerializer] | type[BaseModel] | None = None,
|
63
|
-
output_model: type[BaseModel] | None = None,
|
63
|
+
output_model: type[ModelSerializer] | type[BaseModel] | None = None,
|
64
64
|
output_schema: OutputSchema | None = None,
|
65
65
|
auth: bool = False,
|
66
|
-
permissions: list[BasePermission] | None = None,
|
67
|
-
throttling:
|
68
|
-
cache:
|
69
|
-
|
70
|
-
|
66
|
+
permissions: list[type[BasePermission]] | None = None,
|
67
|
+
throttling: Throttle | None = None,
|
68
|
+
cache: timedelta | None = None,
|
69
|
+
middlewares: list[type[HTTPMiddleware]] | None = None,
|
70
|
+
**kwargs,
|
71
71
|
):
|
72
|
-
self.methods = {m.upper() for m in methods} if methods else
|
72
|
+
self.methods = {m.upper() for m in methods} if methods else {'GET', 'POST', 'PUT', 'PATCH', 'DELETE'}
|
73
73
|
self.input_model = input_model
|
74
|
+
self.output_model = output_model
|
74
75
|
self.output_schema = output_schema
|
75
76
|
self.auth = auth
|
76
77
|
self.permissions = permissions or []
|
77
78
|
self.throttling = throttling
|
78
79
|
self.cache = cache
|
79
|
-
self.
|
80
|
-
self.middlewares: list[HTTPMiddleware] | None = middlewares
|
80
|
+
self.middlewares = middlewares
|
81
81
|
self.request: Request | None = None
|
82
|
-
if
|
82
|
+
if kwargs.pop('cache_exp_time', None):
|
83
|
+
deprecation_message = (
|
84
|
+
traceback.format_stack(limit=2)[0]
|
85
|
+
+ '\nThe `cache_exp_time` argument has been removed in Panther v5 and is no longer available.'
|
86
|
+
'\nYou may want to use `cache` instead.'
|
87
|
+
)
|
88
|
+
raise PantherError(deprecation_message)
|
89
|
+
# Validate Cache
|
90
|
+
if self.cache and not isinstance(self.cache, timedelta):
|
83
91
|
deprecation_message = (
|
84
|
-
|
85
|
-
|
86
|
-
'\nPlease update your code to use the new approach. More info: '
|
87
|
-
'https://pantherpy.github.io/open_api/'
|
92
|
+
traceback.format_stack(limit=2)[0] + '\nThe `cache` argument has been changed in Panther v5, '
|
93
|
+
'it should be an instance of `datetime.timedelta()`.'
|
88
94
|
)
|
89
95
|
raise PantherError(deprecation_message)
|
96
|
+
assert self.cache is None or isinstance(self.cache, timedelta)
|
97
|
+
# Validate Permissions
|
98
|
+
for perm in self.permissions:
|
99
|
+
if is_function_async(perm.authorization) is False:
|
100
|
+
msg = f'{perm.__name__}.authorization() should be `async`'
|
101
|
+
logger.error(msg)
|
102
|
+
raise PantherError(msg)
|
103
|
+
if type(perm.authorization).__name__ != 'method':
|
104
|
+
msg = f'{perm.__name__}.authorization() should be `@classmethod`'
|
105
|
+
logger.error(msg)
|
106
|
+
raise PantherError(msg)
|
107
|
+
# Check kwargs
|
108
|
+
if kwargs:
|
109
|
+
msg = f'Unknown kwargs: {kwargs.keys()}'
|
110
|
+
logger.error(msg)
|
111
|
+
raise PantherError(msg)
|
90
112
|
|
91
113
|
def __call__(self, func):
|
92
114
|
self.func = func
|
115
|
+
self.is_function_async = is_function_async(self.func)
|
116
|
+
self.function_annotations = {
|
117
|
+
k: v for k, v in func.__annotations__.items() if v in {BaseRequest, Request, bool, int}
|
118
|
+
}
|
93
119
|
|
94
120
|
@functools.wraps(func)
|
95
121
|
async def wrapper(request: Request) -> Response:
|
@@ -110,37 +136,40 @@ class API:
|
|
110
136
|
async def handle_endpoint(self, request: Request) -> Response:
|
111
137
|
self.request = request
|
112
138
|
|
113
|
-
# 0. Preflight
|
114
|
-
if self.request.method == 'OPTIONS':
|
115
|
-
return self.options()
|
116
|
-
|
117
139
|
# 1. Check Method
|
118
|
-
if self.
|
140
|
+
if self.request.method not in self.methods:
|
119
141
|
raise MethodNotAllowedAPIError
|
120
142
|
|
121
143
|
# 2. Authentication
|
122
|
-
|
144
|
+
if self.auth:
|
145
|
+
if not config.AUTHENTICATION:
|
146
|
+
logger.critical('"AUTHENTICATION" has not been set in configs')
|
147
|
+
raise APIError
|
148
|
+
self.request.user = await config.AUTHENTICATION.authentication(self.request)
|
123
149
|
|
124
150
|
# 3. Permissions
|
125
|
-
|
151
|
+
for perm in self.permissions:
|
152
|
+
if await perm.authorization(self.request) is False:
|
153
|
+
raise AuthorizationAPIError
|
126
154
|
|
127
|
-
# 4.
|
128
|
-
|
155
|
+
# 4. Throttle
|
156
|
+
if throttling := self.throttling or config.THROTTLING:
|
157
|
+
await throttling.check_and_increment(request=self.request)
|
129
158
|
|
130
159
|
# 5. Validate Input
|
131
|
-
if self.request.method in {'POST', 'PUT', 'PATCH'}:
|
132
|
-
self.
|
160
|
+
if self.input_model and self.request.method in {'POST', 'PUT', 'PATCH'}:
|
161
|
+
self.request.validated_data = self.validate_input(model=self.input_model, request=self.request)
|
133
162
|
|
134
163
|
# 6. Get Cached Response
|
135
164
|
if self.cache and self.request.method == 'GET':
|
136
|
-
if cached := await get_response_from_cache(request=self.request,
|
165
|
+
if cached := await get_response_from_cache(request=self.request, duration=self.cache):
|
137
166
|
return Response(data=cached.data, headers=cached.headers, status_code=cached.status_code)
|
138
167
|
|
139
168
|
# 7. Put PathVariables and Request(If User Wants It) In kwargs
|
140
|
-
kwargs = self.request.clean_parameters(self.
|
169
|
+
kwargs = self.request.clean_parameters(self.function_annotations)
|
141
170
|
|
142
171
|
# 8. Call Endpoint
|
143
|
-
if
|
172
|
+
if self.is_function_async:
|
144
173
|
response = await self.func(**kwargs)
|
145
174
|
else:
|
146
175
|
response = self.func(**kwargs)
|
@@ -148,53 +177,17 @@ class API:
|
|
148
177
|
# 9. Clean Response
|
149
178
|
if not isinstance(response, Response):
|
150
179
|
response = Response(data=response)
|
180
|
+
if self.output_model and response.data:
|
181
|
+
response.data = await response.serialize_output(output_model=self.output_model)
|
151
182
|
if response.pagination:
|
152
183
|
response.data = await response.pagination.template(response.data)
|
153
184
|
|
154
185
|
# 10. Set New Response To Cache
|
155
186
|
if self.cache and self.request.method == 'GET':
|
156
|
-
await set_response_in_cache(request=self.request, response=response,
|
157
|
-
|
158
|
-
# 11. Warning CacheExpTime
|
159
|
-
if self.cache_exp_time and self.cache is False:
|
160
|
-
logger.warning('"cache_exp_time" won\'t work while "cache" is False')
|
187
|
+
await set_response_in_cache(request=self.request, response=response, duration=self.cache)
|
161
188
|
|
162
189
|
return response
|
163
190
|
|
164
|
-
async def handle_authentication(self) -> None:
|
165
|
-
if self.auth:
|
166
|
-
if not config.AUTHENTICATION:
|
167
|
-
logger.critical('"AUTHENTICATION" has not been set in configs')
|
168
|
-
raise APIError
|
169
|
-
self.request.user = await config.AUTHENTICATION.authentication(self.request)
|
170
|
-
|
171
|
-
async def handle_throttling(self) -> None:
|
172
|
-
if throttling := self.throttling or config.THROTTLING:
|
173
|
-
if await get_throttling_from_cache(self.request, duration=throttling.duration) + 1 > throttling.rate:
|
174
|
-
raise ThrottlingAPIError
|
175
|
-
|
176
|
-
await increment_throttling_in_cache(self.request, duration=throttling.duration)
|
177
|
-
|
178
|
-
async def handle_permission(self) -> None:
|
179
|
-
for perm in self.permissions:
|
180
|
-
if type(perm.authorization).__name__ != 'method':
|
181
|
-
logger.error(f'{perm.__name__}.authorization should be "classmethod"')
|
182
|
-
raise AuthorizationAPIError
|
183
|
-
if await perm.authorization(self.request) is False:
|
184
|
-
raise AuthorizationAPIError
|
185
|
-
|
186
|
-
def handle_input_validation(self):
|
187
|
-
if self.input_model:
|
188
|
-
self.request.validated_data = self.validate_input(model=self.input_model, request=self.request)
|
189
|
-
|
190
|
-
@classmethod
|
191
|
-
def options(cls):
|
192
|
-
headers = {
|
193
|
-
'Access-Control-Allow-Methods': 'DELETE, GET, PATCH, POST, PUT, OPTIONS, HEAD',
|
194
|
-
'Access-Control-Allow-Headers': 'Accept, Authorization, User-Agent, Content-Type',
|
195
|
-
}
|
196
|
-
return Response(headers=headers)
|
197
|
-
|
198
191
|
@classmethod
|
199
192
|
def validate_input(cls, model, request: Request):
|
200
193
|
if isinstance(request.data, bytes):
|
@@ -212,21 +205,15 @@ class API:
|
|
212
205
|
|
213
206
|
|
214
207
|
class MetaGenericAPI(type):
|
215
|
-
def __new__(
|
216
|
-
cls,
|
217
|
-
cls_name: str,
|
218
|
-
bases: tuple[type[typing.Any], ...],
|
219
|
-
namespace: dict[str, typing.Any],
|
220
|
-
**kwargs
|
221
|
-
):
|
208
|
+
def __new__(cls, cls_name: str, bases: tuple[type[typing.Any], ...], namespace: dict[str, typing.Any], **kwargs):
|
222
209
|
if cls_name == 'GenericAPI':
|
223
210
|
return super().__new__(cls, cls_name, bases, namespace)
|
224
211
|
if 'output_model' in namespace:
|
225
212
|
deprecation_message = (
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
213
|
+
traceback.format_stack(limit=2)[0]
|
214
|
+
+ '\nThe `output_model` argument has been removed in Panther v5 and is no longer available.'
|
215
|
+
'\nPlease update your code to use the new approach. More info: '
|
216
|
+
'https://pantherpy.github.io/open_api/'
|
230
217
|
)
|
231
218
|
raise PantherError(deprecation_message)
|
232
219
|
return super().__new__(cls, cls_name, bases, namespace)
|
@@ -236,15 +223,29 @@ class GenericAPI(metaclass=MetaGenericAPI):
|
|
236
223
|
"""
|
237
224
|
Check out the documentation of `panther.app.API()`.
|
238
225
|
"""
|
226
|
+
|
239
227
|
input_model: type[ModelSerializer] | type[BaseModel] | None = None
|
228
|
+
output_model: type[ModelSerializer] | type[BaseModel] | None = None
|
240
229
|
output_schema: OutputSchema | None = None
|
241
230
|
auth: bool = False
|
242
|
-
permissions: list | None = None
|
243
|
-
throttling:
|
244
|
-
cache:
|
245
|
-
cache_exp_time: timedelta | int | None = None
|
231
|
+
permissions: list[type[BasePermission]] | None = None
|
232
|
+
throttling: Throttle | None = None
|
233
|
+
cache: timedelta | None = None
|
246
234
|
middlewares: list[HTTPMiddleware] | None = None
|
247
235
|
|
236
|
+
def __init_subclass__(cls, **kwargs):
|
237
|
+
# Creating API instance to validate the attributes.
|
238
|
+
API(
|
239
|
+
input_model=cls.input_model,
|
240
|
+
output_model=cls.output_model,
|
241
|
+
output_schema=cls.output_schema,
|
242
|
+
auth=cls.auth,
|
243
|
+
permissions=cls.permissions,
|
244
|
+
throttling=cls.throttling,
|
245
|
+
cache=cls.cache,
|
246
|
+
middlewares=cls.middlewares,
|
247
|
+
)
|
248
|
+
|
248
249
|
async def get(self, *args, **kwargs):
|
249
250
|
raise MethodNotAllowedAPIError
|
250
251
|
|
@@ -272,18 +273,16 @@ class GenericAPI(metaclass=MetaGenericAPI):
|
|
272
273
|
func = self.patch
|
273
274
|
case 'DELETE':
|
274
275
|
func = self.delete
|
275
|
-
case 'OPTIONS':
|
276
|
-
func = API.options
|
277
276
|
case _:
|
278
277
|
raise MethodNotAllowedAPIError
|
279
278
|
|
280
279
|
return await API(
|
281
280
|
input_model=self.input_model,
|
281
|
+
output_model=self.output_model,
|
282
282
|
output_schema=self.output_schema,
|
283
283
|
auth=self.auth,
|
284
284
|
permissions=self.permissions,
|
285
285
|
throttling=self.throttling,
|
286
286
|
cache=self.cache,
|
287
|
-
cache_exp_time=self.cache_exp_time,
|
288
287
|
middlewares=self.middlewares,
|
289
288
|
)(func)(request=request)
|
panther/authentications.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import logging
|
2
2
|
import time
|
3
3
|
from abc import abstractmethod
|
4
|
-
from datetime import
|
4
|
+
from datetime import datetime, timezone
|
5
5
|
from typing import Literal
|
6
6
|
|
7
7
|
from panther.base_websocket import Websocket
|
@@ -36,24 +36,44 @@ class BaseAuthentication:
|
|
36
36
|
|
37
37
|
|
38
38
|
class JWTAuthentication(BaseAuthentication):
|
39
|
-
|
39
|
+
"""
|
40
|
+
Retrieve the Authorization from header
|
41
|
+
Example:
|
42
|
+
Headers: {'authorization': 'Bearer the_jwt_token'}
|
43
|
+
"""
|
44
|
+
|
45
|
+
model = None
|
40
46
|
keyword = 'Bearer'
|
41
47
|
algorithm = 'HS256'
|
42
48
|
HTTP_HEADER_ENCODING = 'iso-8859-1' # RFC5987
|
43
49
|
|
44
50
|
@classmethod
|
45
|
-
def get_authorization_header(cls, request: Request | Websocket) -> str:
|
51
|
+
def get_authorization_header(cls, request: Request | Websocket) -> list[str]:
|
52
|
+
"""Retrieve the Authorization header from the request."""
|
46
53
|
if auth := request.headers.authorization:
|
47
|
-
return auth
|
54
|
+
return auth.split()
|
48
55
|
msg = 'Authorization is required'
|
49
56
|
raise cls.exception(msg) from None
|
50
57
|
|
51
58
|
@classmethod
|
52
59
|
async def authentication(cls, request: Request | Websocket) -> Model:
|
53
|
-
|
60
|
+
"""Authenticate the user based on the JWT token in the Authorization header."""
|
61
|
+
auth_header = cls.get_authorization_header(request)
|
62
|
+
token = cls.get_token(auth_header=auth_header)
|
63
|
+
|
64
|
+
if redis.is_connected and await cls.is_token_revoked(token=token):
|
65
|
+
msg = 'User logged out'
|
66
|
+
raise cls.exception(msg) from None
|
54
67
|
|
68
|
+
payload = await cls.decode_jwt(token)
|
69
|
+
user = await cls.get_user(payload)
|
70
|
+
user._auth_token = token
|
71
|
+
return user
|
72
|
+
|
73
|
+
@classmethod
|
74
|
+
def get_token(cls, auth_header):
|
55
75
|
if len(auth_header) != 2:
|
56
|
-
msg = 'Authorization
|
76
|
+
msg = 'Authorization header must contain 2 parts'
|
57
77
|
raise cls.exception(msg) from None
|
58
78
|
|
59
79
|
bearer, token = auth_header
|
@@ -67,18 +87,11 @@ class JWTAuthentication(BaseAuthentication):
|
|
67
87
|
msg = 'Authorization keyword is not valid'
|
68
88
|
raise cls.exception(msg) from None
|
69
89
|
|
70
|
-
|
71
|
-
msg = 'User logged out'
|
72
|
-
raise cls.exception(msg) from None
|
73
|
-
|
74
|
-
payload = cls.decode_jwt(token)
|
75
|
-
user = await cls.get_user(payload)
|
76
|
-
user._auth_token = token
|
77
|
-
return user
|
90
|
+
return token
|
78
91
|
|
79
92
|
@classmethod
|
80
|
-
def decode_jwt(cls, token: str) -> dict:
|
81
|
-
"""Decode JWT token
|
93
|
+
async def decode_jwt(cls, token: str) -> dict:
|
94
|
+
"""Decode a JWT token and return the payload."""
|
82
95
|
try:
|
83
96
|
return jwt.decode(
|
84
97
|
token=token,
|
@@ -90,21 +103,21 @@ class JWTAuthentication(BaseAuthentication):
|
|
90
103
|
|
91
104
|
@classmethod
|
92
105
|
async def get_user(cls, payload: dict) -> Model:
|
93
|
-
"""
|
106
|
+
"""Fetch the user based on the decoded JWT payload from cls.model or config.UserModel"""
|
94
107
|
if (user_id := payload.get('user_id')) is None:
|
95
108
|
msg = 'Payload does not have `user_id`'
|
96
109
|
raise cls.exception(msg)
|
97
110
|
|
98
|
-
user_model =
|
99
|
-
|
100
|
-
|
111
|
+
user_model = cls.model or config.USER_MODEL
|
112
|
+
user = await user_model.find_one(id=user_id)
|
113
|
+
if user is None:
|
114
|
+
raise cls.exception('User not found')
|
101
115
|
|
102
|
-
|
103
|
-
raise cls.exception(msg) from None
|
116
|
+
return user
|
104
117
|
|
105
118
|
@classmethod
|
106
119
|
def encode_jwt(cls, user_id: str, token_type: Literal['access', 'refresh'] = 'access') -> str:
|
107
|
-
"""
|
120
|
+
"""Generate a JWT token for a given user ID."""
|
108
121
|
issued_at = datetime.now(timezone.utc).timestamp()
|
109
122
|
if token_type == 'access':
|
110
123
|
expire = issued_at + config.JWT_CONFIG.life_time
|
@@ -124,44 +137,87 @@ class JWTAuthentication(BaseAuthentication):
|
|
124
137
|
)
|
125
138
|
|
126
139
|
@classmethod
|
127
|
-
def login(cls,
|
128
|
-
"""
|
140
|
+
async def login(cls, user) -> dict:
|
141
|
+
"""Generate access and refresh tokens for user login."""
|
129
142
|
return {
|
130
|
-
'access_token': cls.encode_jwt(user_id=
|
131
|
-
'refresh_token': cls.encode_jwt(user_id=
|
143
|
+
'access_token': cls.encode_jwt(user_id=user.id),
|
144
|
+
'refresh_token': cls.encode_jwt(user_id=user.id, token_type='refresh'),
|
132
145
|
}
|
133
146
|
|
134
147
|
@classmethod
|
135
|
-
async def logout(cls,
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
148
|
+
async def logout(cls, user) -> None:
|
149
|
+
"""Log out a user by revoking their JWT token."""
|
150
|
+
payload = await cls.decode_jwt(token=user._auth_token)
|
151
|
+
await cls.revoke_token_in_cache(token=user._auth_token, exp=payload['exp'])
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
async def refresh(cls, user):
|
155
|
+
if hasattr(user, '_auth_refresh_token'):
|
156
|
+
# It happens in CookieJWTAuthentication
|
157
|
+
token = user._auth_refresh_token
|
141
158
|
else:
|
142
|
-
|
159
|
+
token = user._auth_token
|
160
|
+
|
161
|
+
payload = await cls.decode_jwt(token=token)
|
162
|
+
|
163
|
+
if payload['token_type'] != 'refresh':
|
164
|
+
raise cls.exception('Invalid token type; expected `refresh` token.')
|
165
|
+
# Revoke after use
|
166
|
+
await cls.revoke_token_in_cache(token=token, exp=payload['exp'])
|
167
|
+
|
168
|
+
return await cls.login(user=user)
|
143
169
|
|
144
170
|
@classmethod
|
145
|
-
async def
|
146
|
-
|
147
|
-
|
171
|
+
async def revoke_token_in_cache(cls, token: str, exp: int) -> None:
|
172
|
+
"""Mark the token as revoked in the cache."""
|
173
|
+
if redis.is_connected:
|
174
|
+
key = generate_hash_value_from_string(token)
|
175
|
+
remaining_exp_time = int(exp - time.time())
|
176
|
+
await redis.set(key, b'', ex=remaining_exp_time)
|
177
|
+
else:
|
178
|
+
logger.error('Redis is not connected; token revocation is not effective.')
|
148
179
|
|
149
180
|
@classmethod
|
150
|
-
async def
|
181
|
+
async def is_token_revoked(cls, token: str) -> bool:
|
182
|
+
"""Check if the token is revoked by looking it up in the cache."""
|
151
183
|
key = generate_hash_value_from_string(token)
|
152
184
|
return bool(await redis.exists(key))
|
153
185
|
|
154
186
|
|
155
187
|
class QueryParamJWTAuthentication(JWTAuthentication):
|
188
|
+
"""
|
189
|
+
Retrieve the Authorization from query params
|
190
|
+
Example:
|
191
|
+
https://example.com?authorization=the_jwt_without_bearer
|
192
|
+
"""
|
193
|
+
|
156
194
|
@classmethod
|
157
|
-
def get_authorization_header(cls, request: Request | Websocket) -> str:
|
195
|
+
def get_authorization_header(cls, request: Request | Websocket) -> list[str]:
|
158
196
|
if auth := request.query_params.get('authorization'):
|
159
197
|
return auth
|
160
198
|
msg = '`authorization` query param not found.'
|
161
199
|
raise cls.exception(msg) from None
|
162
200
|
|
201
|
+
@classmethod
|
202
|
+
def get_token(cls, auth_header) -> str:
|
203
|
+
return auth_header
|
204
|
+
|
163
205
|
|
164
206
|
class CookieJWTAuthentication(JWTAuthentication):
|
207
|
+
"""
|
208
|
+
Retrieve the Authorization from cookies
|
209
|
+
Example:
|
210
|
+
Cookies: access_token=the_jwt_without_bearer
|
211
|
+
"""
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
async def authentication(cls, request: Request | Websocket) -> Model:
|
215
|
+
user = await super().authentication(request=request)
|
216
|
+
if refresh_token := request.headers.get_cookies().get('refresh_token'):
|
217
|
+
# It's used in `cls.refresh()`
|
218
|
+
user._auth_refresh_token = refresh_token
|
219
|
+
return user
|
220
|
+
|
165
221
|
@classmethod
|
166
222
|
def get_authorization_header(cls, request: Request | Websocket) -> str:
|
167
223
|
if token := request.headers.get_cookies().get('access_token'):
|
@@ -170,14 +226,5 @@ class CookieJWTAuthentication(JWTAuthentication):
|
|
170
226
|
raise cls.exception(msg) from None
|
171
227
|
|
172
228
|
@classmethod
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
if redis.is_connected and await cls._check_in_cache(token=token):
|
177
|
-
msg = 'User logged out'
|
178
|
-
raise cls.exception(msg) from None
|
179
|
-
|
180
|
-
payload = cls.decode_jwt(token)
|
181
|
-
user = await cls.get_user(payload)
|
182
|
-
user._auth_token = token
|
183
|
-
return user
|
229
|
+
def get_token(cls, auth_header) -> str:
|
230
|
+
return auth_header
|