aiohttp-msal 0.7.1__tar.gz → 1.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1}/PKG-INFO +27 -28
  2. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1}/README.md +12 -3
  3. aiohttp_msal-1.0.1/pyproject.toml +128 -0
  4. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/__init__.py +27 -20
  5. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/msal_async.py +55 -55
  6. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/redis_tools.py +33 -14
  7. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/routes.py +6 -7
  8. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/settings.py +16 -12
  9. aiohttp_msal-1.0.1/src/aiohttp_msal/settings_base.py +93 -0
  10. {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/user_info.py +8 -4
  11. aiohttp_msal-0.7.1/LICENSE +0 -21
  12. aiohttp_msal-0.7.1/aiohttp_msal/settings_base.py +0 -83
  13. aiohttp_msal-0.7.1/aiohttp_msal.egg-info/PKG-INFO +0 -141
  14. aiohttp_msal-0.7.1/aiohttp_msal.egg-info/SOURCES.txt +0 -23
  15. aiohttp_msal-0.7.1/aiohttp_msal.egg-info/dependency_links.txt +0 -1
  16. aiohttp_msal-0.7.1/aiohttp_msal.egg-info/requires.txt +0 -16
  17. aiohttp_msal-0.7.1/aiohttp_msal.egg-info/top_level.txt +0 -2
  18. aiohttp_msal-0.7.1/aiohttp_msal.egg-info/zip-safe +0 -1
  19. aiohttp_msal-0.7.1/pyproject.toml +0 -65
  20. aiohttp_msal-0.7.1/setup.cfg +0 -55
  21. aiohttp_msal-0.7.1/setup.py +0 -6
  22. aiohttp_msal-0.7.1/tests/__init__.py +0 -0
  23. aiohttp_msal-0.7.1/tests/test_init.py +0 -79
  24. aiohttp_msal-0.7.1/tests/test_msal_async.py +0 -12
  25. aiohttp_msal-0.7.1/tests/test_redis_tools.py +0 -60
  26. aiohttp_msal-0.7.1/tests/test_settings.py +0 -11
@@ -1,39 +1,29 @@
1
- Metadata-Version: 2.1
2
- Name: aiohttp_msal
3
- Version: 0.7.1
1
+ Metadata-Version: 2.3
2
+ Name: aiohttp-msal
3
+ Version: 1.0.1
4
4
  Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
5
- Home-page: https://github.com/kellerza/aiohttp_msal
5
+ Keywords: aiohttp,asyncio,msal,oauth
6
6
  Author: Johann Kellerman
7
- Author-email: kellerza@gmail.com
7
+ Author-email: Johann Kellerman <kellerza@gmail.com>
8
8
  License: MIT
9
- Keywords: msal,oauth,aiohttp,asyncio
10
9
  Classifier: Development Status :: 4 - Beta
11
10
  Classifier: Intended Audience :: Developers
12
11
  Classifier: Natural Language :: English
13
- Classifier: Programming Language :: Python :: 3
14
12
  Classifier: Programming Language :: Python :: 3 :: Only
15
- Classifier: Programming Language :: Python :: 3.10
16
13
  Classifier: Programming Language :: Python :: 3.11
17
14
  Classifier: Programming Language :: Python :: 3.12
18
- Requires-Python: >=3.10
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Dist: aiohttp>=3.11.18,<3.13
17
+ Requires-Dist: aiohttp-session>=2.12.1,<3
18
+ Requires-Dist: attrs>=25.3,<26
19
+ Requires-Dist: msal>=1.32.3,<2
20
+ Requires-Dist: aiohttp-session[aioredis]>=2.12.1,<3 ; extra == 'aioredis'
21
+ Requires-Python: >=3.11
22
+ Project-URL: Homepage, https://github.com/kellerza/aiohttp_msal
23
+ Provides-Extra: aioredis
19
24
  Description-Content-Type: text/markdown
20
- License-File: LICENSE
21
- Requires-Dist: msal>=1.30.0
22
- Requires-Dist: aiohttp_session>=2.12
23
- Requires-Dist: aiohttp>=3.8
24
- Provides-Extra: redis
25
- Requires-Dist: aiohttp_session[aioredis]>=2.12; extra == "redis"
26
- Provides-Extra: tests
27
- Requires-Dist: black==24.8.0; extra == "tests"
28
- Requires-Dist: pylint==3.2.6; extra == "tests"
29
- Requires-Dist: flake8; extra == "tests"
30
- Requires-Dist: pytest-aiohttp; extra == "tests"
31
- Requires-Dist: pytest; extra == "tests"
32
- Requires-Dist: pytest-cov; extra == "tests"
33
- Requires-Dist: pytest-asyncio; extra == "tests"
34
- Requires-Dist: pytest-env; extra == "tests"
35
-
36
- # aiohttp_msal Python library
25
+
26
+ # Async based MSAL helper for aiohttp - aiohttp_msal Python library
37
27
 
38
28
  Authorization Code Flow Helper. Learn more about auth-code-flow at
39
29
  <https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow>
@@ -134,8 +124,17 @@ from aiohttp_msal.redis_tools import get_session
134
124
  def main()
135
125
  # Uses the redis.asyncio driver to retrieve the current token
136
126
  # Will update the token_cache if a RefreshToken was used
137
- ases = asyncio.run(get_session(MYEMAIL))
138
- client = GraphClient(ases.get_token)
127
+ ses = asyncio.run(get_session(MYEMAIL))
128
+ client = GraphClient(ses.get_token)
139
129
  # ...
140
130
  # use the Graphclient
141
131
  ```
132
+
133
+ ## Development
134
+
135
+ ```bash
136
+ uv sync --all-extras
137
+ uv tool install ruff
138
+ uv tool install codespell
139
+ uv tool install pyproject-fmt
140
+ ```
@@ -1,4 +1,4 @@
1
- # aiohttp_msal Python library
1
+ # Async based MSAL helper for aiohttp - aiohttp_msal Python library
2
2
 
3
3
  Authorization Code Flow Helper. Learn more about auth-code-flow at
4
4
  <https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow>
@@ -99,8 +99,17 @@ from aiohttp_msal.redis_tools import get_session
99
99
  def main()
100
100
  # Uses the redis.asyncio driver to retrieve the current token
101
101
  # Will update the token_cache if a RefreshToken was used
102
- ases = asyncio.run(get_session(MYEMAIL))
103
- client = GraphClient(ases.get_token)
102
+ ses = asyncio.run(get_session(MYEMAIL))
103
+ client = GraphClient(ses.get_token)
104
104
  # ...
105
105
  # use the Graphclient
106
106
  ```
107
+
108
+ ## Development
109
+
110
+ ```bash
111
+ uv sync --all-extras
112
+ uv tool install ruff
113
+ uv tool install codespell
114
+ uv tool install pyproject-fmt
115
+ ```
@@ -0,0 +1,128 @@
1
+ [build-system]
2
+ build-backend = "uv_build"
3
+ requires = [ "uv-build" ] # >=0.5.15,<0.6
4
+
5
+ [project]
6
+ name = "aiohttp-msal"
7
+ version = "1.0.1"
8
+ description = "Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp"
9
+ readme = "README.md"
10
+ keywords = [ "aiohttp", "asyncio", "msal", "oauth" ]
11
+ license = { text = "MIT" }
12
+ authors = [ { name = "Johann Kellerman", email = "kellerza@gmail.com" } ]
13
+ requires-python = ">=3.11"
14
+ classifiers = [
15
+ "Development Status :: 4 - Beta",
16
+ "Intended Audience :: Developers",
17
+ "Natural Language :: English",
18
+ "Programming Language :: Python :: 3 :: Only",
19
+ "Programming Language :: Python :: 3.11",
20
+ "Programming Language :: Python :: 3.12",
21
+ "Programming Language :: Python :: 3.13",
22
+ ]
23
+ dependencies = [
24
+ "aiohttp>=3.11.18,<3.13",
25
+ "aiohttp-session>=2.12.1,<3",
26
+ "attrs>=25.3,<26",
27
+ "msal>=1.32.3,<2",
28
+ ]
29
+ optional-dependencies.aioredis = [ "aiohttp-session[aioredis]>=2.12.1,<3" ]
30
+ urls.Homepage = "https://github.com/kellerza/aiohttp_msal"
31
+
32
+ [dependency-groups]
33
+ dev = [
34
+ "mypy",
35
+ "pytest",
36
+ "pytest-aiohttp",
37
+ "pytest-asyncio",
38
+ "pytest-cov",
39
+ "pytest-env",
40
+ "types-redis",
41
+ ]
42
+
43
+ [tool.ruff]
44
+ include = [ "aiohttp_msal/**/*.py", "tests/*.py" ]
45
+
46
+ format.line-ending = "lf"
47
+ format.docstring-code-format = true
48
+ lint.select = [
49
+ "A", # flake8-builtins
50
+ "ASYNC", # flake8-async
51
+ "B", # bugbear
52
+ "D", # pydocstyle
53
+ "E", # pycodestyle
54
+ "F", # pyflakes
55
+ "I", # isort
56
+ "PGH", # pygrep-hooks
57
+ "PIE", # flake8-pie
58
+ "PL", # pylint
59
+ "PTH", # flake8-pathlib
60
+ "PYI", # flake8-pyi
61
+ "RUF", # ruff
62
+ "UP", # pyupgrade
63
+ "W", # pycodestyle
64
+ ]
65
+ lint.ignore = [
66
+ "D203",
67
+ "D213",
68
+ "E203",
69
+ "E501",
70
+ "PLC0415",
71
+ "PLR2004",
72
+ "PLW2901",
73
+ "UP047",
74
+ ]
75
+ lint.isort.no-lines-before = [ "future", "standard-library" ]
76
+
77
+ [tool.codespell]
78
+ skip = [ "build/*", "*.json", "*.csv", "**/node_modules/*", "./s2-js/dist/*" ]
79
+ #ignore-words-list = []
80
+
81
+ [tool.pytest.ini_options]
82
+ pythonpath = [ ".", "src" ]
83
+ filterwarnings = "ignore:.+@coroutine.+deprecated.+"
84
+ norecursedirs = [ ".git", "modules" ]
85
+ log_cli = true
86
+ log_cli_level = "DEBUG"
87
+ asyncio_mode = "auto"
88
+ addopts = "--cov=aiohttp_msal --cov-report xml:cov.xml"
89
+ asyncio_default_fixture_loop_scope = "function"
90
+
91
+ env = [
92
+ "X_SP_APP_PW=p1",
93
+ "X_SP_APP_ID=i1",
94
+ "X_SP_AUTHORITY=a1",
95
+
96
+ "Y_SP_APP_PW=p2",
97
+ "Y_SP_APP_ID=i2",
98
+ "Y_SP_AUTHORITY=a2",
99
+
100
+ "A_NUM=5",
101
+ "A_BOOL=True",
102
+
103
+ "B_NUM=10",
104
+ "B_BOOL=False",
105
+ "B_ROOT=/c/",
106
+ ]
107
+
108
+ [tool.mypy]
109
+ disallow_untyped_defs = true
110
+ ignore_missing_imports = true
111
+
112
+ [tool.semantic_release]
113
+ commit = true
114
+ tag = true
115
+ vcs_release = true
116
+ commit_parser = "emoji"
117
+ version_toml = [ "pyproject.toml:project.version" ]
118
+ build_command = "pip install uv && uv build"
119
+ commit_version_number = true
120
+
121
+ # https://python-semantic-release.readthedocs.io/en/latest/multibranch_releases.html#configuring-multibranch-releases
122
+ [tool.semantic_release.branches.main]
123
+ match = "main"
124
+
125
+ [tool.semantic_release.commit_parser_options]
126
+ major_tags = [ ":boom:" ]
127
+ minor_tags = [ ":rocket:" ]
128
+ patch_tags = [ ":ambulance:", ":lock:", ":bug:", ":dolphin:" ]
@@ -1,9 +1,10 @@
1
1
  """aiohttp_msal."""
2
2
 
3
3
  import logging
4
- import typing
4
+ from collections.abc import Awaitable, Callable
5
5
  from functools import wraps
6
6
  from inspect import getfullargspec, iscoroutinefunction
7
+ from typing import TypeVar, TypeVarTuple, cast
7
8
 
8
9
  from aiohttp import ClientSession, web
9
10
  from aiohttp_session import get_session
@@ -14,43 +15,49 @@ from aiohttp_msal.settings import ENV
14
15
 
15
16
  _LOGGER = logging.getLogger(__name__)
16
17
 
17
- VERSION = "0.7.1"
18
+ _T = TypeVar("_T")
19
+ Ts = TypeVarTuple("Ts")
18
20
 
19
21
 
20
22
  def msal_session(
21
- *callbacks: typing.Callable[[AsyncMSAL], bool | typing.Awaitable[bool]],
23
+ *callbacks: Callable[[AsyncMSAL], bool | Awaitable[bool]],
22
24
  at_least_one: bool | None = False,
23
- ) -> typing.Callable:
25
+ ) -> Callable[
26
+ [Callable[[*Ts, AsyncMSAL], Awaitable[_T]]], Callable[[*Ts], Awaitable[_T]]
27
+ ]:
24
28
  """Session decorator.
25
29
 
26
30
  Arguments can include a list of function to perform login tests etc.
27
31
  """
28
32
 
29
- def _session(func: typing.Callable) -> typing.Callable:
33
+ def check_session(
34
+ func: Callable[[*Ts, AsyncMSAL], Awaitable[_T]],
35
+ ) -> Callable[[*Ts], Awaitable[_T]]:
30
36
  @wraps(func)
31
- async def __session(request: web.Request) -> typing.Callable:
37
+ async def wrapper(*args: *Ts) -> _T:
38
+ if len(args) < 1:
39
+ raise AssertionError("Requires a Request as the first parameter")
40
+ request = cast(web.Request, args[0])
32
41
  ses = AsyncMSAL(session=await get_session(request))
33
42
  for c_b in callbacks:
34
43
  _ok = await c_b(ses) if iscoroutinefunction(c_b) else c_b(ses)
35
44
 
36
45
  if at_least_one:
37
46
  if _ok:
38
- return await func(request=request, ses=ses)
39
- continue
40
-
41
- if not _ok:
47
+ return await func(*args, ses)
48
+ elif not _ok:
42
49
  raise web.HTTPForbidden
43
50
 
44
51
  if at_least_one:
45
52
  raise web.HTTPForbidden
46
- return await func(request=request, ses=ses)
53
+ return await func(*args, ses)
47
54
 
48
55
  assert iscoroutinefunction(func), f"Function needs to be a coroutine: {func}"
49
56
  spec = getfullargspec(func)
50
57
  assert "ses" in spec.args, f"Function needs to accept a session 'ses': {func}"
51
- return __session
58
+ return wrapper
52
59
 
53
- return _session
60
+ return check_session
54
61
 
55
62
 
56
63
  def auth_ok(ses: AsyncMSAL) -> bool:
@@ -59,11 +66,12 @@ def auth_ok(ses: AsyncMSAL) -> bool:
59
66
 
60
67
 
61
68
  def auth_or(
62
- *args: typing.Callable[[AsyncMSAL], bool | typing.Awaitable[bool]]
63
- ) -> typing.Callable[[AsyncMSAL], typing.Awaitable[bool]]:
69
+ *args: Callable[[AsyncMSAL], bool | Awaitable[bool]],
70
+ ) -> Callable[[AsyncMSAL], Awaitable[bool]]:
64
71
  """Ensure either of the methods is valid. An alternative to at_least_one=True.
65
72
 
66
- Arguments can include a list of function to perform login tests etc."""
73
+ Arguments can include a list of function to perform login tests etc.
74
+ """
67
75
 
68
76
  async def or_auth(ses: AsyncMSAL) -> bool:
69
77
  """Or."""
@@ -81,11 +89,10 @@ def auth_or(
81
89
  async def app_init_redis_session(
82
90
  app: web.Application, max_age: int = 3600 * 24 * 90
83
91
  ) -> None:
84
- """OPTIONAL: Initialize aiohttp_session with Redis storage.
92
+ """Init an aiohttp_session with Redis storage helper.
85
93
 
86
94
  You can initialize your own aiohttp_session & storage provider.
87
95
  """
88
- # pylint: disable=import-outside-toplevel
89
96
  from aiohttp_session import redis_storage
90
97
  from redis.asyncio import from_url
91
98
 
@@ -93,7 +100,7 @@ async def app_init_redis_session(
93
100
 
94
101
  _LOGGER.info("Connect to Redis %s", ENV.REDIS)
95
102
  try:
96
- ENV.database = from_url(ENV.REDIS) # pylint: disable=no-member
103
+ ENV.database = from_url(ENV.REDIS)
97
104
  # , encoding="utf-8", decode_responses=True
98
105
  except ConnectionRefusedError as err:
99
106
  raise ConnectionError("Could not connect to REDIS server") from err
@@ -119,7 +126,7 @@ async def check_proxy() -> None:
119
126
  if resp.ok:
120
127
  return
121
128
  raise ConnectionError(await resp.text())
122
- except Exception as err: # pylint: disable=broad-except
129
+ except Exception as err:
123
130
  raise ConnectionError(
124
131
  "No connection to the Internet. Required for OAuth. Check your Proxy?"
125
132
  ) from err
@@ -7,16 +7,24 @@ Once you have the OAuth tokens store in the session, you are free to make reques
7
7
 
8
8
  import asyncio
9
9
  import json
10
- from functools import partial, wraps
11
- from typing import Any, Callable, Literal
10
+ from collections.abc import Callable
11
+ from functools import cached_property, partial, partialmethod, wraps
12
+ from typing import Any, ClassVar, Literal, Unpack
12
13
 
13
14
  from aiohttp import web
14
- from aiohttp.client import ClientResponse, ClientSession, _RequestContextManager
15
+ from aiohttp.client import (
16
+ ClientResponse,
17
+ ClientSession,
18
+ _RequestContextManager,
19
+ _RequestOptions,
20
+ )
21
+ from aiohttp.typedefs import StrOrURL
15
22
  from aiohttp_session import Session
16
23
  from msal import ConfidentialClientApplication, SerializableTokenCache
17
24
 
18
25
  from aiohttp_msal.settings import ENV
19
26
 
27
+ HttpMethods = Literal["get", "post", "put", "patch", "delete"]
20
28
  HTTP_GET = "get"
21
29
  HTTP_POST = "post"
22
30
  HTTP_PUT = "put"
@@ -62,60 +70,51 @@ class AsyncMSAL:
62
70
  Use until such time as MSAL Python gets a true async version.
63
71
  """
64
72
 
65
- _token_cache: SerializableTokenCache = None
66
- _app: ConfidentialClientApplication = None
67
- _clientsession: ClientSession = None # type: ignore
73
+ client_session: ClassVar[ClientSession | None] = None
68
74
 
69
75
  def __init__(
70
76
  self,
71
- session: Session | dict[str, str],
72
- save_cache: Callable[[Session | dict[str, str]], None] | None = None,
77
+ session: Session | dict[str, Any],
78
+ save_callback: Callable[[Session | dict[str, Any]], None] | None = None,
73
79
  ):
74
80
  """Init the class.
75
81
 
76
- **save_token_cache** will be called if the token cache changes. Optional.
82
+ **save_callback** will be called if the token cache changes. Optional.
77
83
  Not required when the session parameter is an aiohttp_session.Session.
78
84
  """
79
85
  self.session = session
80
- if save_cache:
81
- self.save_token_cache = save_cache
82
- if not isinstance(session, (Session, dict)):
86
+ self.save_callback = save_callback
87
+ if not isinstance(session, Session | dict):
83
88
  raise ValueError(f"session or dict-like object required {session}")
84
89
 
85
- @property
90
+ @cached_property
86
91
  def token_cache(self) -> SerializableTokenCache:
87
92
  """Get the token_cache."""
88
- if not self._token_cache:
89
- self._token_cache = SerializableTokenCache()
90
- # _load_token_cache
91
- if self.session and self.session.get(TOKEN_CACHE):
92
- self._token_cache.deserialize(self.session[TOKEN_CACHE])
93
-
94
- return self._token_cache
93
+ res = SerializableTokenCache()
94
+ if self.session and self.session.get(TOKEN_CACHE):
95
+ res.deserialize(self.session[TOKEN_CACHE])
96
+ return res
95
97
 
96
- @property
98
+ @cached_property
97
99
  def app(self) -> ConfidentialClientApplication:
98
100
  """Create the application using the cache.
99
101
 
100
102
  Based on: https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76
101
103
  """
102
- if not self._app:
103
- token_cache = self.token_cache
104
- self._app = ConfidentialClientApplication(
105
- client_id=ENV.SP_APP_ID,
106
- client_credential=ENV.SP_APP_PW,
107
- authority=ENV.SP_AUTHORITY, # common/oauth2/v2.0/token'
108
- validate_authority=False,
109
- token_cache=token_cache,
110
- )
111
- return self._app
104
+ return ConfidentialClientApplication(
105
+ client_id=ENV.SP_APP_ID,
106
+ client_credential=ENV.SP_APP_PW,
107
+ authority=ENV.SP_AUTHORITY, # common/oauth2/v2.0/token'
108
+ validate_authority=False,
109
+ token_cache=self.token_cache,
110
+ )
112
111
 
113
- def _save_token_cache(self) -> None:
112
+ def save_token_cache(self) -> None:
114
113
  """Save the token cache if it changed."""
115
114
  if self.token_cache.has_state_changed:
116
115
  self.session[TOKEN_CACHE] = self.token_cache.serialize()
117
- if hasattr(self, "save_token_cache"):
118
- self.save_token_cache(self.token_cache)
116
+ if self.save_callback:
117
+ self.save_callback(self.session)
119
118
 
120
119
  def build_auth_code_flow(
121
120
  self,
@@ -125,8 +124,8 @@ class AsyncMSAL:
125
124
  **kwargs: Any,
126
125
  ) -> str:
127
126
  """First step - Start the flow."""
128
- self.session[TOKEN_CACHE] = None # type: ignore
129
- self.session[USER_EMAIL] = None # type: ignore
127
+ self.session[TOKEN_CACHE] = None
128
+ self.session[USER_EMAIL] = None
130
129
  self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
131
130
  scopes or DEFAULT_SCOPES,
132
131
  redirect_uri=redirect_uri,
@@ -149,10 +148,9 @@ class AsyncMSAL:
149
148
  raise web.HTTPBadRequest(text=str(result["error"]))
150
149
  if "id_token_claims" not in result:
151
150
  raise web.HTTPBadRequest(text=f"Expected id_token_claims in {result}")
152
- self._save_token_cache()
153
- self.session[USER_EMAIL] = result.get("id_token_claims").get(
154
- "preferred_username"
155
- )
151
+ self.save_token_cache()
152
+ if tok := result.get("id_token_claims"):
153
+ self.session[USER_EMAIL] = tok.get("preferred_username")
156
154
 
157
155
  async def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
158
156
  """Second step - Acquire token, async version."""
@@ -167,7 +165,7 @@ class AsyncMSAL:
167
165
  result = self.app.acquire_token_silent(
168
166
  scopes=scopes or DEFAULT_SCOPES, account=accounts[0]
169
167
  )
170
- self._save_token_cache()
168
+ self.save_token_cache()
171
169
  return result
172
170
  return None
173
171
 
@@ -175,7 +173,9 @@ class AsyncMSAL:
175
173
  """Acquire a token based on username."""
176
174
  return await asyncio.get_event_loop().run_in_executor(None, self.get_token)
177
175
 
178
- async def request(self, method: str, url: str, **kwargs: Any) -> ClientResponse:
176
+ async def request(
177
+ self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
178
+ ) -> ClientResponse:
179
179
  """Make a request to url using an oauth session.
180
180
 
181
181
  :param str url: url to send request to
@@ -184,16 +184,14 @@ class AsyncMSAL:
184
184
  :return: Response of the request
185
185
  :rtype: aiohttp.Response
186
186
  """
187
- if not self._clientsession:
188
- AsyncMSAL._clientsession = ClientSession(trust_env=True)
189
-
190
187
  token = await self.async_get_token()
191
188
  if token is None:
192
189
  raise web.HTTPClientError(text="No login token available.")
193
190
 
194
191
  kwargs = kwargs.copy()
195
192
  # Ensure headers exist & make a copy
196
- kwargs["headers"] = headers = dict(kwargs.get("headers", {}))
193
+ headers = dict[str, str](kwargs.get("headers") or {}) # type:ignore[arg-type]
194
+ kwargs["headers"] = headers
197
195
 
198
196
  headers["Authorization"] = "Bearer " + token["access_token"]
199
197
 
@@ -207,22 +205,24 @@ class AsyncMSAL:
207
205
  if "data" in kwargs:
208
206
  kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json
209
207
 
210
- response = await self._clientsession.request(method, url, **kwargs)
208
+ if not AsyncMSAL.client_session:
209
+ AsyncMSAL.client_session = ClientSession(trust_env=True)
211
210
 
212
- return response
211
+ return await AsyncMSAL.client_session.request(method, url, **kwargs)
213
212
 
214
- def get(self, url: str, **kwargs: Any): # type:ignore
215
- """GET Request."""
216
- return _RequestContextManager(self.request(HTTP_GET, url, **kwargs))
213
+ def request_ctx(
214
+ self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
215
+ ) -> _RequestContextManager:
216
+ """Request context manager."""
217
+ return _RequestContextManager(self.request(method, url, **kwargs))
217
218
 
218
- def post(self, url: str, **kwargs: Any): # type:ignore
219
- """POST request."""
220
- return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
219
+ get = partialmethod(request_ctx, HTTP_GET)
220
+ post = partialmethod(request_ctx, HTTP_POST)
221
221
 
222
222
  @property
223
223
  def mail(self) -> str:
224
224
  """User email."""
225
- return self.session.get("mail", "")
225
+ return self.session.get(USER_EMAIL, "")
226
226
 
227
227
  @property
228
228
  def manager_mail(self) -> str:
@@ -4,8 +4,9 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  import time
7
+ from collections.abc import AsyncGenerator
7
8
  from contextlib import AsyncExitStack, asynccontextmanager
8
- from typing import Any, AsyncGenerator, Optional
9
+ from typing import Any
9
10
 
10
11
  from redis.asyncio import Redis, from_url
11
12
 
@@ -30,15 +31,15 @@ async def get_redis() -> AsyncGenerator[Redis, None]:
30
31
  try:
31
32
  yield redis
32
33
  finally:
33
- MENV.database = None # type:ignore
34
+ MENV.database = None # type:ignore[assignment]
34
35
  await redis.close()
35
36
 
36
37
 
37
38
  async def session_iter(
38
39
  redis: Redis,
39
40
  *,
40
- match: Optional[dict[str, str]] = None,
41
- key_match: Optional[str] = None,
41
+ match: dict[str, str] | None = None,
42
+ key_match: str | None = None,
42
43
  ) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]:
43
44
  """Iterate over the Redis keys to find a specific session.
44
45
 
@@ -55,10 +56,10 @@ async def session_iter(
55
56
  sval = await redis.get(key)
56
57
  created, ses = 0, {}
57
58
  try:
58
- val = json.loads(sval) # type: ignore
59
+ val = json.loads(sval) # type: ignore[arg-type]
59
60
  created = int(val["created"])
60
61
  ses = val["session"]
61
- except Exception: # pylint: disable=broad-except
62
+ except Exception:
62
63
  pass
63
64
  if match:
64
65
  # Ensure we match all the supplied terms
@@ -73,7 +74,7 @@ async def session_iter(
73
74
 
74
75
 
75
76
  async def session_clean(
76
- redis: Redis, *, max_age: int = 90, expected_keys: Optional[dict] = None
77
+ redis: Redis, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
77
78
  ) -> None:
78
79
  """Clear session entries older than max_age days."""
79
80
  rem, keep = 0, 0
@@ -93,11 +94,29 @@ async def session_clean(
93
94
  _LOGGER.debug("No sessions removed (%s total)", keep)
94
95
 
95
96
 
96
- def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
97
+ async def invalid_sessions(redis: Redis) -> None:
98
+ """Find & clean invalid sessions."""
99
+ async for key in redis.scan_iter(count=100, match=f"{MENV.COOKIE_NAME}*"):
100
+ if not isinstance(key, str):
101
+ key = key.decode()
102
+ sval = await redis.get(key)
103
+ if sval is None:
104
+ continue
105
+ try:
106
+ val: dict = json.loads(sval)
107
+ assert isinstance(val["created"], int)
108
+ assert isinstance(val["session"], dict)
109
+ except Exception as err:
110
+ _LOGGER.warning("Removing session %s: %s", key, err)
111
+ await redis.delete(key)
112
+
113
+
114
+ def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
97
115
  """Create a AsyncMSAL session.
98
116
 
99
117
  When get_token refreshes the token retrieved from Redis, the save_cache callback
100
- will be responsible to update the cache in Redis."""
118
+ will be responsible to update the cache in Redis.
119
+ """
101
120
 
102
121
  async def async_save_cache(_: dict) -> None:
103
122
  """Save the token cache to Redis."""
@@ -111,11 +130,11 @@ def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
111
130
  except RuntimeError:
112
131
  asyncio.run(async_save_cache(*args))
113
132
 
114
- return AsyncMSAL(session, save_cache=save_cache)
133
+ return AsyncMSAL(session, save_callback=save_cache)
115
134
 
116
135
 
117
136
  async def get_session(
118
- email: str, *, redis: Optional[Redis] = None, scope: str = ""
137
+ email: str, *, redis: Redis | None = None, scope: str = ""
119
138
  ) -> AsyncMSAL:
120
139
  """Get a session from Redis."""
121
140
  cnt = 0
@@ -126,7 +145,7 @@ async def get_session(
126
145
  cnt += 1
127
146
  if scope and scope not in str(session.get("token_cache")).lower():
128
147
  continue
129
- return _session_factory(key, str(created), session)
148
+ return _session_factory(key, created, session)
130
149
  msg = f"Session for {email}"
131
150
  if not scope:
132
151
  raise ValueError(f"{msg} not found")
@@ -136,7 +155,7 @@ async def get_session(
136
155
  async def redis_get_json(key: str) -> list | dict | None:
137
156
  """Get a key from redis."""
138
157
  res = await MENV.database.get(key)
139
- if isinstance(res, (str, bytes, bytearray)):
158
+ if isinstance(res, str | bytes | bytearray):
140
159
  return json.loads(res)
141
160
  if res is not None:
142
161
  _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
@@ -148,7 +167,7 @@ async def redis_get(key: str) -> str | None:
148
167
  res = await MENV.database.get(key)
149
168
  if isinstance(res, str):
150
169
  return res
151
- if isinstance(res, (bytes, bytearray)):
170
+ if isinstance(res, bytes | bytearray):
152
171
  return res.decode()
153
172
  if res is not None:
154
173
  _LOGGER.warning("Unexpected type for %s: %s", key, type(res))