panther 5.0.0b3__py3-none-any.whl → 5.0.0b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. panther/__init__.py +1 -1
  2. panther/_load_configs.py +46 -37
  3. panther/_utils.py +49 -34
  4. panther/app.py +96 -97
  5. panther/authentications.py +97 -50
  6. panther/background_tasks.py +98 -124
  7. panther/base_request.py +16 -10
  8. panther/base_websocket.py +8 -8
  9. panther/caching.py +16 -80
  10. panther/cli/create_command.py +17 -16
  11. panther/cli/main.py +1 -1
  12. panther/cli/monitor_command.py +11 -6
  13. panther/cli/run_command.py +5 -71
  14. panther/cli/template.py +7 -7
  15. panther/cli/utils.py +58 -69
  16. panther/configs.py +70 -72
  17. panther/db/connections.py +30 -24
  18. panther/db/cursor.py +3 -1
  19. panther/db/models.py +26 -10
  20. panther/db/queries/base_queries.py +4 -5
  21. panther/db/queries/mongodb_queries.py +21 -21
  22. panther/db/queries/pantherdb_queries.py +1 -1
  23. panther/db/queries/queries.py +26 -8
  24. panther/db/utils.py +1 -1
  25. panther/events.py +25 -14
  26. panther/exceptions.py +2 -7
  27. panther/file_handler.py +1 -1
  28. panther/generics.py +74 -100
  29. panther/logging.py +2 -1
  30. panther/main.py +12 -13
  31. panther/middlewares/cors.py +67 -0
  32. panther/middlewares/monitoring.py +5 -3
  33. panther/openapi/urls.py +2 -2
  34. panther/openapi/utils.py +3 -3
  35. panther/openapi/views.py +20 -37
  36. panther/pagination.py +4 -2
  37. panther/panel/apis.py +2 -7
  38. panther/panel/urls.py +2 -6
  39. panther/panel/utils.py +9 -5
  40. panther/panel/views.py +13 -22
  41. panther/permissions.py +2 -1
  42. panther/request.py +2 -1
  43. panther/response.py +101 -94
  44. panther/routings.py +12 -12
  45. panther/serializer.py +20 -43
  46. panther/test.py +73 -58
  47. panther/throttling.py +68 -3
  48. panther/utils.py +5 -11
  49. panther-5.0.0b5.dist-info/METADATA +188 -0
  50. panther-5.0.0b5.dist-info/RECORD +75 -0
  51. panther/monitoring.py +0 -34
  52. panther-5.0.0b3.dist-info/METADATA +0 -223
  53. panther-5.0.0b3.dist-info/RECORD +0 -75
  54. {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/WHEEL +0 -0
  55. {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/entry_points.txt +0 -0
  56. {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/licenses/LICENSE +0 -0
  57. {panther-5.0.0b3.dist-info → panther-5.0.0b5.dist-info}/top_level.txt +0 -0
panther/app.py CHANGED
@@ -2,36 +2,35 @@ import functools
2
2
  import logging
3
3
  import traceback
4
4
  import typing
5
+ from collections.abc import Callable
5
6
  from datetime import timedelta
6
- from typing import Literal, Callable
7
+ from typing import Literal
7
8
 
8
9
  from orjson import JSONDecodeError
9
- from pydantic import ValidationError, BaseModel
10
+ from pydantic import BaseModel, ValidationError
10
11
 
11
12
  from panther._utils import is_function_async
13
+ from panther.base_request import BaseRequest
12
14
  from panther.caching import (
13
15
  get_response_from_cache,
14
16
  set_response_in_cache,
15
- get_throttling_from_cache,
16
- increment_throttling_in_cache
17
17
  )
18
18
  from panther.configs import config
19
19
  from panther.exceptions import (
20
20
  APIError,
21
21
  AuthorizationAPIError,
22
+ BadRequestAPIError,
22
23
  JSONDecodeAPIError,
23
24
  MethodNotAllowedAPIError,
24
- ThrottlingAPIError,
25
- BadRequestAPIError
25
+ PantherError,
26
26
  )
27
- from panther.exceptions import PantherError
28
27
  from panther.middlewares import HTTPMiddleware
29
28
  from panther.openapi import OutputSchema
30
29
  from panther.permissions import BasePermission
31
30
  from panther.request import Request
32
31
  from panther.response import Response
33
32
  from panther.serializer import ModelSerializer
34
- from panther.throttling import Throttling
33
+ from panther.throttling import Throttle
35
34
 
36
35
  __all__ = ('API', 'GenericAPI')
37
36
 
@@ -40,19 +39,20 @@ logger = logging.getLogger('panther')
40
39
 
41
40
  class API:
42
41
  """
42
+ methods: Specify the allowed methods.
43
43
  input_model: The `request.data` will be validated with this attribute, It will raise an
44
44
  `panther.exceptions.BadRequestAPIError` or put the validated data in the `request.validated_data`.
45
+ output_model: The `response.data` will be passed through this class to filter its attributes.
45
46
  output_schema: This attribute only used in creation of OpenAPI scheme which is available in `panther.openapi.urls`
46
47
  You may want to add its `url` to your urls.
47
48
  auth: It will authenticate the user with header of its request or raise an
48
49
  `panther.exceptions.AuthenticationAPIError`.
49
50
  permissions: List of permissions that will be called sequentially after authentication to authorize the user.
50
- throttling: It will limit the users' request on a specific (time-bucket, path)
51
- cache: Response of the request will be cached.
52
- cache_exp_time: Specify the expiry time of the cache. (default is `config.DEFAULT_CACHE_EXP`)
53
- methods: Specify the allowed methods.
51
+ throttling: It will limit the users' request on a specific (time-window, path)
52
+ cache: Specify the duration of the cache (Will be used only in GET requests).
54
53
  middlewares: These middlewares have inner priority than global middlewares.
55
54
  """
55
+
56
56
  func: Callable
57
57
 
58
58
  def __init__(
@@ -60,36 +60,62 @@ class API:
60
60
  *,
61
61
  methods: list[Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE']] | None = None,
62
62
  input_model: type[ModelSerializer] | type[BaseModel] | None = None,
63
- output_model: type[BaseModel] | None = None,
63
+ output_model: type[ModelSerializer] | type[BaseModel] | None = None,
64
64
  output_schema: OutputSchema | None = None,
65
65
  auth: bool = False,
66
- permissions: list[BasePermission] | None = None,
67
- throttling: Throttling | None = None,
68
- cache: bool = False,
69
- cache_exp_time: timedelta | int | None = None,
70
- middlewares: list[HTTPMiddleware] | None = None,
66
+ permissions: list[type[BasePermission]] | None = None,
67
+ throttling: Throttle | None = None,
68
+ cache: timedelta | None = None,
69
+ middlewares: list[type[HTTPMiddleware]] | None = None,
70
+ **kwargs,
71
71
  ):
72
- self.methods = {m.upper() for m in methods} if methods else None
72
+ self.methods = {m.upper() for m in methods} if methods else {'GET', 'POST', 'PUT', 'PATCH', 'DELETE'}
73
73
  self.input_model = input_model
74
+ self.output_model = output_model
74
75
  self.output_schema = output_schema
75
76
  self.auth = auth
76
77
  self.permissions = permissions or []
77
78
  self.throttling = throttling
78
79
  self.cache = cache
79
- self.cache_exp_time = cache_exp_time
80
- self.middlewares: list[HTTPMiddleware] | None = middlewares
80
+ self.middlewares = middlewares
81
81
  self.request: Request | None = None
82
- if output_model:
82
+ if kwargs.pop('cache_exp_time', None):
83
+ deprecation_message = (
84
+ traceback.format_stack(limit=2)[0]
85
+ + '\nThe `cache_exp_time` argument has been removed in Panther v5 and is no longer available.'
86
+ '\nYou may want to use `cache` instead.'
87
+ )
88
+ raise PantherError(deprecation_message)
89
+ # Validate Cache
90
+ if self.cache and not isinstance(self.cache, timedelta):
83
91
  deprecation_message = (
84
- traceback.format_stack(limit=2)[0] +
85
- '\nThe `output_model` argument has been removed in Panther v5 and is no longer available.'
86
- '\nPlease update your code to use the new approach. More info: '
87
- 'https://pantherpy.github.io/open_api/'
92
+ traceback.format_stack(limit=2)[0] + '\nThe `cache` argument has been changed in Panther v5, '
93
+ 'it should be an instance of `datetime.timedelta()`.'
88
94
  )
89
95
  raise PantherError(deprecation_message)
96
+ assert self.cache is None or isinstance(self.cache, timedelta)
97
+ # Validate Permissions
98
+ for perm in self.permissions:
99
+ if is_function_async(perm.authorization) is False:
100
+ msg = f'{perm.__name__}.authorization() should be `async`'
101
+ logger.error(msg)
102
+ raise PantherError(msg)
103
+ if type(perm.authorization).__name__ != 'method':
104
+ msg = f'{perm.__name__}.authorization() should be `@classmethod`'
105
+ logger.error(msg)
106
+ raise PantherError(msg)
107
+ # Check kwargs
108
+ if kwargs:
109
+ msg = f'Unknown kwargs: {kwargs.keys()}'
110
+ logger.error(msg)
111
+ raise PantherError(msg)
90
112
 
91
113
  def __call__(self, func):
92
114
  self.func = func
115
+ self.is_function_async = is_function_async(self.func)
116
+ self.function_annotations = {
117
+ k: v for k, v in func.__annotations__.items() if v in {BaseRequest, Request, bool, int}
118
+ }
93
119
 
94
120
  @functools.wraps(func)
95
121
  async def wrapper(request: Request) -> Response:
@@ -110,37 +136,40 @@ class API:
110
136
  async def handle_endpoint(self, request: Request) -> Response:
111
137
  self.request = request
112
138
 
113
- # 0. Preflight
114
- if self.request.method == 'OPTIONS':
115
- return self.options()
116
-
117
139
  # 1. Check Method
118
- if self.methods and self.request.method not in self.methods:
140
+ if self.request.method not in self.methods:
119
141
  raise MethodNotAllowedAPIError
120
142
 
121
143
  # 2. Authentication
122
- await self.handle_authentication()
144
+ if self.auth:
145
+ if not config.AUTHENTICATION:
146
+ logger.critical('"AUTHENTICATION" has not been set in configs')
147
+ raise APIError
148
+ self.request.user = await config.AUTHENTICATION.authentication(self.request)
123
149
 
124
150
  # 3. Permissions
125
- await self.handle_permission()
151
+ for perm in self.permissions:
152
+ if await perm.authorization(self.request) is False:
153
+ raise AuthorizationAPIError
126
154
 
127
- # 4. Throttling
128
- await self.handle_throttling()
155
+ # 4. Throttle
156
+ if throttling := self.throttling or config.THROTTLING:
157
+ await throttling.check_and_increment(request=self.request)
129
158
 
130
159
  # 5. Validate Input
131
- if self.request.method in {'POST', 'PUT', 'PATCH'}:
132
- self.handle_input_validation()
160
+ if self.input_model and self.request.method in {'POST', 'PUT', 'PATCH'}:
161
+ self.request.validated_data = self.validate_input(model=self.input_model, request=self.request)
133
162
 
134
163
  # 6. Get Cached Response
135
164
  if self.cache and self.request.method == 'GET':
136
- if cached := await get_response_from_cache(request=self.request, cache_exp_time=self.cache_exp_time):
165
+ if cached := await get_response_from_cache(request=self.request, duration=self.cache):
137
166
  return Response(data=cached.data, headers=cached.headers, status_code=cached.status_code)
138
167
 
139
168
  # 7. Put PathVariables and Request(If User Wants It) In kwargs
140
- kwargs = self.request.clean_parameters(self.func)
169
+ kwargs = self.request.clean_parameters(self.function_annotations)
141
170
 
142
171
  # 8. Call Endpoint
143
- if is_function_async(self.func):
172
+ if self.is_function_async:
144
173
  response = await self.func(**kwargs)
145
174
  else:
146
175
  response = self.func(**kwargs)
@@ -148,53 +177,17 @@ class API:
148
177
  # 9. Clean Response
149
178
  if not isinstance(response, Response):
150
179
  response = Response(data=response)
180
+ if self.output_model and response.data:
181
+ response.data = await response.serialize_output(output_model=self.output_model)
151
182
  if response.pagination:
152
183
  response.data = await response.pagination.template(response.data)
153
184
 
154
185
  # 10. Set New Response To Cache
155
186
  if self.cache and self.request.method == 'GET':
156
- await set_response_in_cache(request=self.request, response=response, cache_exp_time=self.cache_exp_time)
157
-
158
- # 11. Warning CacheExpTime
159
- if self.cache_exp_time and self.cache is False:
160
- logger.warning('"cache_exp_time" won\'t work while "cache" is False')
187
+ await set_response_in_cache(request=self.request, response=response, duration=self.cache)
161
188
 
162
189
  return response
163
190
 
164
- async def handle_authentication(self) -> None:
165
- if self.auth:
166
- if not config.AUTHENTICATION:
167
- logger.critical('"AUTHENTICATION" has not been set in configs')
168
- raise APIError
169
- self.request.user = await config.AUTHENTICATION.authentication(self.request)
170
-
171
- async def handle_throttling(self) -> None:
172
- if throttling := self.throttling or config.THROTTLING:
173
- if await get_throttling_from_cache(self.request, duration=throttling.duration) + 1 > throttling.rate:
174
- raise ThrottlingAPIError
175
-
176
- await increment_throttling_in_cache(self.request, duration=throttling.duration)
177
-
178
- async def handle_permission(self) -> None:
179
- for perm in self.permissions:
180
- if type(perm.authorization).__name__ != 'method':
181
- logger.error(f'{perm.__name__}.authorization should be "classmethod"')
182
- raise AuthorizationAPIError
183
- if await perm.authorization(self.request) is False:
184
- raise AuthorizationAPIError
185
-
186
- def handle_input_validation(self):
187
- if self.input_model:
188
- self.request.validated_data = self.validate_input(model=self.input_model, request=self.request)
189
-
190
- @classmethod
191
- def options(cls):
192
- headers = {
193
- 'Access-Control-Allow-Methods': 'DELETE, GET, PATCH, POST, PUT, OPTIONS, HEAD',
194
- 'Access-Control-Allow-Headers': 'Accept, Authorization, User-Agent, Content-Type',
195
- }
196
- return Response(headers=headers)
197
-
198
191
  @classmethod
199
192
  def validate_input(cls, model, request: Request):
200
193
  if isinstance(request.data, bytes):
@@ -212,21 +205,15 @@ class API:
212
205
 
213
206
 
214
207
  class MetaGenericAPI(type):
215
- def __new__(
216
- cls,
217
- cls_name: str,
218
- bases: tuple[type[typing.Any], ...],
219
- namespace: dict[str, typing.Any],
220
- **kwargs
221
- ):
208
+ def __new__(cls, cls_name: str, bases: tuple[type[typing.Any], ...], namespace: dict[str, typing.Any], **kwargs):
222
209
  if cls_name == 'GenericAPI':
223
210
  return super().__new__(cls, cls_name, bases, namespace)
224
211
  if 'output_model' in namespace:
225
212
  deprecation_message = (
226
- traceback.format_stack(limit=2)[0] +
227
- '\nThe `output_model` argument has been removed in Panther v5 and is no longer available.'
228
- '\nPlease update your code to use the new approach. More info: '
229
- 'https://pantherpy.github.io/open_api/'
213
+ traceback.format_stack(limit=2)[0]
214
+ + '\nThe `output_model` argument has been removed in Panther v5 and is no longer available.'
215
+ '\nPlease update your code to use the new approach. More info: '
216
+ 'https://pantherpy.github.io/open_api/'
230
217
  )
231
218
  raise PantherError(deprecation_message)
232
219
  return super().__new__(cls, cls_name, bases, namespace)
@@ -236,15 +223,29 @@ class GenericAPI(metaclass=MetaGenericAPI):
236
223
  """
237
224
  Check out the documentation of `panther.app.API()`.
238
225
  """
226
+
239
227
  input_model: type[ModelSerializer] | type[BaseModel] | None = None
228
+ output_model: type[ModelSerializer] | type[BaseModel] | None = None
240
229
  output_schema: OutputSchema | None = None
241
230
  auth: bool = False
242
- permissions: list | None = None
243
- throttling: Throttling | None = None
244
- cache: bool = False
245
- cache_exp_time: timedelta | int | None = None
231
+ permissions: list[type[BasePermission]] | None = None
232
+ throttling: Throttle | None = None
233
+ cache: timedelta | None = None
246
234
  middlewares: list[HTTPMiddleware] | None = None
247
235
 
236
+ def __init_subclass__(cls, **kwargs):
237
+ # Creating API instance to validate the attributes.
238
+ API(
239
+ input_model=cls.input_model,
240
+ output_model=cls.output_model,
241
+ output_schema=cls.output_schema,
242
+ auth=cls.auth,
243
+ permissions=cls.permissions,
244
+ throttling=cls.throttling,
245
+ cache=cls.cache,
246
+ middlewares=cls.middlewares,
247
+ )
248
+
248
249
  async def get(self, *args, **kwargs):
249
250
  raise MethodNotAllowedAPIError
250
251
 
@@ -272,18 +273,16 @@ class GenericAPI(metaclass=MetaGenericAPI):
272
273
  func = self.patch
273
274
  case 'DELETE':
274
275
  func = self.delete
275
- case 'OPTIONS':
276
- func = API.options
277
276
  case _:
278
277
  raise MethodNotAllowedAPIError
279
278
 
280
279
  return await API(
281
280
  input_model=self.input_model,
281
+ output_model=self.output_model,
282
282
  output_schema=self.output_schema,
283
283
  auth=self.auth,
284
284
  permissions=self.permissions,
285
285
  throttling=self.throttling,
286
286
  cache=self.cache,
287
- cache_exp_time=self.cache_exp_time,
288
287
  middlewares=self.middlewares,
289
288
  )(func)(request=request)
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import time
3
3
  from abc import abstractmethod
4
- from datetime import timezone, datetime
4
+ from datetime import datetime, timezone
5
5
  from typing import Literal
6
6
 
7
7
  from panther.base_websocket import Websocket
@@ -36,24 +36,44 @@ class BaseAuthentication:
36
36
 
37
37
 
38
38
  class JWTAuthentication(BaseAuthentication):
39
- model = BaseUser
39
+ """
40
+ Retrieve the Authorization from header
41
+ Example:
42
+ Headers: {'authorization': 'Bearer the_jwt_token'}
43
+ """
44
+
45
+ model = None
40
46
  keyword = 'Bearer'
41
47
  algorithm = 'HS256'
42
48
  HTTP_HEADER_ENCODING = 'iso-8859-1' # RFC5987
43
49
 
44
50
  @classmethod
45
- def get_authorization_header(cls, request: Request | Websocket) -> str:
51
+ def get_authorization_header(cls, request: Request | Websocket) -> list[str]:
52
+ """Retrieve the Authorization header from the request."""
46
53
  if auth := request.headers.authorization:
47
- return auth
54
+ return auth.split()
48
55
  msg = 'Authorization is required'
49
56
  raise cls.exception(msg) from None
50
57
 
51
58
  @classmethod
52
59
  async def authentication(cls, request: Request | Websocket) -> Model:
53
- auth_header = cls.get_authorization_header(request).split()
60
+ """Authenticate the user based on the JWT token in the Authorization header."""
61
+ auth_header = cls.get_authorization_header(request)
62
+ token = cls.get_token(auth_header=auth_header)
63
+
64
+ if redis.is_connected and await cls.is_token_revoked(token=token):
65
+ msg = 'User logged out'
66
+ raise cls.exception(msg) from None
54
67
 
68
+ payload = await cls.decode_jwt(token)
69
+ user = await cls.get_user(payload)
70
+ user._auth_token = token
71
+ return user
72
+
73
+ @classmethod
74
+ def get_token(cls, auth_header):
55
75
  if len(auth_header) != 2:
56
- msg = 'Authorization should have 2 part'
76
+ msg = 'Authorization header must contain 2 parts'
57
77
  raise cls.exception(msg) from None
58
78
 
59
79
  bearer, token = auth_header
@@ -67,18 +87,11 @@ class JWTAuthentication(BaseAuthentication):
67
87
  msg = 'Authorization keyword is not valid'
68
88
  raise cls.exception(msg) from None
69
89
 
70
- if redis.is_connected and await cls._check_in_cache(token=token):
71
- msg = 'User logged out'
72
- raise cls.exception(msg) from None
73
-
74
- payload = cls.decode_jwt(token)
75
- user = await cls.get_user(payload)
76
- user._auth_token = token
77
- return user
90
+ return token
78
91
 
79
92
  @classmethod
80
- def decode_jwt(cls, token: str) -> dict:
81
- """Decode JWT token to user_id (it can return multiple variable ... )"""
93
+ async def decode_jwt(cls, token: str) -> dict:
94
+ """Decode a JWT token and return the payload."""
82
95
  try:
83
96
  return jwt.decode(
84
97
  token=token,
@@ -90,21 +103,21 @@ class JWTAuthentication(BaseAuthentication):
90
103
 
91
104
  @classmethod
92
105
  async def get_user(cls, payload: dict) -> Model:
93
- """Get UserModel from config, else use default UserModel from cls.model"""
106
+ """Fetch the user based on the decoded JWT payload from cls.model or config.UserModel"""
94
107
  if (user_id := payload.get('user_id')) is None:
95
108
  msg = 'Payload does not have `user_id`'
96
109
  raise cls.exception(msg)
97
110
 
98
- user_model = config.USER_MODEL or cls.model
99
- if user := await user_model.find_one(id=user_id):
100
- return user
111
+ user_model = cls.model or config.USER_MODEL
112
+ user = await user_model.find_one(id=user_id)
113
+ if user is None:
114
+ raise cls.exception('User not found')
101
115
 
102
- msg = 'User not found'
103
- raise cls.exception(msg) from None
116
+ return user
104
117
 
105
118
  @classmethod
106
119
  def encode_jwt(cls, user_id: str, token_type: Literal['access', 'refresh'] = 'access') -> str:
107
- """Encode JWT from user_id."""
120
+ """Generate a JWT token for a given user ID."""
108
121
  issued_at = datetime.now(timezone.utc).timestamp()
109
122
  if token_type == 'access':
110
123
  expire = issued_at + config.JWT_CONFIG.life_time
@@ -124,44 +137,87 @@ class JWTAuthentication(BaseAuthentication):
124
137
  )
125
138
 
126
139
  @classmethod
127
- def login(cls, user_id: str) -> dict:
128
- """Return dict of access and refresh token"""
140
+ async def login(cls, user) -> dict:
141
+ """Generate access and refresh tokens for user login."""
129
142
  return {
130
- 'access_token': cls.encode_jwt(user_id=user_id),
131
- 'refresh_token': cls.encode_jwt(user_id=user_id, token_type='refresh')
143
+ 'access_token': cls.encode_jwt(user_id=user.id),
144
+ 'refresh_token': cls.encode_jwt(user_id=user.id, token_type='refresh'),
132
145
  }
133
146
 
134
147
  @classmethod
135
- async def logout(cls, raw_token: str) -> None:
136
- *_, token = raw_token.split()
137
- if redis.is_connected:
138
- payload = cls.decode_jwt(token=token)
139
- remaining_exp_time = payload['exp'] - time.time()
140
- await cls._set_in_cache(token=token, exp=int(remaining_exp_time))
148
+ async def logout(cls, user) -> None:
149
+ """Log out a user by revoking their JWT token."""
150
+ payload = await cls.decode_jwt(token=user._auth_token)
151
+ await cls.revoke_token_in_cache(token=user._auth_token, exp=payload['exp'])
152
+
153
+ @classmethod
154
+ async def refresh(cls, user):
155
+ if hasattr(user, '_auth_refresh_token'):
156
+ # It happens in CookieJWTAuthentication
157
+ token = user._auth_refresh_token
141
158
  else:
142
- logger.error('`redis` middleware is required for `logout()`')
159
+ token = user._auth_token
160
+
161
+ payload = await cls.decode_jwt(token=token)
162
+
163
+ if payload['token_type'] != 'refresh':
164
+ raise cls.exception('Invalid token type; expected `refresh` token.')
165
+ # Revoke after use
166
+ await cls.revoke_token_in_cache(token=token, exp=payload['exp'])
167
+
168
+ return await cls.login(user=user)
143
169
 
144
170
  @classmethod
145
- async def _set_in_cache(cls, token: str, exp: int) -> None:
146
- key = generate_hash_value_from_string(token)
147
- await redis.set(key, b'', ex=exp)
171
+ async def revoke_token_in_cache(cls, token: str, exp: int) -> None:
172
+ """Mark the token as revoked in the cache."""
173
+ if redis.is_connected:
174
+ key = generate_hash_value_from_string(token)
175
+ remaining_exp_time = int(exp - time.time())
176
+ await redis.set(key, b'', ex=remaining_exp_time)
177
+ else:
178
+ logger.error('Redis is not connected; token revocation is not effective.')
148
179
 
149
180
  @classmethod
150
- async def _check_in_cache(cls, token: str) -> bool:
181
+ async def is_token_revoked(cls, token: str) -> bool:
182
+ """Check if the token is revoked by looking it up in the cache."""
151
183
  key = generate_hash_value_from_string(token)
152
184
  return bool(await redis.exists(key))
153
185
 
154
186
 
155
187
  class QueryParamJWTAuthentication(JWTAuthentication):
188
+ """
189
+ Retrieve the Authorization from query params
190
+ Example:
191
+ https://example.com?authorization=the_jwt_without_bearer
192
+ """
193
+
156
194
  @classmethod
157
- def get_authorization_header(cls, request: Request | Websocket) -> str:
195
+ def get_authorization_header(cls, request: Request | Websocket) -> list[str]:
158
196
  if auth := request.query_params.get('authorization'):
159
197
  return auth
160
198
  msg = '`authorization` query param not found.'
161
199
  raise cls.exception(msg) from None
162
200
 
201
+ @classmethod
202
+ def get_token(cls, auth_header) -> str:
203
+ return auth_header
204
+
163
205
 
164
206
  class CookieJWTAuthentication(JWTAuthentication):
207
+ """
208
+ Retrieve the Authorization from cookies
209
+ Example:
210
+ Cookies: access_token=the_jwt_without_bearer
211
+ """
212
+
213
+ @classmethod
214
+ async def authentication(cls, request: Request | Websocket) -> Model:
215
+ user = await super().authentication(request=request)
216
+ if refresh_token := request.headers.get_cookies().get('refresh_token'):
217
+ # It's used in `cls.refresh()`
218
+ user._auth_refresh_token = refresh_token
219
+ return user
220
+
165
221
  @classmethod
166
222
  def get_authorization_header(cls, request: Request | Websocket) -> str:
167
223
  if token := request.headers.get_cookies().get('access_token'):
@@ -170,14 +226,5 @@ class CookieJWTAuthentication(JWTAuthentication):
170
226
  raise cls.exception(msg) from None
171
227
 
172
228
  @classmethod
173
- async def authentication(cls, request: Request | Websocket) -> Model:
174
- token = cls.get_authorization_header(request)
175
-
176
- if redis.is_connected and await cls._check_in_cache(token=token):
177
- msg = 'User logged out'
178
- raise cls.exception(msg) from None
179
-
180
- payload = cls.decode_jwt(token)
181
- user = await cls.get_user(payload)
182
- user._auth_token = token
183
- return user
229
+ def get_token(cls, auth_header) -> str:
230
+ return auth_header