aiohttp-msal 0.7.0__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0}/PKG-INFO +23 -24
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0}/README.md +9 -0
- aiohttp_msal-1.0.0/pyproject.toml +128 -0
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/__init__.py +34 -20
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/msal_async.py +51 -35
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/redis_tools.py +33 -14
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/routes.py +11 -12
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/settings.py +16 -12
- aiohttp_msal-1.0.0/src/aiohttp_msal/settings_base.py +93 -0
- {aiohttp_msal-0.7.0 → aiohttp_msal-1.0.0/src}/aiohttp_msal/user_info.py +8 -4
- aiohttp_msal-0.7.0/LICENSE +0 -21
- aiohttp_msal-0.7.0/aiohttp_msal/settings_base.py +0 -83
- aiohttp_msal-0.7.0/aiohttp_msal.egg-info/PKG-INFO +0 -141
- aiohttp_msal-0.7.0/aiohttp_msal.egg-info/SOURCES.txt +0 -23
- aiohttp_msal-0.7.0/aiohttp_msal.egg-info/dependency_links.txt +0 -1
- aiohttp_msal-0.7.0/aiohttp_msal.egg-info/requires.txt +0 -16
- aiohttp_msal-0.7.0/aiohttp_msal.egg-info/top_level.txt +0 -2
- aiohttp_msal-0.7.0/aiohttp_msal.egg-info/zip-safe +0 -1
- aiohttp_msal-0.7.0/pyproject.toml +0 -65
- aiohttp_msal-0.7.0/setup.cfg +0 -55
- aiohttp_msal-0.7.0/setup.py +0 -6
- aiohttp_msal-0.7.0/tests/__init__.py +0 -0
- aiohttp_msal-0.7.0/tests/test_init.py +0 -4
- aiohttp_msal-0.7.0/tests/test_msal_async.py +0 -12
- aiohttp_msal-0.7.0/tests/test_redis_tools.py +0 -60
- aiohttp_msal-0.7.0/tests/test_settings.py +0 -11
|
@@ -1,37 +1,27 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
2
|
-
Name:
|
|
3
|
-
Version: 0.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: aiohttp-msal
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
|
|
5
|
-
|
|
5
|
+
Keywords: aiohttp,asyncio,msal,oauth
|
|
6
6
|
Author: Johann Kellerman
|
|
7
|
-
Author-email: kellerza@gmail.com
|
|
7
|
+
Author-email: Johann Kellerman <kellerza@gmail.com>
|
|
8
8
|
License: MIT
|
|
9
|
-
Keywords: msal,oauth,aiohttp,asyncio
|
|
10
9
|
Classifier: Development Status :: 4 - Beta
|
|
11
10
|
Classifier: Intended Audience :: Developers
|
|
12
11
|
Classifier: Natural Language :: English
|
|
13
|
-
Classifier: Programming Language :: Python :: 3
|
|
14
12
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
16
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
-
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Requires-Dist: aiohttp>=3.11.18,<3.13
|
|
17
|
+
Requires-Dist: aiohttp-session>=2.12.1,<3
|
|
18
|
+
Requires-Dist: attrs>=25.3,<26
|
|
19
|
+
Requires-Dist: msal>=1.32.3,<2
|
|
20
|
+
Requires-Dist: aiohttp-session[aioredis]>=2.12.1,<3 ; extra == 'aioredis'
|
|
21
|
+
Requires-Python: >=3.11
|
|
22
|
+
Project-URL: Homepage, https://github.com/kellerza/aiohttp_msal
|
|
23
|
+
Provides-Extra: aioredis
|
|
19
24
|
Description-Content-Type: text/markdown
|
|
20
|
-
License-File: LICENSE
|
|
21
|
-
Requires-Dist: msal>=1.30.0
|
|
22
|
-
Requires-Dist: aiohttp_session>=2.12
|
|
23
|
-
Requires-Dist: aiohttp>=3.8
|
|
24
|
-
Provides-Extra: redis
|
|
25
|
-
Requires-Dist: aiohttp_session[aioredis]>=2.12; extra == "redis"
|
|
26
|
-
Provides-Extra: tests
|
|
27
|
-
Requires-Dist: black==24.8.0; extra == "tests"
|
|
28
|
-
Requires-Dist: pylint==3.2.6; extra == "tests"
|
|
29
|
-
Requires-Dist: flake8; extra == "tests"
|
|
30
|
-
Requires-Dist: pytest-aiohttp; extra == "tests"
|
|
31
|
-
Requires-Dist: pytest; extra == "tests"
|
|
32
|
-
Requires-Dist: pytest-cov; extra == "tests"
|
|
33
|
-
Requires-Dist: pytest-asyncio; extra == "tests"
|
|
34
|
-
Requires-Dist: pytest-env; extra == "tests"
|
|
35
25
|
|
|
36
26
|
# aiohttp_msal Python library
|
|
37
27
|
|
|
@@ -139,3 +129,12 @@ def main()
|
|
|
139
129
|
# ...
|
|
140
130
|
# use the Graphclient
|
|
141
131
|
```
|
|
132
|
+
|
|
133
|
+
## Development
|
|
134
|
+
|
|
135
|
+
```bash
|
|
136
|
+
uv sync --all-extras
|
|
137
|
+
uv tool install ruff
|
|
138
|
+
uv tool install codespell
|
|
139
|
+
uv tool install pyproject-fmt
|
|
140
|
+
```
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
build-backend = "uv_build"
|
|
3
|
+
requires = [ "uv-build" ] # >=0.5.15,<0.6
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "aiohttp-msal"
|
|
7
|
+
version = "1.0.0"
|
|
8
|
+
description = "Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
keywords = [ "aiohttp", "asyncio", "msal", "oauth" ]
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [ { name = "Johann Kellerman", email = "kellerza@gmail.com" } ]
|
|
13
|
+
requires-python = ">=3.11"
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 4 - Beta",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Natural Language :: English",
|
|
18
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
19
|
+
"Programming Language :: Python :: 3.11",
|
|
20
|
+
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Programming Language :: Python :: 3.13",
|
|
22
|
+
]
|
|
23
|
+
dependencies = [
|
|
24
|
+
"aiohttp>=3.11.18,<3.13",
|
|
25
|
+
"aiohttp-session>=2.12.1,<3",
|
|
26
|
+
"attrs>=25.3,<26",
|
|
27
|
+
"msal>=1.32.3,<2",
|
|
28
|
+
]
|
|
29
|
+
optional-dependencies.aioredis = [ "aiohttp-session[aioredis]>=2.12.1,<3" ]
|
|
30
|
+
urls.Homepage = "https://github.com/kellerza/aiohttp_msal"
|
|
31
|
+
|
|
32
|
+
[dependency-groups]
|
|
33
|
+
dev = [
|
|
34
|
+
"mypy",
|
|
35
|
+
"pytest",
|
|
36
|
+
"pytest-aiohttp",
|
|
37
|
+
"pytest-asyncio",
|
|
38
|
+
"pytest-cov",
|
|
39
|
+
"pytest-env",
|
|
40
|
+
"types-redis",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
[tool.ruff]
|
|
44
|
+
include = [ "aiohttp_msal/**/*.py", "tests/*.py" ]
|
|
45
|
+
|
|
46
|
+
format.line-ending = "lf"
|
|
47
|
+
format.docstring-code-format = true
|
|
48
|
+
lint.select = [
|
|
49
|
+
"A", # flake8-builtins
|
|
50
|
+
"ASYNC", # flake8-async
|
|
51
|
+
"B", # bugbear
|
|
52
|
+
"D", # pydocstyle
|
|
53
|
+
"E", # pycodestyle
|
|
54
|
+
"F", # pyflakes
|
|
55
|
+
"I", # isort
|
|
56
|
+
"PGH", # pygrep-hooks
|
|
57
|
+
"PIE", # flake8-pie
|
|
58
|
+
"PL", # pylint
|
|
59
|
+
"PTH", # flake8-pathlib
|
|
60
|
+
"PYI", # flake8-pyi
|
|
61
|
+
"RUF", # ruff
|
|
62
|
+
"UP", # pyupgrade
|
|
63
|
+
"W", # pycodestyle
|
|
64
|
+
]
|
|
65
|
+
lint.ignore = [
|
|
66
|
+
"D203",
|
|
67
|
+
"D213",
|
|
68
|
+
"E203",
|
|
69
|
+
"E501",
|
|
70
|
+
"PLC0415",
|
|
71
|
+
"PLR2004",
|
|
72
|
+
"PLW2901",
|
|
73
|
+
"UP047",
|
|
74
|
+
]
|
|
75
|
+
lint.isort.no-lines-before = [ "future", "standard-library" ]
|
|
76
|
+
|
|
77
|
+
[tool.codespell]
|
|
78
|
+
skip = [ "build/*", "*.json", "*.csv", "**/node_modules/*", "./s2-js/dist/*" ]
|
|
79
|
+
#ignore-words-list = []
|
|
80
|
+
|
|
81
|
+
[tool.pytest.ini_options]
|
|
82
|
+
pythonpath = [ ".", "src" ]
|
|
83
|
+
filterwarnings = "ignore:.+@coroutine.+deprecated.+"
|
|
84
|
+
norecursedirs = [ ".git", "modules" ]
|
|
85
|
+
log_cli = true
|
|
86
|
+
log_cli_level = "DEBUG"
|
|
87
|
+
asyncio_mode = "auto"
|
|
88
|
+
addopts = "--cov=aiohttp_msal --cov-report xml:cov.xml"
|
|
89
|
+
asyncio_default_fixture_loop_scope = "function"
|
|
90
|
+
|
|
91
|
+
env = [
|
|
92
|
+
"X_SP_APP_PW=p1",
|
|
93
|
+
"X_SP_APP_ID=i1",
|
|
94
|
+
"X_SP_AUTHORITY=a1",
|
|
95
|
+
|
|
96
|
+
"Y_SP_APP_PW=p2",
|
|
97
|
+
"Y_SP_APP_ID=i2",
|
|
98
|
+
"Y_SP_AUTHORITY=a2",
|
|
99
|
+
|
|
100
|
+
"A_NUM=5",
|
|
101
|
+
"A_BOOL=True",
|
|
102
|
+
|
|
103
|
+
"B_NUM=10",
|
|
104
|
+
"B_BOOL=False",
|
|
105
|
+
"B_ROOT=/c/",
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
[tool.mypy]
|
|
109
|
+
disallow_untyped_defs = true
|
|
110
|
+
ignore_missing_imports = true
|
|
111
|
+
|
|
112
|
+
[tool.semantic_release]
|
|
113
|
+
commit = true
|
|
114
|
+
tag = true
|
|
115
|
+
vcs_release = true
|
|
116
|
+
commit_parser = "emoji"
|
|
117
|
+
version_toml = [ "pyproject.toml:project.version" ]
|
|
118
|
+
build_command = "pip install uv && uv build"
|
|
119
|
+
commit_version_number = true
|
|
120
|
+
|
|
121
|
+
# https://python-semantic-release.readthedocs.io/en/latest/multibranch_releases.html#configuring-multibranch-releases
|
|
122
|
+
[tool.semantic_release.branches.main]
|
|
123
|
+
match = "main"
|
|
124
|
+
|
|
125
|
+
[tool.semantic_release.commit_parser_options]
|
|
126
|
+
major_tags = [ ":boom:" ]
|
|
127
|
+
minor_tags = [ ":rocket:" ]
|
|
128
|
+
patch_tags = [ ":ambulance:", ":lock:", ":bug:", ":dolphin:" ]
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
"""aiohttp_msal."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
5
|
from functools import wraps
|
|
6
6
|
from inspect import getfullargspec, iscoroutinefunction
|
|
7
|
+
from typing import TypeVar, TypeVarTuple, cast
|
|
7
8
|
|
|
8
9
|
from aiohttp import ClientSession, web
|
|
9
10
|
from aiohttp_session import get_session
|
|
@@ -14,49 +15,63 @@ from aiohttp_msal.settings import ENV
|
|
|
14
15
|
|
|
15
16
|
_LOGGER = logging.getLogger(__name__)
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
_T = TypeVar("_T")
|
|
19
|
+
Ts = TypeVarTuple("Ts")
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
def msal_session(
|
|
21
|
-
*callbacks:
|
|
23
|
+
*callbacks: Callable[[AsyncMSAL], bool | Awaitable[bool]],
|
|
22
24
|
at_least_one: bool | None = False,
|
|
23
|
-
) ->
|
|
25
|
+
) -> Callable[
|
|
26
|
+
[Callable[[*Ts, AsyncMSAL], Awaitable[_T]]], Callable[[*Ts], Awaitable[_T]]
|
|
27
|
+
]:
|
|
24
28
|
"""Session decorator.
|
|
25
29
|
|
|
26
30
|
Arguments can include a list of function to perform login tests etc.
|
|
27
31
|
"""
|
|
28
32
|
|
|
29
|
-
def
|
|
33
|
+
def check_session(
|
|
34
|
+
func: Callable[[*Ts, AsyncMSAL], Awaitable[_T]],
|
|
35
|
+
) -> Callable[[*Ts], Awaitable[_T]]:
|
|
30
36
|
@wraps(func)
|
|
31
|
-
async def
|
|
37
|
+
async def wrapper(*args: *Ts) -> _T:
|
|
38
|
+
if len(args) < 1:
|
|
39
|
+
raise AssertionError("Requires a Request as the first parameter")
|
|
40
|
+
request = cast(web.Request, args[0])
|
|
32
41
|
ses = AsyncMSAL(session=await get_session(request))
|
|
33
42
|
for c_b in callbacks:
|
|
34
43
|
_ok = await c_b(ses) if iscoroutinefunction(c_b) else c_b(ses)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
44
|
+
|
|
45
|
+
if at_least_one:
|
|
46
|
+
if _ok:
|
|
47
|
+
return await func(*args, ses)
|
|
48
|
+
elif not _ok:
|
|
38
49
|
raise web.HTTPForbidden
|
|
39
|
-
|
|
50
|
+
|
|
51
|
+
if at_least_one:
|
|
52
|
+
raise web.HTTPForbidden
|
|
53
|
+
return await func(*args, ses)
|
|
40
54
|
|
|
41
55
|
assert iscoroutinefunction(func), f"Function needs to be a coroutine: {func}"
|
|
42
56
|
spec = getfullargspec(func)
|
|
43
57
|
assert "ses" in spec.args, f"Function needs to accept a session 'ses': {func}"
|
|
44
|
-
return
|
|
58
|
+
return wrapper
|
|
45
59
|
|
|
46
|
-
return
|
|
60
|
+
return check_session
|
|
47
61
|
|
|
48
62
|
|
|
49
|
-
def
|
|
63
|
+
def auth_ok(ses: AsyncMSAL) -> bool:
|
|
50
64
|
"""Test if session was authenticated."""
|
|
51
65
|
return bool(ses.mail)
|
|
52
66
|
|
|
53
67
|
|
|
54
68
|
def auth_or(
|
|
55
|
-
*args:
|
|
56
|
-
) ->
|
|
69
|
+
*args: Callable[[AsyncMSAL], bool | Awaitable[bool]],
|
|
70
|
+
) -> Callable[[AsyncMSAL], Awaitable[bool]]:
|
|
57
71
|
"""Ensure either of the methods is valid. An alternative to at_least_one=True.
|
|
58
72
|
|
|
59
|
-
Arguments can include a list of function to perform login tests etc.
|
|
73
|
+
Arguments can include a list of function to perform login tests etc.
|
|
74
|
+
"""
|
|
60
75
|
|
|
61
76
|
async def or_auth(ses: AsyncMSAL) -> bool:
|
|
62
77
|
"""Or."""
|
|
@@ -74,11 +89,10 @@ def auth_or(
|
|
|
74
89
|
async def app_init_redis_session(
|
|
75
90
|
app: web.Application, max_age: int = 3600 * 24 * 90
|
|
76
91
|
) -> None:
|
|
77
|
-
"""
|
|
92
|
+
"""Init an aiohttp_session with Redis storage helper.
|
|
78
93
|
|
|
79
94
|
You can initialize your own aiohttp_session & storage provider.
|
|
80
95
|
"""
|
|
81
|
-
# pylint: disable=import-outside-toplevel
|
|
82
96
|
from aiohttp_session import redis_storage
|
|
83
97
|
from redis.asyncio import from_url
|
|
84
98
|
|
|
@@ -86,7 +100,7 @@ async def app_init_redis_session(
|
|
|
86
100
|
|
|
87
101
|
_LOGGER.info("Connect to Redis %s", ENV.REDIS)
|
|
88
102
|
try:
|
|
89
|
-
ENV.database = from_url(ENV.REDIS)
|
|
103
|
+
ENV.database = from_url(ENV.REDIS)
|
|
90
104
|
# , encoding="utf-8", decode_responses=True
|
|
91
105
|
except ConnectionRefusedError as err:
|
|
92
106
|
raise ConnectionError("Could not connect to REDIS server") from err
|
|
@@ -112,7 +126,7 @@ async def check_proxy() -> None:
|
|
|
112
126
|
if resp.ok:
|
|
113
127
|
return
|
|
114
128
|
raise ConnectionError(await resp.text())
|
|
115
|
-
except Exception as err:
|
|
129
|
+
except Exception as err:
|
|
116
130
|
raise ConnectionError(
|
|
117
131
|
"No connection to the Internet. Required for OAuth. Check your Proxy?"
|
|
118
132
|
) from err
|
|
@@ -7,16 +7,24 @@ Once you have the OAuth tokens store in the session, you are free to make reques
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
|
-
from
|
|
11
|
-
from
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from functools import partial, partialmethod, wraps
|
|
12
|
+
from typing import Any, ClassVar, Literal, Unpack
|
|
12
13
|
|
|
13
14
|
from aiohttp import web
|
|
14
|
-
from aiohttp.client import
|
|
15
|
+
from aiohttp.client import (
|
|
16
|
+
ClientResponse,
|
|
17
|
+
ClientSession,
|
|
18
|
+
_RequestContextManager,
|
|
19
|
+
_RequestOptions,
|
|
20
|
+
)
|
|
21
|
+
from aiohttp.typedefs import StrOrURL
|
|
15
22
|
from aiohttp_session import Session
|
|
16
23
|
from msal import ConfidentialClientApplication, SerializableTokenCache
|
|
17
24
|
|
|
18
25
|
from aiohttp_msal.settings import ENV
|
|
19
26
|
|
|
27
|
+
HttpMethods = Literal["get", "post", "put", "patch", "delete"]
|
|
20
28
|
HTTP_GET = "get"
|
|
21
29
|
HTTP_POST = "post"
|
|
22
30
|
HTTP_PUT = "put"
|
|
@@ -62,24 +70,23 @@ class AsyncMSAL:
|
|
|
62
70
|
Use until such time as MSAL Python gets a true async version.
|
|
63
71
|
"""
|
|
64
72
|
|
|
65
|
-
_token_cache: SerializableTokenCache
|
|
66
|
-
_app: ConfidentialClientApplication
|
|
67
|
-
|
|
73
|
+
_token_cache: SerializableTokenCache
|
|
74
|
+
_app: ConfidentialClientApplication
|
|
75
|
+
client_session: ClassVar[ClientSession | None] = None
|
|
68
76
|
|
|
69
77
|
def __init__(
|
|
70
78
|
self,
|
|
71
|
-
session: Session | dict[str,
|
|
72
|
-
|
|
79
|
+
session: Session | dict[str, Any],
|
|
80
|
+
save_callback: Callable[[Session | dict[str, Any]], None] | None = None,
|
|
73
81
|
):
|
|
74
82
|
"""Init the class.
|
|
75
83
|
|
|
76
|
-
**
|
|
84
|
+
**save_callback** will be called if the token cache changes. Optional.
|
|
77
85
|
Not required when the session parameter is an aiohttp_session.Session.
|
|
78
86
|
"""
|
|
79
87
|
self.session = session
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if not isinstance(session, (Session, dict)):
|
|
88
|
+
self.save_callback = save_callback
|
|
89
|
+
if not isinstance(session, Session | dict):
|
|
83
90
|
raise ValueError(f"session or dict-like object required {session}")
|
|
84
91
|
|
|
85
92
|
@property
|
|
@@ -110,12 +117,12 @@ class AsyncMSAL:
|
|
|
110
117
|
)
|
|
111
118
|
return self._app
|
|
112
119
|
|
|
113
|
-
def
|
|
120
|
+
def save_token_cache(self) -> None:
|
|
114
121
|
"""Save the token cache if it changed."""
|
|
115
122
|
if self.token_cache.has_state_changed:
|
|
116
123
|
self.session[TOKEN_CACHE] = self.token_cache.serialize()
|
|
117
|
-
if
|
|
118
|
-
self.
|
|
124
|
+
if self.save_callback:
|
|
125
|
+
self.save_callback(self.session)
|
|
119
126
|
|
|
120
127
|
def build_auth_code_flow(
|
|
121
128
|
self,
|
|
@@ -125,8 +132,8 @@ class AsyncMSAL:
|
|
|
125
132
|
**kwargs: Any,
|
|
126
133
|
) -> str:
|
|
127
134
|
"""First step - Start the flow."""
|
|
128
|
-
self.session[TOKEN_CACHE] = None
|
|
129
|
-
self.session[USER_EMAIL] = None
|
|
135
|
+
self.session[TOKEN_CACHE] = None
|
|
136
|
+
self.session[USER_EMAIL] = None
|
|
130
137
|
self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
|
|
131
138
|
scopes or DEFAULT_SCOPES,
|
|
132
139
|
redirect_uri=redirect_uri,
|
|
@@ -149,10 +156,9 @@ class AsyncMSAL:
|
|
|
149
156
|
raise web.HTTPBadRequest(text=str(result["error"]))
|
|
150
157
|
if "id_token_claims" not in result:
|
|
151
158
|
raise web.HTTPBadRequest(text=f"Expected id_token_claims in {result}")
|
|
152
|
-
self.
|
|
153
|
-
|
|
154
|
-
"preferred_username"
|
|
155
|
-
)
|
|
159
|
+
self.save_token_cache()
|
|
160
|
+
if tok := result.get("id_token_claims"):
|
|
161
|
+
self.session[USER_EMAIL] = tok.get("preferred_username")
|
|
156
162
|
|
|
157
163
|
async def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
|
|
158
164
|
"""Second step - Acquire token, async version."""
|
|
@@ -167,7 +173,7 @@ class AsyncMSAL:
|
|
|
167
173
|
result = self.app.acquire_token_silent(
|
|
168
174
|
scopes=scopes or DEFAULT_SCOPES, account=accounts[0]
|
|
169
175
|
)
|
|
170
|
-
self.
|
|
176
|
+
self.save_token_cache()
|
|
171
177
|
return result
|
|
172
178
|
return None
|
|
173
179
|
|
|
@@ -175,7 +181,9 @@ class AsyncMSAL:
|
|
|
175
181
|
"""Acquire a token based on username."""
|
|
176
182
|
return await asyncio.get_event_loop().run_in_executor(None, self.get_token)
|
|
177
183
|
|
|
178
|
-
async def request(
|
|
184
|
+
async def request(
|
|
185
|
+
self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
|
|
186
|
+
) -> ClientResponse:
|
|
179
187
|
"""Make a request to url using an oauth session.
|
|
180
188
|
|
|
181
189
|
:param str url: url to send request to
|
|
@@ -184,16 +192,14 @@ class AsyncMSAL:
|
|
|
184
192
|
:return: Response of the request
|
|
185
193
|
:rtype: aiohttp.Response
|
|
186
194
|
"""
|
|
187
|
-
if not self._clientsession:
|
|
188
|
-
AsyncMSAL._clientsession = ClientSession(trust_env=True)
|
|
189
|
-
|
|
190
195
|
token = await self.async_get_token()
|
|
191
196
|
if token is None:
|
|
192
197
|
raise web.HTTPClientError(text="No login token available.")
|
|
193
198
|
|
|
194
199
|
kwargs = kwargs.copy()
|
|
195
200
|
# Ensure headers exist & make a copy
|
|
196
|
-
|
|
201
|
+
headers = dict[str, str](kwargs.get("headers") or {}) # type:ignore[arg-type]
|
|
202
|
+
kwargs["headers"] = headers
|
|
197
203
|
|
|
198
204
|
headers["Authorization"] = "Bearer " + token["access_token"]
|
|
199
205
|
|
|
@@ -207,17 +213,27 @@ class AsyncMSAL:
|
|
|
207
213
|
if "data" in kwargs:
|
|
208
214
|
kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json
|
|
209
215
|
|
|
210
|
-
|
|
216
|
+
if not AsyncMSAL.client_session:
|
|
217
|
+
AsyncMSAL.client_session = ClientSession(trust_env=True)
|
|
218
|
+
|
|
219
|
+
return await AsyncMSAL.client_session.request(method, url, **kwargs)
|
|
220
|
+
|
|
221
|
+
def request_ctx(
|
|
222
|
+
self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
|
|
223
|
+
) -> _RequestContextManager:
|
|
224
|
+
"""Request context manager."""
|
|
225
|
+
return _RequestContextManager(self.request(method, url, **kwargs))
|
|
211
226
|
|
|
212
|
-
|
|
227
|
+
get = partialmethod(request_ctx, HTTP_GET)
|
|
228
|
+
post = partialmethod(request_ctx, HTTP_POST)
|
|
213
229
|
|
|
214
|
-
def get(self, url: str, **kwargs: Any)
|
|
215
|
-
|
|
216
|
-
|
|
230
|
+
# def get(self, url: str, **kwargs: Any) -> _RequestContextManager:
|
|
231
|
+
# """GET Request."""
|
|
232
|
+
# return _RequestContextManager(self.request(HTTP_GET, url, **kwargs))
|
|
217
233
|
|
|
218
|
-
def post(self, url: str, **kwargs: Any)
|
|
219
|
-
|
|
220
|
-
|
|
234
|
+
# def post(self, url: str, **kwargs: Any) -> _RequestContextManager:
|
|
235
|
+
# """POST request."""
|
|
236
|
+
# return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
|
|
221
237
|
|
|
222
238
|
@property
|
|
223
239
|
def mail(self) -> str:
|
|
@@ -4,8 +4,9 @@ import asyncio
|
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
6
|
import time
|
|
7
|
+
from collections.abc import AsyncGenerator
|
|
7
8
|
from contextlib import AsyncExitStack, asynccontextmanager
|
|
8
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
from redis.asyncio import Redis, from_url
|
|
11
12
|
|
|
@@ -30,15 +31,15 @@ async def get_redis() -> AsyncGenerator[Redis, None]:
|
|
|
30
31
|
try:
|
|
31
32
|
yield redis
|
|
32
33
|
finally:
|
|
33
|
-
MENV.database = None # type:ignore
|
|
34
|
+
MENV.database = None # type:ignore[assignment]
|
|
34
35
|
await redis.close()
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
async def session_iter(
|
|
38
39
|
redis: Redis,
|
|
39
40
|
*,
|
|
40
|
-
match:
|
|
41
|
-
key_match:
|
|
41
|
+
match: dict[str, str] | None = None,
|
|
42
|
+
key_match: str | None = None,
|
|
42
43
|
) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]:
|
|
43
44
|
"""Iterate over the Redis keys to find a specific session.
|
|
44
45
|
|
|
@@ -55,10 +56,10 @@ async def session_iter(
|
|
|
55
56
|
sval = await redis.get(key)
|
|
56
57
|
created, ses = 0, {}
|
|
57
58
|
try:
|
|
58
|
-
val = json.loads(sval) # type: ignore
|
|
59
|
+
val = json.loads(sval) # type: ignore[arg-type]
|
|
59
60
|
created = int(val["created"])
|
|
60
61
|
ses = val["session"]
|
|
61
|
-
except Exception:
|
|
62
|
+
except Exception:
|
|
62
63
|
pass
|
|
63
64
|
if match:
|
|
64
65
|
# Ensure we match all the supplied terms
|
|
@@ -73,7 +74,7 @@ async def session_iter(
|
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
async def session_clean(
|
|
76
|
-
redis: Redis, *, max_age: int = 90, expected_keys:
|
|
77
|
+
redis: Redis, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
|
|
77
78
|
) -> None:
|
|
78
79
|
"""Clear session entries older than max_age days."""
|
|
79
80
|
rem, keep = 0, 0
|
|
@@ -93,11 +94,29 @@ async def session_clean(
|
|
|
93
94
|
_LOGGER.debug("No sessions removed (%s total)", keep)
|
|
94
95
|
|
|
95
96
|
|
|
96
|
-
def
|
|
97
|
+
async def invalid_sessions(redis: Redis) -> None:
|
|
98
|
+
"""Find & clean invalid sessions."""
|
|
99
|
+
async for key in redis.scan_iter(count=100, match=f"{MENV.COOKIE_NAME}*"):
|
|
100
|
+
if not isinstance(key, str):
|
|
101
|
+
key = key.decode()
|
|
102
|
+
sval = await redis.get(key)
|
|
103
|
+
if sval is None:
|
|
104
|
+
continue
|
|
105
|
+
try:
|
|
106
|
+
val: dict = json.loads(sval)
|
|
107
|
+
assert isinstance(val["created"], int)
|
|
108
|
+
assert isinstance(val["session"], dict)
|
|
109
|
+
except Exception as err:
|
|
110
|
+
_LOGGER.warning("Removing session %s: %s", key, err)
|
|
111
|
+
await redis.delete(key)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
|
|
97
115
|
"""Create a AsyncMSAL session.
|
|
98
116
|
|
|
99
117
|
When get_token refreshes the token retrieved from Redis, the save_cache callback
|
|
100
|
-
will be responsible to update the cache in Redis.
|
|
118
|
+
will be responsible to update the cache in Redis.
|
|
119
|
+
"""
|
|
101
120
|
|
|
102
121
|
async def async_save_cache(_: dict) -> None:
|
|
103
122
|
"""Save the token cache to Redis."""
|
|
@@ -111,11 +130,11 @@ def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
|
|
|
111
130
|
except RuntimeError:
|
|
112
131
|
asyncio.run(async_save_cache(*args))
|
|
113
132
|
|
|
114
|
-
return AsyncMSAL(session,
|
|
133
|
+
return AsyncMSAL(session, save_callback=save_cache)
|
|
115
134
|
|
|
116
135
|
|
|
117
136
|
async def get_session(
|
|
118
|
-
email: str, *, redis:
|
|
137
|
+
email: str, *, redis: Redis | None = None, scope: str = ""
|
|
119
138
|
) -> AsyncMSAL:
|
|
120
139
|
"""Get a session from Redis."""
|
|
121
140
|
cnt = 0
|
|
@@ -126,7 +145,7 @@ async def get_session(
|
|
|
126
145
|
cnt += 1
|
|
127
146
|
if scope and scope not in str(session.get("token_cache")).lower():
|
|
128
147
|
continue
|
|
129
|
-
return _session_factory(key,
|
|
148
|
+
return _session_factory(key, created, session)
|
|
130
149
|
msg = f"Session for {email}"
|
|
131
150
|
if not scope:
|
|
132
151
|
raise ValueError(f"{msg} not found")
|
|
@@ -136,7 +155,7 @@ async def get_session(
|
|
|
136
155
|
async def redis_get_json(key: str) -> list | dict | None:
|
|
137
156
|
"""Get a key from redis."""
|
|
138
157
|
res = await MENV.database.get(key)
|
|
139
|
-
if isinstance(res,
|
|
158
|
+
if isinstance(res, str | bytes | bytearray):
|
|
140
159
|
return json.loads(res)
|
|
141
160
|
if res is not None:
|
|
142
161
|
_LOGGER.warning("Unexpected type for %s: %s", key, type(res))
|
|
@@ -148,7 +167,7 @@ async def redis_get(key: str) -> str | None:
|
|
|
148
167
|
res = await MENV.database.get(key)
|
|
149
168
|
if isinstance(res, str):
|
|
150
169
|
return res
|
|
151
|
-
if isinstance(res,
|
|
170
|
+
if isinstance(res, bytes | bytearray):
|
|
152
171
|
return res.decode()
|
|
153
172
|
if res is not None:
|
|
154
173
|
_LOGGER.warning("Unexpected type for %s: %s", key, type(res))
|