aiohttp-msal 0.7.0__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0}/PKG-INFO +23 -24
  2. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0}/README.md +9 -0
  3. aiohttp_msal-1.0.0/pyproject.toml +128 -0
  4. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/__init__.py +34 -20
  5. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/msal_async.py +51 -35
  6. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/redis_tools.py +33 -14
  7. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/routes.py +11 -12
  8. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/settings.py +16 -12
  9. aiohttp_msal-1.0.0/src/aiohttp_msal/settings_base.py +93 -0
  10. {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/user_info.py +8 -4
  11. aiohttp_msal-0.7.0/LICENSE +0 -21
  12. aiohttp_msal-0.7.0/aiohttp_msal/settings_base.py +0 -83
  13. aiohttp_msal-0.7.0/aiohttp_msal.egg-info/PKG-INFO +0 -141
  14. aiohttp_msal-0.7.0/aiohttp_msal.egg-info/SOURCES.txt +0 -23
  15. aiohttp_msal-0.7.0/aiohttp_msal.egg-info/dependency_links.txt +0 -1
  16. aiohttp_msal-0.7.0/aiohttp_msal.egg-info/requires.txt +0 -16
  17. aiohttp_msal-0.7.0/aiohttp_msal.egg-info/top_level.txt +0 -2
  18. aiohttp_msal-0.7.0/aiohttp_msal.egg-info/zip-safe +0 -1
  19. aiohttp_msal-0.7.0/pyproject.toml +0 -65
  20. aiohttp_msal-0.7.0/setup.cfg +0 -55
  21. aiohttp_msal-0.7.0/setup.py +0 -6
  22. aiohttp_msal-0.7.0/tests/__init__.py +0 -0
  23. aiohttp_msal-0.7.0/tests/test_init.py +0 -4
  24. aiohttp_msal-0.7.0/tests/test_msal_async.py +0 -12
  25. aiohttp_msal-0.7.0/tests/test_redis_tools.py +0 -60
  26. aiohttp_msal-0.7.0/tests/test_settings.py +0 -11
@@ -1,37 +1,27 @@
1
- Metadata-Version: 2.1
2
- Name: aiohttp_msal
3
- Version: 0.7.0
1
+ Metadata-Version: 2.3
2
+ Name: aiohttp-msal
3
+ Version: 1.0.0
4
4
  Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
5
- Home-page: https://github.com/kellerza/aiohttp_msal
5
+ Keywords: aiohttp,asyncio,msal,oauth
6
6
  Author: Johann Kellerman
7
- Author-email: kellerza@gmail.com
7
+ Author-email: Johann Kellerman <kellerza@gmail.com>
8
8
  License: MIT
9
- Keywords: msal,oauth,aiohttp,asyncio
10
9
  Classifier: Development Status :: 4 - Beta
11
10
  Classifier: Intended Audience :: Developers
12
11
  Classifier: Natural Language :: English
13
- Classifier: Programming Language :: Python :: 3
14
12
  Classifier: Programming Language :: Python :: 3 :: Only
15
- Classifier: Programming Language :: Python :: 3.10
16
13
  Classifier: Programming Language :: Python :: 3.11
17
14
  Classifier: Programming Language :: Python :: 3.12
18
- Requires-Python: >=3.10
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Dist: aiohttp>=3.11.18,<3.13
17
+ Requires-Dist: aiohttp-session>=2.12.1,<3
18
+ Requires-Dist: attrs>=25.3,<26
19
+ Requires-Dist: msal>=1.32.3,<2
20
+ Requires-Dist: aiohttp-session[aioredis]>=2.12.1,<3 ; extra == 'aioredis'
21
+ Requires-Python: >=3.11
22
+ Project-URL: Homepage, https://github.com/kellerza/aiohttp_msal
23
+ Provides-Extra: aioredis
19
24
  Description-Content-Type: text/markdown
20
- License-File: LICENSE
21
- Requires-Dist: msal>=1.30.0
22
- Requires-Dist: aiohttp_session>=2.12
23
- Requires-Dist: aiohttp>=3.8
24
- Provides-Extra: redis
25
- Requires-Dist: aiohttp_session[aioredis]>=2.12; extra == "redis"
26
- Provides-Extra: tests
27
- Requires-Dist: black==24.8.0; extra == "tests"
28
- Requires-Dist: pylint==3.2.6; extra == "tests"
29
- Requires-Dist: flake8; extra == "tests"
30
- Requires-Dist: pytest-aiohttp; extra == "tests"
31
- Requires-Dist: pytest; extra == "tests"
32
- Requires-Dist: pytest-cov; extra == "tests"
33
- Requires-Dist: pytest-asyncio; extra == "tests"
34
- Requires-Dist: pytest-env; extra == "tests"
35
25
 
36
26
  # aiohttp_msal Python library
37
27
 
@@ -139,3 +129,12 @@ def main()
139
129
  # ...
140
130
  # use the Graphclient
141
131
  ```
132
+
133
+ ## Development
134
+
135
+ ```bash
136
+ uv sync --all-extras
137
+ uv tool install ruff
138
+ uv tool install codespell
139
+ uv tool install pyproject-fmt
140
+ ```
@@ -104,3 +104,12 @@ def main()
104
104
  # ...
105
105
  # use the Graphclient
106
106
  ```
107
+
108
+ ## Development
109
+
110
+ ```bash
111
+ uv sync --all-extras
112
+ uv tool install ruff
113
+ uv tool install codespell
114
+ uv tool install pyproject-fmt
115
+ ```
@@ -0,0 +1,128 @@
1
+ [build-system]
2
+ build-backend = "uv_build"
3
+ requires = [ "uv-build" ] # >=0.5.15,<0.6
4
+
5
+ [project]
6
+ name = "aiohttp-msal"
7
+ version = "1.0.0"
8
+ description = "Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp"
9
+ readme = "README.md"
10
+ keywords = [ "aiohttp", "asyncio", "msal", "oauth" ]
11
+ license = { text = "MIT" }
12
+ authors = [ { name = "Johann Kellerman", email = "kellerza@gmail.com" } ]
13
+ requires-python = ">=3.11"
14
+ classifiers = [
15
+ "Development Status :: 4 - Beta",
16
+ "Intended Audience :: Developers",
17
+ "Natural Language :: English",
18
+ "Programming Language :: Python :: 3 :: Only",
19
+ "Programming Language :: Python :: 3.11",
20
+ "Programming Language :: Python :: 3.12",
21
+ "Programming Language :: Python :: 3.13",
22
+ ]
23
+ dependencies = [
24
+ "aiohttp>=3.11.18,<3.13",
25
+ "aiohttp-session>=2.12.1,<3",
26
+ "attrs>=25.3,<26",
27
+ "msal>=1.32.3,<2",
28
+ ]
29
+ optional-dependencies.aioredis = [ "aiohttp-session[aioredis]>=2.12.1,<3" ]
30
+ urls.Homepage = "https://github.com/kellerza/aiohttp_msal"
31
+
32
+ [dependency-groups]
33
+ dev = [
34
+ "mypy",
35
+ "pytest",
36
+ "pytest-aiohttp",
37
+ "pytest-asyncio",
38
+ "pytest-cov",
39
+ "pytest-env",
40
+ "types-redis",
41
+ ]
42
+
43
+ [tool.ruff]
44
+ include = [ "aiohttp_msal/**/*.py", "tests/*.py" ]
45
+
46
+ format.line-ending = "lf"
47
+ format.docstring-code-format = true
48
+ lint.select = [
49
+ "A", # flake8-builtins
50
+ "ASYNC", # flake8-async
51
+ "B", # bugbear
52
+ "D", # pydocstyle
53
+ "E", # pycodestyle
54
+ "F", # pyflakes
55
+ "I", # isort
56
+ "PGH", # pygrep-hooks
57
+ "PIE", # flake8-pie
58
+ "PL", # pylint
59
+ "PTH", # flake8-pathlib
60
+ "PYI", # flake8-pyi
61
+ "RUF", # ruff
62
+ "UP", # pyupgrade
63
+ "W", # pycodestyle
64
+ ]
65
+ lint.ignore = [
66
+ "D203",
67
+ "D213",
68
+ "E203",
69
+ "E501",
70
+ "PLC0415",
71
+ "PLR2004",
72
+ "PLW2901",
73
+ "UP047",
74
+ ]
75
+ lint.isort.no-lines-before = [ "future", "standard-library" ]
76
+
77
+ [tool.codespell]
78
+ skip = [ "build/*", "*.json", "*.csv", "**/node_modules/*", "./s2-js/dist/*" ]
79
+ #ignore-words-list = []
80
+
81
+ [tool.pytest.ini_options]
82
+ pythonpath = [ ".", "src" ]
83
+ filterwarnings = "ignore:.+@coroutine.+deprecated.+"
84
+ norecursedirs = [ ".git", "modules" ]
85
+ log_cli = true
86
+ log_cli_level = "DEBUG"
87
+ asyncio_mode = "auto"
88
+ addopts = "--cov=aiohttp_msal --cov-report xml:cov.xml"
89
+ asyncio_default_fixture_loop_scope = "function"
90
+
91
+ env = [
92
+ "X_SP_APP_PW=p1",
93
+ "X_SP_APP_ID=i1",
94
+ "X_SP_AUTHORITY=a1",
95
+
96
+ "Y_SP_APP_PW=p2",
97
+ "Y_SP_APP_ID=i2",
98
+ "Y_SP_AUTHORITY=a2",
99
+
100
+ "A_NUM=5",
101
+ "A_BOOL=True",
102
+
103
+ "B_NUM=10",
104
+ "B_BOOL=False",
105
+ "B_ROOT=/c/",
106
+ ]
107
+
108
+ [tool.mypy]
109
+ disallow_untyped_defs = true
110
+ ignore_missing_imports = true
111
+
112
+ [tool.semantic_release]
113
+ commit = true
114
+ tag = true
115
+ vcs_release = true
116
+ commit_parser = "emoji"
117
+ version_toml = [ "pyproject.toml:project.version" ]
118
+ build_command = "pip install uv && uv build"
119
+ commit_version_number = true
120
+
121
+ # https://python-semantic-release.readthedocs.io/en/latest/multibranch_releases.html#configuring-multibranch-releases
122
+ [tool.semantic_release.branches.main]
123
+ match = "main"
124
+
125
+ [tool.semantic_release.commit_parser_options]
126
+ major_tags = [ ":boom:" ]
127
+ minor_tags = [ ":rocket:" ]
128
+ patch_tags = [ ":ambulance:", ":lock:", ":bug:", ":dolphin:" ]
@@ -1,9 +1,10 @@
1
1
  """aiohttp_msal."""
2
2
 
3
3
  import logging
4
- import typing
4
+ from collections.abc import Awaitable, Callable
5
5
  from functools import wraps
6
6
  from inspect import getfullargspec, iscoroutinefunction
7
+ from typing import TypeVar, TypeVarTuple, cast
7
8
 
8
9
  from aiohttp import ClientSession, web
9
10
  from aiohttp_session import get_session
@@ -14,49 +15,63 @@ from aiohttp_msal.settings import ENV
14
15
 
15
16
  _LOGGER = logging.getLogger(__name__)
16
17
 
17
- VERSION = "0.7.0"
18
+ _T = TypeVar("_T")
19
+ Ts = TypeVarTuple("Ts")
18
20
 
19
21
 
20
22
  def msal_session(
21
- *callbacks: typing.Callable[[AsyncMSAL], bool | typing.Awaitable[bool]],
23
+ *callbacks: Callable[[AsyncMSAL], bool | Awaitable[bool]],
22
24
  at_least_one: bool | None = False,
23
- ) -> typing.Callable:
25
+ ) -> Callable[
26
+ [Callable[[*Ts, AsyncMSAL], Awaitable[_T]]], Callable[[*Ts], Awaitable[_T]]
27
+ ]:
24
28
  """Session decorator.
25
29
 
26
30
  Arguments can include a list of function to perform login tests etc.
27
31
  """
28
32
 
29
- def _session(func: typing.Callable) -> typing.Callable:
33
+ def check_session(
34
+ func: Callable[[*Ts, AsyncMSAL], Awaitable[_T]],
35
+ ) -> Callable[[*Ts], Awaitable[_T]]:
30
36
  @wraps(func)
31
- async def __session(request: web.Request) -> typing.Callable:
37
+ async def wrapper(*args: *Ts) -> _T:
38
+ if len(args) < 1:
39
+ raise AssertionError("Requires a Request as the first parameter")
40
+ request = cast(web.Request, args[0])
32
41
  ses = AsyncMSAL(session=await get_session(request))
33
42
  for c_b in callbacks:
34
43
  _ok = await c_b(ses) if iscoroutinefunction(c_b) else c_b(ses)
35
- if at_least_one and _ok:
36
- break
37
- if not at_least_one and not _ok:
44
+
45
+ if at_least_one:
46
+ if _ok:
47
+ return await func(*args, ses)
48
+ elif not _ok:
38
49
  raise web.HTTPForbidden
39
- return await func(request=request, ses=ses)
50
+
51
+ if at_least_one:
52
+ raise web.HTTPForbidden
53
+ return await func(*args, ses)
40
54
 
41
55
  assert iscoroutinefunction(func), f"Function needs to be a coroutine: {func}"
42
56
  spec = getfullargspec(func)
43
57
  assert "ses" in spec.args, f"Function needs to accept a session 'ses': {func}"
44
- return __session
58
+ return wrapper
45
59
 
46
- return _session
60
+ return check_session
47
61
 
48
62
 
49
- def authenticated(ses: AsyncMSAL) -> bool:
63
+ def auth_ok(ses: AsyncMSAL) -> bool:
50
64
  """Test if session was authenticated."""
51
65
  return bool(ses.mail)
52
66
 
53
67
 
54
68
  def auth_or(
55
- *args: typing.Callable[[AsyncMSAL], bool | typing.Awaitable[bool]]
56
- ) -> typing.Callable[[AsyncMSAL], typing.Awaitable[bool]]:
69
+ *args: Callable[[AsyncMSAL], bool | Awaitable[bool]],
70
+ ) -> Callable[[AsyncMSAL], Awaitable[bool]]:
57
71
  """Ensure either of the methods is valid. An alternative to at_least_one=True.
58
72
 
59
- Arguments can include a list of function to perform login tests etc."""
73
+ Arguments can include a list of function to perform login tests etc.
74
+ """
60
75
 
61
76
  async def or_auth(ses: AsyncMSAL) -> bool:
62
77
  """Or."""
@@ -74,11 +89,10 @@ def auth_or(
74
89
  async def app_init_redis_session(
75
90
  app: web.Application, max_age: int = 3600 * 24 * 90
76
91
  ) -> None:
77
- """OPTIONAL: Initialize aiohttp_session with Redis storage.
92
+ """Init an aiohttp_session with Redis storage helper.
78
93
 
79
94
  You can initialize your own aiohttp_session & storage provider.
80
95
  """
81
- # pylint: disable=import-outside-toplevel
82
96
  from aiohttp_session import redis_storage
83
97
  from redis.asyncio import from_url
84
98
 
@@ -86,7 +100,7 @@ async def app_init_redis_session(
86
100
 
87
101
  _LOGGER.info("Connect to Redis %s", ENV.REDIS)
88
102
  try:
89
- ENV.database = from_url(ENV.REDIS) # pylint: disable=no-member
103
+ ENV.database = from_url(ENV.REDIS)
90
104
  # , encoding="utf-8", decode_responses=True
91
105
  except ConnectionRefusedError as err:
92
106
  raise ConnectionError("Could not connect to REDIS server") from err
@@ -112,7 +126,7 @@ async def check_proxy() -> None:
112
126
  if resp.ok:
113
127
  return
114
128
  raise ConnectionError(await resp.text())
115
- except Exception as err: # pylint: disable=broad-except
129
+ except Exception as err:
116
130
  raise ConnectionError(
117
131
  "No connection to the Internet. Required for OAuth. Check your Proxy?"
118
132
  ) from err
@@ -7,16 +7,24 @@ Once you have the OAuth tokens store in the session, you are free to make reques
7
7
 
8
8
  import asyncio
9
9
  import json
10
- from functools import partial, wraps
11
- from typing import Any, Callable, Literal
10
+ from collections.abc import Callable
11
+ from functools import partial, partialmethod, wraps
12
+ from typing import Any, ClassVar, Literal, Unpack
12
13
 
13
14
  from aiohttp import web
14
- from aiohttp.client import ClientResponse, ClientSession, _RequestContextManager
15
+ from aiohttp.client import (
16
+ ClientResponse,
17
+ ClientSession,
18
+ _RequestContextManager,
19
+ _RequestOptions,
20
+ )
21
+ from aiohttp.typedefs import StrOrURL
15
22
  from aiohttp_session import Session
16
23
  from msal import ConfidentialClientApplication, SerializableTokenCache
17
24
 
18
25
  from aiohttp_msal.settings import ENV
19
26
 
27
+ HttpMethods = Literal["get", "post", "put", "patch", "delete"]
20
28
  HTTP_GET = "get"
21
29
  HTTP_POST = "post"
22
30
  HTTP_PUT = "put"
@@ -62,24 +70,23 @@ class AsyncMSAL:
62
70
  Use until such time as MSAL Python gets a true async version.
63
71
  """
64
72
 
65
- _token_cache: SerializableTokenCache = None
66
- _app: ConfidentialClientApplication = None
67
- _clientsession: ClientSession = None # type: ignore
73
+ _token_cache: SerializableTokenCache
74
+ _app: ConfidentialClientApplication
75
+ client_session: ClassVar[ClientSession | None] = None
68
76
 
69
77
  def __init__(
70
78
  self,
71
- session: Session | dict[str, str],
72
- save_cache: Callable[[Session | dict[str, str]], None] | None = None,
79
+ session: Session | dict[str, Any],
80
+ save_callback: Callable[[Session | dict[str, Any]], None] | None = None,
73
81
  ):
74
82
  """Init the class.
75
83
 
76
- **save_token_cache** will be called if the token cache changes. Optional.
84
+ **save_callback** will be called if the token cache changes. Optional.
77
85
  Not required when the session parameter is an aiohttp_session.Session.
78
86
  """
79
87
  self.session = session
80
- if save_cache:
81
- self.save_token_cache = save_cache
82
- if not isinstance(session, (Session, dict)):
88
+ self.save_callback = save_callback
89
+ if not isinstance(session, Session | dict):
83
90
  raise ValueError(f"session or dict-like object required {session}")
84
91
 
85
92
  @property
@@ -110,12 +117,12 @@ class AsyncMSAL:
110
117
  )
111
118
  return self._app
112
119
 
113
- def _save_token_cache(self) -> None:
120
+ def save_token_cache(self) -> None:
114
121
  """Save the token cache if it changed."""
115
122
  if self.token_cache.has_state_changed:
116
123
  self.session[TOKEN_CACHE] = self.token_cache.serialize()
117
- if hasattr(self, "save_token_cache"):
118
- self.save_token_cache(self.token_cache)
124
+ if self.save_callback:
125
+ self.save_callback(self.session)
119
126
 
120
127
  def build_auth_code_flow(
121
128
  self,
@@ -125,8 +132,8 @@ class AsyncMSAL:
125
132
  **kwargs: Any,
126
133
  ) -> str:
127
134
  """First step - Start the flow."""
128
- self.session[TOKEN_CACHE] = None # type: ignore
129
- self.session[USER_EMAIL] = None # type: ignore
135
+ self.session[TOKEN_CACHE] = None
136
+ self.session[USER_EMAIL] = None
130
137
  self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
131
138
  scopes or DEFAULT_SCOPES,
132
139
  redirect_uri=redirect_uri,
@@ -149,10 +156,9 @@ class AsyncMSAL:
149
156
  raise web.HTTPBadRequest(text=str(result["error"]))
150
157
  if "id_token_claims" not in result:
151
158
  raise web.HTTPBadRequest(text=f"Expected id_token_claims in {result}")
152
- self._save_token_cache()
153
- self.session[USER_EMAIL] = result.get("id_token_claims").get(
154
- "preferred_username"
155
- )
159
+ self.save_token_cache()
160
+ if tok := result.get("id_token_claims"):
161
+ self.session[USER_EMAIL] = tok.get("preferred_username")
156
162
 
157
163
  async def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
158
164
  """Second step - Acquire token, async version."""
@@ -167,7 +173,7 @@ class AsyncMSAL:
167
173
  result = self.app.acquire_token_silent(
168
174
  scopes=scopes or DEFAULT_SCOPES, account=accounts[0]
169
175
  )
170
- self._save_token_cache()
176
+ self.save_token_cache()
171
177
  return result
172
178
  return None
173
179
 
@@ -175,7 +181,9 @@ class AsyncMSAL:
175
181
  """Acquire a token based on username."""
176
182
  return await asyncio.get_event_loop().run_in_executor(None, self.get_token)
177
183
 
178
- async def request(self, method: str, url: str, **kwargs: Any) -> ClientResponse:
184
+ async def request(
185
+ self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
186
+ ) -> ClientResponse:
179
187
  """Make a request to url using an oauth session.
180
188
 
181
189
  :param str url: url to send request to
@@ -184,16 +192,14 @@ class AsyncMSAL:
184
192
  :return: Response of the request
185
193
  :rtype: aiohttp.Response
186
194
  """
187
- if not self._clientsession:
188
- AsyncMSAL._clientsession = ClientSession(trust_env=True)
189
-
190
195
  token = await self.async_get_token()
191
196
  if token is None:
192
197
  raise web.HTTPClientError(text="No login token available.")
193
198
 
194
199
  kwargs = kwargs.copy()
195
200
  # Ensure headers exist & make a copy
196
- kwargs["headers"] = headers = dict(kwargs.get("headers", {}))
201
+ headers = dict[str, str](kwargs.get("headers") or {}) # type:ignore[arg-type]
202
+ kwargs["headers"] = headers
197
203
 
198
204
  headers["Authorization"] = "Bearer " + token["access_token"]
199
205
 
@@ -207,17 +213,27 @@ class AsyncMSAL:
207
213
  if "data" in kwargs:
208
214
  kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json
209
215
 
210
- response = await self._clientsession.request(method, url, **kwargs)
216
+ if not AsyncMSAL.client_session:
217
+ AsyncMSAL.client_session = ClientSession(trust_env=True)
218
+
219
+ return await AsyncMSAL.client_session.request(method, url, **kwargs)
220
+
221
+ def request_ctx(
222
+ self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
223
+ ) -> _RequestContextManager:
224
+ """Request context manager."""
225
+ return _RequestContextManager(self.request(method, url, **kwargs))
211
226
 
212
- return response
227
+ get = partialmethod(request_ctx, HTTP_GET)
228
+ post = partialmethod(request_ctx, HTTP_POST)
213
229
 
214
- def get(self, url: str, **kwargs: Any): # type:ignore
215
- """GET Request."""
216
- return _RequestContextManager(self.request(HTTP_GET, url, **kwargs))
230
+ # def get(self, url: str, **kwargs: Any) -> _RequestContextManager:
231
+ # """GET Request."""
232
+ # return _RequestContextManager(self.request(HTTP_GET, url, **kwargs))
217
233
 
218
- def post(self, url: str, **kwargs: Any): # type:ignore
219
- """POST request."""
220
- return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
234
+ # def post(self, url: str, **kwargs: Any) -> _RequestContextManager:
235
+ # """POST request."""
236
+ # return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
221
237
 
222
238
  @property
223
239
  def mail(self) -> str:
@@ -4,8 +4,9 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  import time
7
+ from collections.abc import AsyncGenerator
7
8
  from contextlib import AsyncExitStack, asynccontextmanager
8
- from typing import Any, AsyncGenerator, Optional
9
+ from typing import Any
9
10
 
10
11
  from redis.asyncio import Redis, from_url
11
12
 
@@ -30,15 +31,15 @@ async def get_redis() -> AsyncGenerator[Redis, None]:
30
31
  try:
31
32
  yield redis
32
33
  finally:
33
- MENV.database = None # type:ignore
34
+ MENV.database = None # type:ignore[assignment]
34
35
  await redis.close()
35
36
 
36
37
 
37
38
  async def session_iter(
38
39
  redis: Redis,
39
40
  *,
40
- match: Optional[dict[str, str]] = None,
41
- key_match: Optional[str] = None,
41
+ match: dict[str, str] | None = None,
42
+ key_match: str | None = None,
42
43
  ) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]:
43
44
  """Iterate over the Redis keys to find a specific session.
44
45
 
@@ -55,10 +56,10 @@ async def session_iter(
55
56
  sval = await redis.get(key)
56
57
  created, ses = 0, {}
57
58
  try:
58
- val = json.loads(sval) # type: ignore
59
+ val = json.loads(sval) # type: ignore[arg-type]
59
60
  created = int(val["created"])
60
61
  ses = val["session"]
61
- except Exception: # pylint: disable=broad-except
62
+ except Exception:
62
63
  pass
63
64
  if match:
64
65
  # Ensure we match all the supplied terms
@@ -73,7 +74,7 @@ async def session_iter(
73
74
 
74
75
 
75
76
  async def session_clean(
76
- redis: Redis, *, max_age: int = 90, expected_keys: Optional[dict] = None
77
+ redis: Redis, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
77
78
  ) -> None:
78
79
  """Clear session entries older than max_age days."""
79
80
  rem, keep = 0, 0
@@ -93,11 +94,29 @@ async def session_clean(
93
94
  _LOGGER.debug("No sessions removed (%s total)", keep)
94
95
 
95
96
 
96
- def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
97
+ async def invalid_sessions(redis: Redis) -> None:
98
+ """Find & clean invalid sessions."""
99
+ async for key in redis.scan_iter(count=100, match=f"{MENV.COOKIE_NAME}*"):
100
+ if not isinstance(key, str):
101
+ key = key.decode()
102
+ sval = await redis.get(key)
103
+ if sval is None:
104
+ continue
105
+ try:
106
+ val: dict = json.loads(sval)
107
+ assert isinstance(val["created"], int)
108
+ assert isinstance(val["session"], dict)
109
+ except Exception as err:
110
+ _LOGGER.warning("Removing session %s: %s", key, err)
111
+ await redis.delete(key)
112
+
113
+
114
+ def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
97
115
  """Create a AsyncMSAL session.
98
116
 
99
117
  When get_token refreshes the token retrieved from Redis, the save_cache callback
100
- will be responsible to update the cache in Redis."""
118
+ will be responsible to update the cache in Redis.
119
+ """
101
120
 
102
121
  async def async_save_cache(_: dict) -> None:
103
122
  """Save the token cache to Redis."""
@@ -111,11 +130,11 @@ def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
111
130
  except RuntimeError:
112
131
  asyncio.run(async_save_cache(*args))
113
132
 
114
- return AsyncMSAL(session, save_cache=save_cache)
133
+ return AsyncMSAL(session, save_callback=save_cache)
115
134
 
116
135
 
117
136
  async def get_session(
118
- email: str, *, redis: Optional[Redis] = None, scope: str = ""
137
+ email: str, *, redis: Redis | None = None, scope: str = ""
119
138
  ) -> AsyncMSAL:
120
139
  """Get a session from Redis."""
121
140
  cnt = 0
@@ -126,7 +145,7 @@ async def get_session(
126
145
  cnt += 1
127
146
  if scope and scope not in str(session.get("token_cache")).lower():
128
147
  continue
129
- return _session_factory(key, str(created), session)
148
+ return _session_factory(key, created, session)
130
149
  msg = f"Session for {email}"
131
150
  if not scope:
132
151
  raise ValueError(f"{msg} not found")
@@ -136,7 +155,7 @@ async def get_session(
136
155
  async def redis_get_json(key: str) -> list | dict | None:
137
156
  """Get a key from redis."""
138
157
  res = await MENV.database.get(key)
139
- if isinstance(res, (str, bytes, bytearray)):
158
+ if isinstance(res, str | bytes | bytearray):
140
159
  return json.loads(res)
141
160
  if res is not None:
142
161
  _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
@@ -148,7 +167,7 @@ async def redis_get(key: str) -> str | None:
148
167
  res = await MENV.database.get(key)
149
168
  if isinstance(res, str):
150
169
  return res
151
- if isinstance(res, (bytes, bytearray)):
170
+ if isinstance(res, bytes | bytearray):
152
171
  return res.decode()
153
172
  if res is not None:
154
173
  _LOGGER.warning("Unexpected type for %s: %s", key, type(res))