aiohttp-msal 0.7.1__tar.gz → 1.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1}/PKG-INFO +27 -28
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1}/README.md +12 -3
- aiohttp_msal-1.0.1/pyproject.toml +128 -0
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/__init__.py +27 -20
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/msal_async.py +55 -55
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/redis_tools.py +33 -14
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/routes.py +6 -7
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/settings.py +16 -12
- aiohttp_msal-1.0.1/src/aiohttp_msal/settings_base.py +93 -0
- {aiohttp_msal-0.7.1 → aiohttp_msal-1.0.1/src}/aiohttp_msal/user_info.py +8 -4
- aiohttp_msal-0.7.1/LICENSE +0 -21
- aiohttp_msal-0.7.1/aiohttp_msal/settings_base.py +0 -83
- aiohttp_msal-0.7.1/aiohttp_msal.egg-info/PKG-INFO +0 -141
- aiohttp_msal-0.7.1/aiohttp_msal.egg-info/SOURCES.txt +0 -23
- aiohttp_msal-0.7.1/aiohttp_msal.egg-info/dependency_links.txt +0 -1
- aiohttp_msal-0.7.1/aiohttp_msal.egg-info/requires.txt +0 -16
- aiohttp_msal-0.7.1/aiohttp_msal.egg-info/top_level.txt +0 -2
- aiohttp_msal-0.7.1/aiohttp_msal.egg-info/zip-safe +0 -1
- aiohttp_msal-0.7.1/pyproject.toml +0 -65
- aiohttp_msal-0.7.1/setup.cfg +0 -55
- aiohttp_msal-0.7.1/setup.py +0 -6
- aiohttp_msal-0.7.1/tests/__init__.py +0 -0
- aiohttp_msal-0.7.1/tests/test_init.py +0 -79
- aiohttp_msal-0.7.1/tests/test_msal_async.py +0 -12
- aiohttp_msal-0.7.1/tests/test_redis_tools.py +0 -60
- aiohttp_msal-0.7.1/tests/test_settings.py +0 -11
|
@@ -1,39 +1,29 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
2
|
-
Name:
|
|
3
|
-
Version: 0.
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: aiohttp-msal
|
|
3
|
+
Version: 1.0.1
|
|
4
4
|
Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
|
|
5
|
-
|
|
5
|
+
Keywords: aiohttp,asyncio,msal,oauth
|
|
6
6
|
Author: Johann Kellerman
|
|
7
|
-
Author-email: kellerza@gmail.com
|
|
7
|
+
Author-email: Johann Kellerman <kellerza@gmail.com>
|
|
8
8
|
License: MIT
|
|
9
|
-
Keywords: msal,oauth,aiohttp,asyncio
|
|
10
9
|
Classifier: Development Status :: 4 - Beta
|
|
11
10
|
Classifier: Intended Audience :: Developers
|
|
12
11
|
Classifier: Natural Language :: English
|
|
13
|
-
Classifier: Programming Language :: Python :: 3
|
|
14
12
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
16
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
18
|
-
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Requires-Dist: aiohttp>=3.11.18,<3.13
|
|
17
|
+
Requires-Dist: aiohttp-session>=2.12.1,<3
|
|
18
|
+
Requires-Dist: attrs>=25.3,<26
|
|
19
|
+
Requires-Dist: msal>=1.32.3,<2
|
|
20
|
+
Requires-Dist: aiohttp-session[aioredis]>=2.12.1,<3 ; extra == 'aioredis'
|
|
21
|
+
Requires-Python: >=3.11
|
|
22
|
+
Project-URL: Homepage, https://github.com/kellerza/aiohttp_msal
|
|
23
|
+
Provides-Extra: aioredis
|
|
19
24
|
Description-Content-Type: text/markdown
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
Requires-Dist: aiohttp_session>=2.12
|
|
23
|
-
Requires-Dist: aiohttp>=3.8
|
|
24
|
-
Provides-Extra: redis
|
|
25
|
-
Requires-Dist: aiohttp_session[aioredis]>=2.12; extra == "redis"
|
|
26
|
-
Provides-Extra: tests
|
|
27
|
-
Requires-Dist: black==24.8.0; extra == "tests"
|
|
28
|
-
Requires-Dist: pylint==3.2.6; extra == "tests"
|
|
29
|
-
Requires-Dist: flake8; extra == "tests"
|
|
30
|
-
Requires-Dist: pytest-aiohttp; extra == "tests"
|
|
31
|
-
Requires-Dist: pytest; extra == "tests"
|
|
32
|
-
Requires-Dist: pytest-cov; extra == "tests"
|
|
33
|
-
Requires-Dist: pytest-asyncio; extra == "tests"
|
|
34
|
-
Requires-Dist: pytest-env; extra == "tests"
|
|
35
|
-
|
|
36
|
-
# aiohttp_msal Python library
|
|
25
|
+
|
|
26
|
+
# Async based MSAL helper for aiohttp - aiohttp_msal Python library
|
|
37
27
|
|
|
38
28
|
Authorization Code Flow Helper. Learn more about auth-code-flow at
|
|
39
29
|
<https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow>
|
|
@@ -134,8 +124,17 @@ from aiohttp_msal.redis_tools import get_session
|
|
|
134
124
|
def main()
|
|
135
125
|
# Uses the redis.asyncio driver to retrieve the current token
|
|
136
126
|
# Will update the token_cache if a RefreshToken was used
|
|
137
|
-
|
|
138
|
-
client = GraphClient(
|
|
127
|
+
ses = asyncio.run(get_session(MYEMAIL))
|
|
128
|
+
client = GraphClient(ses.get_token)
|
|
139
129
|
# ...
|
|
140
130
|
# use the Graphclient
|
|
141
131
|
```
|
|
132
|
+
|
|
133
|
+
## Development
|
|
134
|
+
|
|
135
|
+
```bash
|
|
136
|
+
uv sync --all-extras
|
|
137
|
+
uv tool install ruff
|
|
138
|
+
uv tool install codespell
|
|
139
|
+
uv tool install pyproject-fmt
|
|
140
|
+
```
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# aiohttp_msal Python library
|
|
1
|
+
# Async based MSAL helper for aiohttp - aiohttp_msal Python library
|
|
2
2
|
|
|
3
3
|
Authorization Code Flow Helper. Learn more about auth-code-flow at
|
|
4
4
|
<https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow>
|
|
@@ -99,8 +99,17 @@ from aiohttp_msal.redis_tools import get_session
|
|
|
99
99
|
def main()
|
|
100
100
|
# Uses the redis.asyncio driver to retrieve the current token
|
|
101
101
|
# Will update the token_cache if a RefreshToken was used
|
|
102
|
-
|
|
103
|
-
client = GraphClient(
|
|
102
|
+
ses = asyncio.run(get_session(MYEMAIL))
|
|
103
|
+
client = GraphClient(ses.get_token)
|
|
104
104
|
# ...
|
|
105
105
|
# use the Graphclient
|
|
106
106
|
```
|
|
107
|
+
|
|
108
|
+
## Development
|
|
109
|
+
|
|
110
|
+
```bash
|
|
111
|
+
uv sync --all-extras
|
|
112
|
+
uv tool install ruff
|
|
113
|
+
uv tool install codespell
|
|
114
|
+
uv tool install pyproject-fmt
|
|
115
|
+
```
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
build-backend = "uv_build"
|
|
3
|
+
requires = [ "uv-build" ] # >=0.5.15,<0.6
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "aiohttp-msal"
|
|
7
|
+
version = "1.0.1"
|
|
8
|
+
description = "Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
keywords = [ "aiohttp", "asyncio", "msal", "oauth" ]
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [ { name = "Johann Kellerman", email = "kellerza@gmail.com" } ]
|
|
13
|
+
requires-python = ">=3.11"
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 4 - Beta",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Natural Language :: English",
|
|
18
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
19
|
+
"Programming Language :: Python :: 3.11",
|
|
20
|
+
"Programming Language :: Python :: 3.12",
|
|
21
|
+
"Programming Language :: Python :: 3.13",
|
|
22
|
+
]
|
|
23
|
+
dependencies = [
|
|
24
|
+
"aiohttp>=3.11.18,<3.13",
|
|
25
|
+
"aiohttp-session>=2.12.1,<3",
|
|
26
|
+
"attrs>=25.3,<26",
|
|
27
|
+
"msal>=1.32.3,<2",
|
|
28
|
+
]
|
|
29
|
+
optional-dependencies.aioredis = [ "aiohttp-session[aioredis]>=2.12.1,<3" ]
|
|
30
|
+
urls.Homepage = "https://github.com/kellerza/aiohttp_msal"
|
|
31
|
+
|
|
32
|
+
[dependency-groups]
|
|
33
|
+
dev = [
|
|
34
|
+
"mypy",
|
|
35
|
+
"pytest",
|
|
36
|
+
"pytest-aiohttp",
|
|
37
|
+
"pytest-asyncio",
|
|
38
|
+
"pytest-cov",
|
|
39
|
+
"pytest-env",
|
|
40
|
+
"types-redis",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
[tool.ruff]
|
|
44
|
+
include = [ "aiohttp_msal/**/*.py", "tests/*.py" ]
|
|
45
|
+
|
|
46
|
+
format.line-ending = "lf"
|
|
47
|
+
format.docstring-code-format = true
|
|
48
|
+
lint.select = [
|
|
49
|
+
"A", # flake8-builtins
|
|
50
|
+
"ASYNC", # flake8-async
|
|
51
|
+
"B", # bugbear
|
|
52
|
+
"D", # pydocstyle
|
|
53
|
+
"E", # pycodestyle
|
|
54
|
+
"F", # pyflakes
|
|
55
|
+
"I", # isort
|
|
56
|
+
"PGH", # pygrep-hooks
|
|
57
|
+
"PIE", # flake8-pie
|
|
58
|
+
"PL", # pylint
|
|
59
|
+
"PTH", # flake8-pathlib
|
|
60
|
+
"PYI", # flake8-pyi
|
|
61
|
+
"RUF", # ruff
|
|
62
|
+
"UP", # pyupgrade
|
|
63
|
+
"W", # pycodestyle
|
|
64
|
+
]
|
|
65
|
+
lint.ignore = [
|
|
66
|
+
"D203",
|
|
67
|
+
"D213",
|
|
68
|
+
"E203",
|
|
69
|
+
"E501",
|
|
70
|
+
"PLC0415",
|
|
71
|
+
"PLR2004",
|
|
72
|
+
"PLW2901",
|
|
73
|
+
"UP047",
|
|
74
|
+
]
|
|
75
|
+
lint.isort.no-lines-before = [ "future", "standard-library" ]
|
|
76
|
+
|
|
77
|
+
[tool.codespell]
|
|
78
|
+
skip = [ "build/*", "*.json", "*.csv", "**/node_modules/*", "./s2-js/dist/*" ]
|
|
79
|
+
#ignore-words-list = []
|
|
80
|
+
|
|
81
|
+
[tool.pytest.ini_options]
|
|
82
|
+
pythonpath = [ ".", "src" ]
|
|
83
|
+
filterwarnings = "ignore:.+@coroutine.+deprecated.+"
|
|
84
|
+
norecursedirs = [ ".git", "modules" ]
|
|
85
|
+
log_cli = true
|
|
86
|
+
log_cli_level = "DEBUG"
|
|
87
|
+
asyncio_mode = "auto"
|
|
88
|
+
addopts = "--cov=aiohttp_msal --cov-report xml:cov.xml"
|
|
89
|
+
asyncio_default_fixture_loop_scope = "function"
|
|
90
|
+
|
|
91
|
+
env = [
|
|
92
|
+
"X_SP_APP_PW=p1",
|
|
93
|
+
"X_SP_APP_ID=i1",
|
|
94
|
+
"X_SP_AUTHORITY=a1",
|
|
95
|
+
|
|
96
|
+
"Y_SP_APP_PW=p2",
|
|
97
|
+
"Y_SP_APP_ID=i2",
|
|
98
|
+
"Y_SP_AUTHORITY=a2",
|
|
99
|
+
|
|
100
|
+
"A_NUM=5",
|
|
101
|
+
"A_BOOL=True",
|
|
102
|
+
|
|
103
|
+
"B_NUM=10",
|
|
104
|
+
"B_BOOL=False",
|
|
105
|
+
"B_ROOT=/c/",
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
[tool.mypy]
|
|
109
|
+
disallow_untyped_defs = true
|
|
110
|
+
ignore_missing_imports = true
|
|
111
|
+
|
|
112
|
+
[tool.semantic_release]
|
|
113
|
+
commit = true
|
|
114
|
+
tag = true
|
|
115
|
+
vcs_release = true
|
|
116
|
+
commit_parser = "emoji"
|
|
117
|
+
version_toml = [ "pyproject.toml:project.version" ]
|
|
118
|
+
build_command = "pip install uv && uv build"
|
|
119
|
+
commit_version_number = true
|
|
120
|
+
|
|
121
|
+
# https://python-semantic-release.readthedocs.io/en/latest/multibranch_releases.html#configuring-multibranch-releases
|
|
122
|
+
[tool.semantic_release.branches.main]
|
|
123
|
+
match = "main"
|
|
124
|
+
|
|
125
|
+
[tool.semantic_release.commit_parser_options]
|
|
126
|
+
major_tags = [ ":boom:" ]
|
|
127
|
+
minor_tags = [ ":rocket:" ]
|
|
128
|
+
patch_tags = [ ":ambulance:", ":lock:", ":bug:", ":dolphin:" ]
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
"""aiohttp_msal."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
import
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
5
|
from functools import wraps
|
|
6
6
|
from inspect import getfullargspec, iscoroutinefunction
|
|
7
|
+
from typing import TypeVar, TypeVarTuple, cast
|
|
7
8
|
|
|
8
9
|
from aiohttp import ClientSession, web
|
|
9
10
|
from aiohttp_session import get_session
|
|
@@ -14,43 +15,49 @@ from aiohttp_msal.settings import ENV
|
|
|
14
15
|
|
|
15
16
|
_LOGGER = logging.getLogger(__name__)
|
|
16
17
|
|
|
17
|
-
|
|
18
|
+
_T = TypeVar("_T")
|
|
19
|
+
Ts = TypeVarTuple("Ts")
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
def msal_session(
|
|
21
|
-
*callbacks:
|
|
23
|
+
*callbacks: Callable[[AsyncMSAL], bool | Awaitable[bool]],
|
|
22
24
|
at_least_one: bool | None = False,
|
|
23
|
-
) ->
|
|
25
|
+
) -> Callable[
|
|
26
|
+
[Callable[[*Ts, AsyncMSAL], Awaitable[_T]]], Callable[[*Ts], Awaitable[_T]]
|
|
27
|
+
]:
|
|
24
28
|
"""Session decorator.
|
|
25
29
|
|
|
26
30
|
Arguments can include a list of function to perform login tests etc.
|
|
27
31
|
"""
|
|
28
32
|
|
|
29
|
-
def
|
|
33
|
+
def check_session(
|
|
34
|
+
func: Callable[[*Ts, AsyncMSAL], Awaitable[_T]],
|
|
35
|
+
) -> Callable[[*Ts], Awaitable[_T]]:
|
|
30
36
|
@wraps(func)
|
|
31
|
-
async def
|
|
37
|
+
async def wrapper(*args: *Ts) -> _T:
|
|
38
|
+
if len(args) < 1:
|
|
39
|
+
raise AssertionError("Requires a Request as the first parameter")
|
|
40
|
+
request = cast(web.Request, args[0])
|
|
32
41
|
ses = AsyncMSAL(session=await get_session(request))
|
|
33
42
|
for c_b in callbacks:
|
|
34
43
|
_ok = await c_b(ses) if iscoroutinefunction(c_b) else c_b(ses)
|
|
35
44
|
|
|
36
45
|
if at_least_one:
|
|
37
46
|
if _ok:
|
|
38
|
-
return await func(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
if not _ok:
|
|
47
|
+
return await func(*args, ses)
|
|
48
|
+
elif not _ok:
|
|
42
49
|
raise web.HTTPForbidden
|
|
43
50
|
|
|
44
51
|
if at_least_one:
|
|
45
52
|
raise web.HTTPForbidden
|
|
46
|
-
return await func(
|
|
53
|
+
return await func(*args, ses)
|
|
47
54
|
|
|
48
55
|
assert iscoroutinefunction(func), f"Function needs to be a coroutine: {func}"
|
|
49
56
|
spec = getfullargspec(func)
|
|
50
57
|
assert "ses" in spec.args, f"Function needs to accept a session 'ses': {func}"
|
|
51
|
-
return
|
|
58
|
+
return wrapper
|
|
52
59
|
|
|
53
|
-
return
|
|
60
|
+
return check_session
|
|
54
61
|
|
|
55
62
|
|
|
56
63
|
def auth_ok(ses: AsyncMSAL) -> bool:
|
|
@@ -59,11 +66,12 @@ def auth_ok(ses: AsyncMSAL) -> bool:
|
|
|
59
66
|
|
|
60
67
|
|
|
61
68
|
def auth_or(
|
|
62
|
-
*args:
|
|
63
|
-
) ->
|
|
69
|
+
*args: Callable[[AsyncMSAL], bool | Awaitable[bool]],
|
|
70
|
+
) -> Callable[[AsyncMSAL], Awaitable[bool]]:
|
|
64
71
|
"""Ensure either of the methods is valid. An alternative to at_least_one=True.
|
|
65
72
|
|
|
66
|
-
Arguments can include a list of function to perform login tests etc.
|
|
73
|
+
Arguments can include a list of function to perform login tests etc.
|
|
74
|
+
"""
|
|
67
75
|
|
|
68
76
|
async def or_auth(ses: AsyncMSAL) -> bool:
|
|
69
77
|
"""Or."""
|
|
@@ -81,11 +89,10 @@ def auth_or(
|
|
|
81
89
|
async def app_init_redis_session(
|
|
82
90
|
app: web.Application, max_age: int = 3600 * 24 * 90
|
|
83
91
|
) -> None:
|
|
84
|
-
"""
|
|
92
|
+
"""Init an aiohttp_session with Redis storage helper.
|
|
85
93
|
|
|
86
94
|
You can initialize your own aiohttp_session & storage provider.
|
|
87
95
|
"""
|
|
88
|
-
# pylint: disable=import-outside-toplevel
|
|
89
96
|
from aiohttp_session import redis_storage
|
|
90
97
|
from redis.asyncio import from_url
|
|
91
98
|
|
|
@@ -93,7 +100,7 @@ async def app_init_redis_session(
|
|
|
93
100
|
|
|
94
101
|
_LOGGER.info("Connect to Redis %s", ENV.REDIS)
|
|
95
102
|
try:
|
|
96
|
-
ENV.database = from_url(ENV.REDIS)
|
|
103
|
+
ENV.database = from_url(ENV.REDIS)
|
|
97
104
|
# , encoding="utf-8", decode_responses=True
|
|
98
105
|
except ConnectionRefusedError as err:
|
|
99
106
|
raise ConnectionError("Could not connect to REDIS server") from err
|
|
@@ -119,7 +126,7 @@ async def check_proxy() -> None:
|
|
|
119
126
|
if resp.ok:
|
|
120
127
|
return
|
|
121
128
|
raise ConnectionError(await resp.text())
|
|
122
|
-
except Exception as err:
|
|
129
|
+
except Exception as err:
|
|
123
130
|
raise ConnectionError(
|
|
124
131
|
"No connection to the Internet. Required for OAuth. Check your Proxy?"
|
|
125
132
|
) from err
|
|
@@ -7,16 +7,24 @@ Once you have the OAuth tokens store in the session, you are free to make reques
|
|
|
7
7
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
|
-
from
|
|
11
|
-
from
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from functools import cached_property, partial, partialmethod, wraps
|
|
12
|
+
from typing import Any, ClassVar, Literal, Unpack
|
|
12
13
|
|
|
13
14
|
from aiohttp import web
|
|
14
|
-
from aiohttp.client import
|
|
15
|
+
from aiohttp.client import (
|
|
16
|
+
ClientResponse,
|
|
17
|
+
ClientSession,
|
|
18
|
+
_RequestContextManager,
|
|
19
|
+
_RequestOptions,
|
|
20
|
+
)
|
|
21
|
+
from aiohttp.typedefs import StrOrURL
|
|
15
22
|
from aiohttp_session import Session
|
|
16
23
|
from msal import ConfidentialClientApplication, SerializableTokenCache
|
|
17
24
|
|
|
18
25
|
from aiohttp_msal.settings import ENV
|
|
19
26
|
|
|
27
|
+
HttpMethods = Literal["get", "post", "put", "patch", "delete"]
|
|
20
28
|
HTTP_GET = "get"
|
|
21
29
|
HTTP_POST = "post"
|
|
22
30
|
HTTP_PUT = "put"
|
|
@@ -62,60 +70,51 @@ class AsyncMSAL:
|
|
|
62
70
|
Use until such time as MSAL Python gets a true async version.
|
|
63
71
|
"""
|
|
64
72
|
|
|
65
|
-
|
|
66
|
-
_app: ConfidentialClientApplication = None
|
|
67
|
-
_clientsession: ClientSession = None # type: ignore
|
|
73
|
+
client_session: ClassVar[ClientSession | None] = None
|
|
68
74
|
|
|
69
75
|
def __init__(
|
|
70
76
|
self,
|
|
71
|
-
session: Session | dict[str,
|
|
72
|
-
|
|
77
|
+
session: Session | dict[str, Any],
|
|
78
|
+
save_callback: Callable[[Session | dict[str, Any]], None] | None = None,
|
|
73
79
|
):
|
|
74
80
|
"""Init the class.
|
|
75
81
|
|
|
76
|
-
**
|
|
82
|
+
**save_callback** will be called if the token cache changes. Optional.
|
|
77
83
|
Not required when the session parameter is an aiohttp_session.Session.
|
|
78
84
|
"""
|
|
79
85
|
self.session = session
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
if not isinstance(session, (Session, dict)):
|
|
86
|
+
self.save_callback = save_callback
|
|
87
|
+
if not isinstance(session, Session | dict):
|
|
83
88
|
raise ValueError(f"session or dict-like object required {session}")
|
|
84
89
|
|
|
85
|
-
@
|
|
90
|
+
@cached_property
|
|
86
91
|
def token_cache(self) -> SerializableTokenCache:
|
|
87
92
|
"""Get the token_cache."""
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
self._token_cache.deserialize(self.session[TOKEN_CACHE])
|
|
93
|
-
|
|
94
|
-
return self._token_cache
|
|
93
|
+
res = SerializableTokenCache()
|
|
94
|
+
if self.session and self.session.get(TOKEN_CACHE):
|
|
95
|
+
res.deserialize(self.session[TOKEN_CACHE])
|
|
96
|
+
return res
|
|
95
97
|
|
|
96
|
-
@
|
|
98
|
+
@cached_property
|
|
97
99
|
def app(self) -> ConfidentialClientApplication:
|
|
98
100
|
"""Create the application using the cache.
|
|
99
101
|
|
|
100
102
|
Based on: https://github.com/Azure-Samples/ms-identity-python-webapp/blob/master/app.py#L76
|
|
101
103
|
"""
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
token_cache=token_cache,
|
|
110
|
-
)
|
|
111
|
-
return self._app
|
|
104
|
+
return ConfidentialClientApplication(
|
|
105
|
+
client_id=ENV.SP_APP_ID,
|
|
106
|
+
client_credential=ENV.SP_APP_PW,
|
|
107
|
+
authority=ENV.SP_AUTHORITY, # common/oauth2/v2.0/token'
|
|
108
|
+
validate_authority=False,
|
|
109
|
+
token_cache=self.token_cache,
|
|
110
|
+
)
|
|
112
111
|
|
|
113
|
-
def
|
|
112
|
+
def save_token_cache(self) -> None:
|
|
114
113
|
"""Save the token cache if it changed."""
|
|
115
114
|
if self.token_cache.has_state_changed:
|
|
116
115
|
self.session[TOKEN_CACHE] = self.token_cache.serialize()
|
|
117
|
-
if
|
|
118
|
-
self.
|
|
116
|
+
if self.save_callback:
|
|
117
|
+
self.save_callback(self.session)
|
|
119
118
|
|
|
120
119
|
def build_auth_code_flow(
|
|
121
120
|
self,
|
|
@@ -125,8 +124,8 @@ class AsyncMSAL:
|
|
|
125
124
|
**kwargs: Any,
|
|
126
125
|
) -> str:
|
|
127
126
|
"""First step - Start the flow."""
|
|
128
|
-
self.session[TOKEN_CACHE] = None
|
|
129
|
-
self.session[USER_EMAIL] = None
|
|
127
|
+
self.session[TOKEN_CACHE] = None
|
|
128
|
+
self.session[USER_EMAIL] = None
|
|
130
129
|
self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
|
|
131
130
|
scopes or DEFAULT_SCOPES,
|
|
132
131
|
redirect_uri=redirect_uri,
|
|
@@ -149,10 +148,9 @@ class AsyncMSAL:
|
|
|
149
148
|
raise web.HTTPBadRequest(text=str(result["error"]))
|
|
150
149
|
if "id_token_claims" not in result:
|
|
151
150
|
raise web.HTTPBadRequest(text=f"Expected id_token_claims in {result}")
|
|
152
|
-
self.
|
|
153
|
-
|
|
154
|
-
"preferred_username"
|
|
155
|
-
)
|
|
151
|
+
self.save_token_cache()
|
|
152
|
+
if tok := result.get("id_token_claims"):
|
|
153
|
+
self.session[USER_EMAIL] = tok.get("preferred_username")
|
|
156
154
|
|
|
157
155
|
async def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
|
|
158
156
|
"""Second step - Acquire token, async version."""
|
|
@@ -167,7 +165,7 @@ class AsyncMSAL:
|
|
|
167
165
|
result = self.app.acquire_token_silent(
|
|
168
166
|
scopes=scopes or DEFAULT_SCOPES, account=accounts[0]
|
|
169
167
|
)
|
|
170
|
-
self.
|
|
168
|
+
self.save_token_cache()
|
|
171
169
|
return result
|
|
172
170
|
return None
|
|
173
171
|
|
|
@@ -175,7 +173,9 @@ class AsyncMSAL:
|
|
|
175
173
|
"""Acquire a token based on username."""
|
|
176
174
|
return await asyncio.get_event_loop().run_in_executor(None, self.get_token)
|
|
177
175
|
|
|
178
|
-
async def request(
|
|
176
|
+
async def request(
|
|
177
|
+
self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
|
|
178
|
+
) -> ClientResponse:
|
|
179
179
|
"""Make a request to url using an oauth session.
|
|
180
180
|
|
|
181
181
|
:param str url: url to send request to
|
|
@@ -184,16 +184,14 @@ class AsyncMSAL:
|
|
|
184
184
|
:return: Response of the request
|
|
185
185
|
:rtype: aiohttp.Response
|
|
186
186
|
"""
|
|
187
|
-
if not self._clientsession:
|
|
188
|
-
AsyncMSAL._clientsession = ClientSession(trust_env=True)
|
|
189
|
-
|
|
190
187
|
token = await self.async_get_token()
|
|
191
188
|
if token is None:
|
|
192
189
|
raise web.HTTPClientError(text="No login token available.")
|
|
193
190
|
|
|
194
191
|
kwargs = kwargs.copy()
|
|
195
192
|
# Ensure headers exist & make a copy
|
|
196
|
-
|
|
193
|
+
headers = dict[str, str](kwargs.get("headers") or {}) # type:ignore[arg-type]
|
|
194
|
+
kwargs["headers"] = headers
|
|
197
195
|
|
|
198
196
|
headers["Authorization"] = "Bearer " + token["access_token"]
|
|
199
197
|
|
|
@@ -207,22 +205,24 @@ class AsyncMSAL:
|
|
|
207
205
|
if "data" in kwargs:
|
|
208
206
|
kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json
|
|
209
207
|
|
|
210
|
-
|
|
208
|
+
if not AsyncMSAL.client_session:
|
|
209
|
+
AsyncMSAL.client_session = ClientSession(trust_env=True)
|
|
211
210
|
|
|
212
|
-
return
|
|
211
|
+
return await AsyncMSAL.client_session.request(method, url, **kwargs)
|
|
213
212
|
|
|
214
|
-
def
|
|
215
|
-
|
|
216
|
-
|
|
213
|
+
def request_ctx(
|
|
214
|
+
self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
|
|
215
|
+
) -> _RequestContextManager:
|
|
216
|
+
"""Request context manager."""
|
|
217
|
+
return _RequestContextManager(self.request(method, url, **kwargs))
|
|
217
218
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
|
|
219
|
+
get = partialmethod(request_ctx, HTTP_GET)
|
|
220
|
+
post = partialmethod(request_ctx, HTTP_POST)
|
|
221
221
|
|
|
222
222
|
@property
|
|
223
223
|
def mail(self) -> str:
|
|
224
224
|
"""User email."""
|
|
225
|
-
return self.session.get(
|
|
225
|
+
return self.session.get(USER_EMAIL, "")
|
|
226
226
|
|
|
227
227
|
@property
|
|
228
228
|
def manager_mail(self) -> str:
|
|
@@ -4,8 +4,9 @@ import asyncio
|
|
|
4
4
|
import json
|
|
5
5
|
import logging
|
|
6
6
|
import time
|
|
7
|
+
from collections.abc import AsyncGenerator
|
|
7
8
|
from contextlib import AsyncExitStack, asynccontextmanager
|
|
8
|
-
from typing import Any
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
from redis.asyncio import Redis, from_url
|
|
11
12
|
|
|
@@ -30,15 +31,15 @@ async def get_redis() -> AsyncGenerator[Redis, None]:
|
|
|
30
31
|
try:
|
|
31
32
|
yield redis
|
|
32
33
|
finally:
|
|
33
|
-
MENV.database = None # type:ignore
|
|
34
|
+
MENV.database = None # type:ignore[assignment]
|
|
34
35
|
await redis.close()
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
async def session_iter(
|
|
38
39
|
redis: Redis,
|
|
39
40
|
*,
|
|
40
|
-
match:
|
|
41
|
-
key_match:
|
|
41
|
+
match: dict[str, str] | None = None,
|
|
42
|
+
key_match: str | None = None,
|
|
42
43
|
) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]:
|
|
43
44
|
"""Iterate over the Redis keys to find a specific session.
|
|
44
45
|
|
|
@@ -55,10 +56,10 @@ async def session_iter(
|
|
|
55
56
|
sval = await redis.get(key)
|
|
56
57
|
created, ses = 0, {}
|
|
57
58
|
try:
|
|
58
|
-
val = json.loads(sval) # type: ignore
|
|
59
|
+
val = json.loads(sval) # type: ignore[arg-type]
|
|
59
60
|
created = int(val["created"])
|
|
60
61
|
ses = val["session"]
|
|
61
|
-
except Exception:
|
|
62
|
+
except Exception:
|
|
62
63
|
pass
|
|
63
64
|
if match:
|
|
64
65
|
# Ensure we match all the supplied terms
|
|
@@ -73,7 +74,7 @@ async def session_iter(
|
|
|
73
74
|
|
|
74
75
|
|
|
75
76
|
async def session_clean(
|
|
76
|
-
redis: Redis, *, max_age: int = 90, expected_keys:
|
|
77
|
+
redis: Redis, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
|
|
77
78
|
) -> None:
|
|
78
79
|
"""Clear session entries older than max_age days."""
|
|
79
80
|
rem, keep = 0, 0
|
|
@@ -93,11 +94,29 @@ async def session_clean(
|
|
|
93
94
|
_LOGGER.debug("No sessions removed (%s total)", keep)
|
|
94
95
|
|
|
95
96
|
|
|
96
|
-
def
|
|
97
|
+
async def invalid_sessions(redis: Redis) -> None:
|
|
98
|
+
"""Find & clean invalid sessions."""
|
|
99
|
+
async for key in redis.scan_iter(count=100, match=f"{MENV.COOKIE_NAME}*"):
|
|
100
|
+
if not isinstance(key, str):
|
|
101
|
+
key = key.decode()
|
|
102
|
+
sval = await redis.get(key)
|
|
103
|
+
if sval is None:
|
|
104
|
+
continue
|
|
105
|
+
try:
|
|
106
|
+
val: dict = json.loads(sval)
|
|
107
|
+
assert isinstance(val["created"], int)
|
|
108
|
+
assert isinstance(val["session"], dict)
|
|
109
|
+
except Exception as err:
|
|
110
|
+
_LOGGER.warning("Removing session %s: %s", key, err)
|
|
111
|
+
await redis.delete(key)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
|
|
97
115
|
"""Create a AsyncMSAL session.
|
|
98
116
|
|
|
99
117
|
When get_token refreshes the token retrieved from Redis, the save_cache callback
|
|
100
|
-
will be responsible to update the cache in Redis.
|
|
118
|
+
will be responsible to update the cache in Redis.
|
|
119
|
+
"""
|
|
101
120
|
|
|
102
121
|
async def async_save_cache(_: dict) -> None:
|
|
103
122
|
"""Save the token cache to Redis."""
|
|
@@ -111,11 +130,11 @@ def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
|
|
|
111
130
|
except RuntimeError:
|
|
112
131
|
asyncio.run(async_save_cache(*args))
|
|
113
132
|
|
|
114
|
-
return AsyncMSAL(session,
|
|
133
|
+
return AsyncMSAL(session, save_callback=save_cache)
|
|
115
134
|
|
|
116
135
|
|
|
117
136
|
async def get_session(
|
|
118
|
-
email: str, *, redis:
|
|
137
|
+
email: str, *, redis: Redis | None = None, scope: str = ""
|
|
119
138
|
) -> AsyncMSAL:
|
|
120
139
|
"""Get a session from Redis."""
|
|
121
140
|
cnt = 0
|
|
@@ -126,7 +145,7 @@ async def get_session(
|
|
|
126
145
|
cnt += 1
|
|
127
146
|
if scope and scope not in str(session.get("token_cache")).lower():
|
|
128
147
|
continue
|
|
129
|
-
return _session_factory(key,
|
|
148
|
+
return _session_factory(key, created, session)
|
|
130
149
|
msg = f"Session for {email}"
|
|
131
150
|
if not scope:
|
|
132
151
|
raise ValueError(f"{msg} not found")
|
|
@@ -136,7 +155,7 @@ async def get_session(
|
|
|
136
155
|
async def redis_get_json(key: str) -> list | dict | None:
|
|
137
156
|
"""Get a key from redis."""
|
|
138
157
|
res = await MENV.database.get(key)
|
|
139
|
-
if isinstance(res,
|
|
158
|
+
if isinstance(res, str | bytes | bytearray):
|
|
140
159
|
return json.loads(res)
|
|
141
160
|
if res is not None:
|
|
142
161
|
_LOGGER.warning("Unexpected type for %s: %s", key, type(res))
|
|
@@ -148,7 +167,7 @@ async def redis_get(key: str) -> str | None:
|
|
|
148
167
|
res = await MENV.database.get(key)
|
|
149
168
|
if isinstance(res, str):
|
|
150
169
|
return res
|
|
151
|
-
if isinstance(res,
|
|
170
|
+
if isinstance(res, bytes | bytearray):
|
|
152
171
|
return res.decode()
|
|
153
172
|
if res is not None:
|
|
154
173
|
_LOGGER.warning("Unexpected type for %s: %s", key, type(res))
|