aiohttp-msal 0.7.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aiohttp_msal/__init__.py CHANGED
@@ -1,9 +1,10 @@
1
1
  """aiohttp_msal."""
2
2
 
3
3
  import logging
4
- import typing
4
+ from collections.abc import Awaitable, Callable
5
5
  from functools import wraps
6
6
  from inspect import getfullargspec, iscoroutinefunction
7
+ from typing import TypeVar, TypeVarTuple, cast
7
8
 
8
9
  from aiohttp import ClientSession, web
9
10
  from aiohttp_session import get_session
@@ -14,43 +15,49 @@ from aiohttp_msal.settings import ENV
14
15
 
15
16
  _LOGGER = logging.getLogger(__name__)
16
17
 
17
- VERSION = "0.7.1"
18
+ _T = TypeVar("_T")
19
+ Ts = TypeVarTuple("Ts")
18
20
 
19
21
 
20
22
  def msal_session(
21
- *callbacks: typing.Callable[[AsyncMSAL], bool | typing.Awaitable[bool]],
23
+ *callbacks: Callable[[AsyncMSAL], bool | Awaitable[bool]],
22
24
  at_least_one: bool | None = False,
23
- ) -> typing.Callable:
25
+ ) -> Callable[
26
+ [Callable[[*Ts, AsyncMSAL], Awaitable[_T]]], Callable[[*Ts], Awaitable[_T]]
27
+ ]:
24
28
  """Session decorator.
25
29
 
26
30
  Arguments can include a list of function to perform login tests etc.
27
31
  """
28
32
 
29
- def _session(func: typing.Callable) -> typing.Callable:
33
+ def check_session(
34
+ func: Callable[[*Ts, AsyncMSAL], Awaitable[_T]],
35
+ ) -> Callable[[*Ts], Awaitable[_T]]:
30
36
  @wraps(func)
31
- async def __session(request: web.Request) -> typing.Callable:
37
+ async def wrapper(*args: *Ts) -> _T:
38
+ if len(args) < 1:
39
+ raise AssertionError("Requires a Request as the first parameter")
40
+ request = cast(web.Request, args[0])
32
41
  ses = AsyncMSAL(session=await get_session(request))
33
42
  for c_b in callbacks:
34
43
  _ok = await c_b(ses) if iscoroutinefunction(c_b) else c_b(ses)
35
44
 
36
45
  if at_least_one:
37
46
  if _ok:
38
- return await func(request=request, ses=ses)
39
- continue
40
-
41
- if not _ok:
47
+ return await func(*args, ses)
48
+ elif not _ok:
42
49
  raise web.HTTPForbidden
43
50
 
44
51
  if at_least_one:
45
52
  raise web.HTTPForbidden
46
- return await func(request=request, ses=ses)
53
+ return await func(*args, ses)
47
54
 
48
55
  assert iscoroutinefunction(func), f"Function needs to be a coroutine: {func}"
49
56
  spec = getfullargspec(func)
50
57
  assert "ses" in spec.args, f"Function needs to accept a session 'ses': {func}"
51
- return __session
58
+ return wrapper
52
59
 
53
- return _session
60
+ return check_session
54
61
 
55
62
 
56
63
  def auth_ok(ses: AsyncMSAL) -> bool:
@@ -59,11 +66,12 @@ def auth_ok(ses: AsyncMSAL) -> bool:
59
66
 
60
67
 
61
68
  def auth_or(
62
- *args: typing.Callable[[AsyncMSAL], bool | typing.Awaitable[bool]]
63
- ) -> typing.Callable[[AsyncMSAL], typing.Awaitable[bool]]:
69
+ *args: Callable[[AsyncMSAL], bool | Awaitable[bool]],
70
+ ) -> Callable[[AsyncMSAL], Awaitable[bool]]:
64
71
  """Ensure either of the methods is valid. An alternative to at_least_one=True.
65
72
 
66
- Arguments can include a list of function to perform login tests etc."""
73
+ Arguments can include a list of function to perform login tests etc.
74
+ """
67
75
 
68
76
  async def or_auth(ses: AsyncMSAL) -> bool:
69
77
  """Or."""
@@ -81,11 +89,10 @@ def auth_or(
81
89
  async def app_init_redis_session(
82
90
  app: web.Application, max_age: int = 3600 * 24 * 90
83
91
  ) -> None:
84
- """OPTIONAL: Initialize aiohttp_session with Redis storage.
92
+ """Init an aiohttp_session with Redis storage helper.
85
93
 
86
94
  You can initialize your own aiohttp_session & storage provider.
87
95
  """
88
- # pylint: disable=import-outside-toplevel
89
96
  from aiohttp_session import redis_storage
90
97
  from redis.asyncio import from_url
91
98
 
@@ -93,7 +100,7 @@ async def app_init_redis_session(
93
100
 
94
101
  _LOGGER.info("Connect to Redis %s", ENV.REDIS)
95
102
  try:
96
- ENV.database = from_url(ENV.REDIS) # pylint: disable=no-member
103
+ ENV.database = from_url(ENV.REDIS)
97
104
  # , encoding="utf-8", decode_responses=True
98
105
  except ConnectionRefusedError as err:
99
106
  raise ConnectionError("Could not connect to REDIS server") from err
@@ -119,7 +126,7 @@ async def check_proxy() -> None:
119
126
  if resp.ok:
120
127
  return
121
128
  raise ConnectionError(await resp.text())
122
- except Exception as err: # pylint: disable=broad-except
129
+ except Exception as err:
123
130
  raise ConnectionError(
124
131
  "No connection to the Internet. Required for OAuth. Check your Proxy?"
125
132
  ) from err
@@ -7,16 +7,24 @@ Once you have the OAuth tokens store in the session, you are free to make reques
7
7
 
8
8
  import asyncio
9
9
  import json
10
- from functools import partial, wraps
11
- from typing import Any, Callable, Literal
10
+ from collections.abc import Callable
11
+ from functools import partial, partialmethod, wraps
12
+ from typing import Any, ClassVar, Literal, Unpack
12
13
 
13
14
  from aiohttp import web
14
- from aiohttp.client import ClientResponse, ClientSession, _RequestContextManager
15
+ from aiohttp.client import (
16
+ ClientResponse,
17
+ ClientSession,
18
+ _RequestContextManager,
19
+ _RequestOptions,
20
+ )
21
+ from aiohttp.typedefs import StrOrURL
15
22
  from aiohttp_session import Session
16
23
  from msal import ConfidentialClientApplication, SerializableTokenCache
17
24
 
18
25
  from aiohttp_msal.settings import ENV
19
26
 
27
+ HttpMethods = Literal["get", "post", "put", "patch", "delete"]
20
28
  HTTP_GET = "get"
21
29
  HTTP_POST = "post"
22
30
  HTTP_PUT = "put"
@@ -62,24 +70,23 @@ class AsyncMSAL:
62
70
  Use until such time as MSAL Python gets a true async version.
63
71
  """
64
72
 
65
- _token_cache: SerializableTokenCache = None
66
- _app: ConfidentialClientApplication = None
67
- _clientsession: ClientSession = None # type: ignore
73
+ _token_cache: SerializableTokenCache
74
+ _app: ConfidentialClientApplication
75
+ client_session: ClassVar[ClientSession | None] = None
68
76
 
69
77
  def __init__(
70
78
  self,
71
- session: Session | dict[str, str],
72
- save_cache: Callable[[Session | dict[str, str]], None] | None = None,
79
+ session: Session | dict[str, Any],
80
+ save_callback: Callable[[Session | dict[str, Any]], None] | None = None,
73
81
  ):
74
82
  """Init the class.
75
83
 
76
- **save_token_cache** will be called if the token cache changes. Optional.
84
+ **save_callback** will be called if the token cache changes. Optional.
77
85
  Not required when the session parameter is an aiohttp_session.Session.
78
86
  """
79
87
  self.session = session
80
- if save_cache:
81
- self.save_token_cache = save_cache
82
- if not isinstance(session, (Session, dict)):
88
+ self.save_callback = save_callback
89
+ if not isinstance(session, Session | dict):
83
90
  raise ValueError(f"session or dict-like object required {session}")
84
91
 
85
92
  @property
@@ -110,12 +117,12 @@ class AsyncMSAL:
110
117
  )
111
118
  return self._app
112
119
 
113
- def _save_token_cache(self) -> None:
120
+ def save_token_cache(self) -> None:
114
121
  """Save the token cache if it changed."""
115
122
  if self.token_cache.has_state_changed:
116
123
  self.session[TOKEN_CACHE] = self.token_cache.serialize()
117
- if hasattr(self, "save_token_cache"):
118
- self.save_token_cache(self.token_cache)
124
+ if self.save_callback:
125
+ self.save_callback(self.session)
119
126
 
120
127
  def build_auth_code_flow(
121
128
  self,
@@ -125,8 +132,8 @@ class AsyncMSAL:
125
132
  **kwargs: Any,
126
133
  ) -> str:
127
134
  """First step - Start the flow."""
128
- self.session[TOKEN_CACHE] = None # type: ignore
129
- self.session[USER_EMAIL] = None # type: ignore
135
+ self.session[TOKEN_CACHE] = None
136
+ self.session[USER_EMAIL] = None
130
137
  self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
131
138
  scopes or DEFAULT_SCOPES,
132
139
  redirect_uri=redirect_uri,
@@ -149,10 +156,9 @@ class AsyncMSAL:
149
156
  raise web.HTTPBadRequest(text=str(result["error"]))
150
157
  if "id_token_claims" not in result:
151
158
  raise web.HTTPBadRequest(text=f"Expected id_token_claims in {result}")
152
- self._save_token_cache()
153
- self.session[USER_EMAIL] = result.get("id_token_claims").get(
154
- "preferred_username"
155
- )
159
+ self.save_token_cache()
160
+ if tok := result.get("id_token_claims"):
161
+ self.session[USER_EMAIL] = tok.get("preferred_username")
156
162
 
157
163
  async def async_acquire_token_by_auth_code_flow(self, auth_response: Any) -> None:
158
164
  """Second step - Acquire token, async version."""
@@ -167,7 +173,7 @@ class AsyncMSAL:
167
173
  result = self.app.acquire_token_silent(
168
174
  scopes=scopes or DEFAULT_SCOPES, account=accounts[0]
169
175
  )
170
- self._save_token_cache()
176
+ self.save_token_cache()
171
177
  return result
172
178
  return None
173
179
 
@@ -175,7 +181,9 @@ class AsyncMSAL:
175
181
  """Acquire a token based on username."""
176
182
  return await asyncio.get_event_loop().run_in_executor(None, self.get_token)
177
183
 
178
- async def request(self, method: str, url: str, **kwargs: Any) -> ClientResponse:
184
+ async def request(
185
+ self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
186
+ ) -> ClientResponse:
179
187
  """Make a request to url using an oauth session.
180
188
 
181
189
  :param str url: url to send request to
@@ -184,16 +192,14 @@ class AsyncMSAL:
184
192
  :return: Response of the request
185
193
  :rtype: aiohttp.Response
186
194
  """
187
- if not self._clientsession:
188
- AsyncMSAL._clientsession = ClientSession(trust_env=True)
189
-
190
195
  token = await self.async_get_token()
191
196
  if token is None:
192
197
  raise web.HTTPClientError(text="No login token available.")
193
198
 
194
199
  kwargs = kwargs.copy()
195
200
  # Ensure headers exist & make a copy
196
- kwargs["headers"] = headers = dict(kwargs.get("headers", {}))
201
+ headers = dict[str, str](kwargs.get("headers") or {}) # type:ignore[arg-type]
202
+ kwargs["headers"] = headers
197
203
 
198
204
  headers["Authorization"] = "Bearer " + token["access_token"]
199
205
 
@@ -207,17 +213,27 @@ class AsyncMSAL:
207
213
  if "data" in kwargs:
208
214
  kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json
209
215
 
210
- response = await self._clientsession.request(method, url, **kwargs)
216
+ if not AsyncMSAL.client_session:
217
+ AsyncMSAL.client_session = ClientSession(trust_env=True)
218
+
219
+ return await AsyncMSAL.client_session.request(method, url, **kwargs)
220
+
221
+ def request_ctx(
222
+ self, method: HttpMethods, url: StrOrURL, **kwargs: Unpack[_RequestOptions]
223
+ ) -> _RequestContextManager:
224
+ """Request context manager."""
225
+ return _RequestContextManager(self.request(method, url, **kwargs))
211
226
 
212
- return response
227
+ get = partialmethod(request_ctx, HTTP_GET)
228
+ post = partialmethod(request_ctx, HTTP_POST)
213
229
 
214
- def get(self, url: str, **kwargs: Any): # type:ignore
215
- """GET Request."""
216
- return _RequestContextManager(self.request(HTTP_GET, url, **kwargs))
230
+ # def get(self, url: str, **kwargs: Any) -> _RequestContextManager:
231
+ # """GET Request."""
232
+ # return _RequestContextManager(self.request(HTTP_GET, url, **kwargs))
217
233
 
218
- def post(self, url: str, **kwargs: Any): # type:ignore
219
- """POST request."""
220
- return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
234
+ # def post(self, url: str, **kwargs: Any) -> _RequestContextManager:
235
+ # """POST request."""
236
+ # return _RequestContextManager(self.request(HTTP_POST, url, **kwargs))
221
237
 
222
238
  @property
223
239
  def mail(self) -> str:
@@ -4,8 +4,9 @@ import asyncio
4
4
  import json
5
5
  import logging
6
6
  import time
7
+ from collections.abc import AsyncGenerator
7
8
  from contextlib import AsyncExitStack, asynccontextmanager
8
- from typing import Any, AsyncGenerator, Optional
9
+ from typing import Any
9
10
 
10
11
  from redis.asyncio import Redis, from_url
11
12
 
@@ -30,15 +31,15 @@ async def get_redis() -> AsyncGenerator[Redis, None]:
30
31
  try:
31
32
  yield redis
32
33
  finally:
33
- MENV.database = None # type:ignore
34
+ MENV.database = None # type:ignore[assignment]
34
35
  await redis.close()
35
36
 
36
37
 
37
38
  async def session_iter(
38
39
  redis: Redis,
39
40
  *,
40
- match: Optional[dict[str, str]] = None,
41
- key_match: Optional[str] = None,
41
+ match: dict[str, str] | None = None,
42
+ key_match: str | None = None,
42
43
  ) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]:
43
44
  """Iterate over the Redis keys to find a specific session.
44
45
 
@@ -55,10 +56,10 @@ async def session_iter(
55
56
  sval = await redis.get(key)
56
57
  created, ses = 0, {}
57
58
  try:
58
- val = json.loads(sval) # type: ignore
59
+ val = json.loads(sval) # type: ignore[arg-type]
59
60
  created = int(val["created"])
60
61
  ses = val["session"]
61
- except Exception: # pylint: disable=broad-except
62
+ except Exception:
62
63
  pass
63
64
  if match:
64
65
  # Ensure we match all the supplied terms
@@ -73,7 +74,7 @@ async def session_iter(
73
74
 
74
75
 
75
76
  async def session_clean(
76
- redis: Redis, *, max_age: int = 90, expected_keys: Optional[dict] = None
77
+ redis: Redis, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
77
78
  ) -> None:
78
79
  """Clear session entries older than max_age days."""
79
80
  rem, keep = 0, 0
@@ -93,11 +94,29 @@ async def session_clean(
93
94
  _LOGGER.debug("No sessions removed (%s total)", keep)
94
95
 
95
96
 
96
- def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
97
+ async def invalid_sessions(redis: Redis) -> None:
98
+ """Find & clean invalid sessions."""
99
+ async for key in redis.scan_iter(count=100, match=f"{MENV.COOKIE_NAME}*"):
100
+ if not isinstance(key, str):
101
+ key = key.decode()
102
+ sval = await redis.get(key)
103
+ if sval is None:
104
+ continue
105
+ try:
106
+ val: dict = json.loads(sval)
107
+ assert isinstance(val["created"], int)
108
+ assert isinstance(val["session"], dict)
109
+ except Exception as err:
110
+ _LOGGER.warning("Removing session %s: %s", key, err)
111
+ await redis.delete(key)
112
+
113
+
114
+ def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
97
115
  """Create a AsyncMSAL session.
98
116
 
99
117
  When get_token refreshes the token retrieved from Redis, the save_cache callback
100
- will be responsible to update the cache in Redis."""
118
+ will be responsible to update the cache in Redis.
119
+ """
101
120
 
102
121
  async def async_save_cache(_: dict) -> None:
103
122
  """Save the token cache to Redis."""
@@ -111,11 +130,11 @@ def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
111
130
  except RuntimeError:
112
131
  asyncio.run(async_save_cache(*args))
113
132
 
114
- return AsyncMSAL(session, save_cache=save_cache)
133
+ return AsyncMSAL(session, save_callback=save_cache)
115
134
 
116
135
 
117
136
  async def get_session(
118
- email: str, *, redis: Optional[Redis] = None, scope: str = ""
137
+ email: str, *, redis: Redis | None = None, scope: str = ""
119
138
  ) -> AsyncMSAL:
120
139
  """Get a session from Redis."""
121
140
  cnt = 0
@@ -126,7 +145,7 @@ async def get_session(
126
145
  cnt += 1
127
146
  if scope and scope not in str(session.get("token_cache")).lower():
128
147
  continue
129
- return _session_factory(key, str(created), session)
148
+ return _session_factory(key, created, session)
130
149
  msg = f"Session for {email}"
131
150
  if not scope:
132
151
  raise ValueError(f"{msg} not found")
@@ -136,7 +155,7 @@ async def get_session(
136
155
  async def redis_get_json(key: str) -> list | dict | None:
137
156
  """Get a key from redis."""
138
157
  res = await MENV.database.get(key)
139
- if isinstance(res, (str, bytes, bytearray)):
158
+ if isinstance(res, str | bytes | bytearray):
140
159
  return json.loads(res)
141
160
  if res is not None:
142
161
  _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
@@ -148,7 +167,7 @@ async def redis_get(key: str) -> str | None:
148
167
  res = await MENV.database.get(key)
149
168
  if isinstance(res, str):
150
169
  return res
151
- if isinstance(res, (bytes, bytearray)):
170
+ if isinstance(res, bytes | bytearray):
152
171
  return res.decode()
153
172
  if res is not None:
154
173
  _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
aiohttp_msal/routes.py CHANGED
@@ -1,8 +1,9 @@
1
1
  """The user blueprint."""
2
2
 
3
3
  import time
4
+ from collections.abc import Mapping, Sequence
4
5
  from inspect import iscoroutinefunction
5
- from typing import Any, Mapping, Sequence
6
+ from typing import Any
6
7
  from urllib.parse import urljoin
7
8
 
8
9
  from aiohttp import web
@@ -63,8 +64,7 @@ async def user_authorized(request: web.Request) -> web.Response:
63
64
  # Ensure all expected variables were returned...
64
65
  if not all(auth_response.get(k) for k in ["code", "session_state", "state"]):
65
66
  msg.append(
66
- "<b>Expecting code,state,session_state in post body.</b>"
67
- f"auth_response: {auth_response}"
67
+ f"<b>Expecting code,state,session_state in post body.</b>auth_response: {auth_response}"
68
68
  )
69
69
 
70
70
  if not request.cookies.get(ENV.COOKIE_NAME):
@@ -83,7 +83,7 @@ async def user_authorized(request: web.Request) -> web.Response:
83
83
  if not msg:
84
84
  try:
85
85
  await aiomsal.async_acquire_token_by_auth_code_flow(auth_response)
86
- except Exception as err: # pylint: disable=broad-except
86
+ except Exception as err:
87
87
  msg.append(
88
88
  "<b>Could not get token</b> - async_acquire_token_by_auth_code_flow"
89
89
  )
@@ -95,7 +95,7 @@ async def user_authorized(request: web.Request) -> web.Response:
95
95
  try:
96
96
  await get_user_info(aiomsal)
97
97
  await get_manager_info(aiomsal)
98
- except Exception as err: # pylint: disable=broad-except
98
+ except Exception as err:
99
99
  msg.append("Could not get org info from MS graph")
100
100
  msg.append(str(err))
101
101
  if session.get("mail"):
@@ -196,8 +196,7 @@ async def user_logout(request: web.Request, ses: AsyncMSAL) -> web.Response:
196
196
  _to = get_route(request, _to)
197
197
 
198
198
  return web.HTTPFound(
199
- "https://login.microsoftonline.com/common/oauth2/logout?"
200
- f"post_logout_redirect_uri={_to}"
199
+ f"https://login.microsoftonline.com/common/oauth2/logout?post_logout_redirect_uri={_to}"
201
200
  ) # redirect
202
201
 
203
202
 
aiohttp_msal/settings.py CHANGED
@@ -1,43 +1,47 @@
1
1
  """Settings."""
2
2
 
3
- from typing import TYPE_CHECKING, Any, Awaitable, Callable
3
+ from collections.abc import Awaitable, Callable
4
+ from typing import TYPE_CHECKING, Any
4
5
 
5
- from aiohttp_msal.settings_base import SettingsBase, Var
6
+ import attrs
7
+
8
+ from aiohttp_msal.settings_base import VAR_REQ, VAR_REQ_HIDE, SettingsBase
6
9
 
7
10
  if TYPE_CHECKING:
8
11
  from redis.asyncio import Redis
9
12
  else:
10
- Redis = Any
13
+ Redis = None
11
14
 
12
15
 
16
+ @attrs.define
13
17
  class MSALSettings(SettingsBase):
14
18
  """Settings."""
15
19
 
16
- SP_APP_ID = Var(str, required=True)
20
+ SP_APP_ID: str = attrs.field(metadata=VAR_REQ, default="")
17
21
  """SharePoint Application ID."""
18
- SP_APP_PW = Var(str, required=True)
22
+ SP_APP_PW: str = attrs.field(metadata=VAR_REQ_HIDE, default="")
19
23
  """SharePoint Application Secret."""
20
- SP_AUTHORITY = Var(str, required=True)
24
+ SP_AUTHORITY: str = attrs.field(metadata=VAR_REQ, default="")
21
25
  """SharePoint Authority URL.
22
26
 
23
27
  Examples:
24
28
  "https://login.microsoftonline.com/common" # For multi-tenant app
25
29
  "https://login.microsoftonline.com/Tenant_Name_or_UUID_Here"."""
26
30
 
27
- DOMAIN = "mydomain.com"
31
+ DOMAIN: str = "mydomain.com"
28
32
  """Your domain. Used by routes & Redis functions."""
29
33
 
30
- COOKIE_NAME = "AIOHTTP_SESSION"
34
+ COOKIE_NAME: str = "AIOHTTP_SESSION"
31
35
  """The name of the cookie with the session identifier."""
32
36
 
33
- login_callback: list[Callable[[Any], Awaitable[Any]]] = []
37
+ login_callback: list[Callable[[Any], Awaitable[Any]]] = attrs.field(factory=list)
34
38
  """A list of callbacks to execute on successful login."""
35
- info: dict[str, Callable[[Any], Any | Awaitable[Any]]] = {}
39
+ info: dict[str, Callable[[Any], Any | Awaitable[Any]]] = attrs.field(factory=dict)
36
40
  """List of attributes to return in /user/info."""
37
41
 
38
- REDIS = "redis://redis1:6379"
42
+ REDIS: str = "redis://redis1:6379"
39
43
  """OPTIONAL: Redis database connection used by app_init_redis_session()."""
40
- database: Redis = None # type: ignore
44
+ database: Redis = None # type: ignore[assignment]
41
45
  """Store the Redis connection when using app_init_redis_session()."""
42
46
 
43
47
 
@@ -1,29 +1,25 @@
1
1
  """Settings Base."""
2
2
 
3
- from __future__ import annotations
4
3
  import logging
5
4
  import os
6
5
  from pathlib import Path
7
- from typing import Any, Type
6
+ from typing import Any
8
7
 
8
+ import attrs
9
9
 
10
- class Var: # pylint: disable=too-few-public-methods
11
- """Variable settings."""
10
+ KEY_REQ = "required"
11
+ KEY_HIDE = "hide"
12
+ VAR_REQ_HIDE = {KEY_REQ: True, KEY_HIDE: True}
13
+ VAR_REQ = {KEY_REQ: True}
14
+ VAR_HIDE = {KEY_HIDE: True}
12
15
 
13
- @staticmethod
14
- def from_value(val: Any) -> Var:
15
- """Ensure the return is an instance of Var."""
16
- return val if isinstance(val, Var) else Var(type(val))
17
16
 
18
- def __init__(
19
- self, var_type: Type, hidden: bool = False, required: bool = False
20
- ) -> None:
21
- """Init class."""
22
- self.v_type = var_type
23
- self.hide = hidden
24
- self.required = required
17
+ def _is_hidden(atr: attrs.Attribute) -> bool:
18
+ """Is this field hidden."""
19
+ return bool(atr.metadata.get(KEY_HIDE))
25
20
 
26
21
 
22
+ @attrs.define
27
23
  class SettingsBase:
28
24
  """Retrieve Settings from environment variables.
29
25
 
@@ -34,50 +30,64 @@ class SettingsBase:
34
30
  convert environment variables to match the type of the value here.
35
31
  """
36
32
 
37
- _vars: dict[str, Var] = {}
38
- _env_prefix = ""
33
+ _env_prefix: str = attrs.field(init=False, default="")
34
+
35
+ def _get_fields(self) -> dict[str, attrs.Attribute]:
36
+ """Get env."""
37
+ res: list[attrs.Attribute] = [
38
+ a for a in attrs.fields(self.__class__) if a.name.isupper()
39
+ ]
40
+
41
+ dirs = [f for f in dir(self) if f.isupper()]
42
+ if len(dirs) != len(res):
43
+ for atr in res:
44
+ dirs.remove(atr.name)
45
+ raise AssertionError(f"There are UPPERCASE fields without a type!: {dirs}")
46
+
47
+ return {f"{self._env_prefix}{a.name}": a for a in res}
39
48
 
40
49
  def load(self, environment_prefix: str = "") -> None:
41
50
  """Initialize."""
42
- self._env_prefix = environment_prefix
43
51
  logger = logging.getLogger(__name__)
44
- attrs = [a for a in dir(self) if not a.startswith("_") and a.upper() == a]
45
- for name in attrs:
46
- curv = getattr(self, name)
47
- newv: Any = os.getenv(environment_prefix + name.upper())
48
- if isinstance(curv, Var):
49
- self._vars[name] = curv
50
- info = self._vars.get(name) or Var(type(curv))
51
- if not newv:
52
- if info.required:
53
- raise ValueError(f"Required value for {name} not provided")
52
+ self._env_prefix = environment_prefix.upper()
53
+ for ename, atr in self._get_fields().items():
54
+ newv = os.getenv(ename)
55
+ if newv is None:
56
+ if atr.metadata.get(KEY_REQ):
57
+ raise ValueError(f"Required value missing: {ename}")
54
58
  continue
55
59
  if newv.startswith('"') and newv.endswith('"'):
56
60
  newv = newv.strip('"')
57
- logger.debug("ENV %s = %s", name, "***" if info.hide else newv)
58
-
59
- if issubclass(info.v_type, bool):
60
- newv = newv.upper() in ("1", "TRUE")
61
- elif issubclass(info.v_type, int):
62
- newv = int(newv)
63
- elif issubclass(info.v_type, Path):
64
- newv = Path(newv)
65
- elif issubclass(info.v_type, bytes):
66
- newv = newv.encode()
67
-
68
- if name.endswith("_URI") and not newv.endswith("/"):
69
- newv += "/"
70
- setattr(self, name, newv)
71
-
72
- def to_dict(self, as_string: bool = False) -> dict[str, Any]:
61
+
62
+ curv = getattr(self, atr.name)
63
+ v_type = atr.type or type(curv)
64
+
65
+ if issubclass(v_type, bool):
66
+ setattr(self, atr.name, newv.upper() in ("1", "TRUE"))
67
+ elif issubclass(v_type, int):
68
+ setattr(self, atr.name, int(newv))
69
+ elif issubclass(v_type, Path):
70
+ setattr(self, atr.name, Path(newv))
71
+ elif issubclass(v_type, bytes):
72
+ setattr(self, atr.name, newv.encode())
73
+ else:
74
+ if atr.name.endswith("_URI") and not newv.endswith("/"):
75
+ newv += "/"
76
+ setattr(self, atr.name, newv)
77
+
78
+ logger.debug(
79
+ "ENV %s%s = %s",
80
+ self._env_prefix,
81
+ atr.name,
82
+ "***" if atr.metadata.get(KEY_HIDE) else getattr(self, atr.name),
83
+ )
84
+
85
+ def asdict(self, as_string: bool = False) -> dict[str, Any]:
73
86
  """Get all variables."""
74
87
  res = {}
75
- for name in vars(self):
76
- if name.startswith("_") or name.upper() != name:
77
- continue
78
- curv = getattr(self, name)
79
- info = self._vars.get(name) or Var(type(curv))
80
- if info.hide:
88
+ for ename, atr in self._get_fields().items():
89
+ curv = getattr(self, atr.name)
90
+ if atr.metadata.get(KEY_HIDE):
81
91
  continue
82
- res[self._env_prefix + name] = str(curv) if as_string else curv
92
+ res[ename] = str(curv) if as_string else curv
83
93
  return res
aiohttp_msal/user_info.py CHANGED
@@ -1,24 +1,28 @@
1
1
  """Graph User Info."""
2
2
 
3
3
  import asyncio
4
+ from collections.abc import Awaitable, Callable
4
5
  from functools import wraps
5
- from typing import Any, Callable
6
+ from typing import ParamSpec, TypeVar
6
7
 
7
8
  from aiohttp_msal.msal_async import AsyncMSAL
8
9
 
10
+ _T = TypeVar("_T")
11
+ _P = ParamSpec("_P")
9
12
 
10
- def retry(func: Callable) -> Callable:
13
+
14
+ def retry(func: Callable[_P, Awaitable[_T]]) -> Callable[_P, Awaitable[_T]]:
11
15
  """Retry if tenacity is installed."""
12
16
 
13
17
  @wraps(func)
14
- async def _retry(*args: Any, **kwargs: Any) -> Any:
18
+ async def _retry(*args: _P.args, **kwargs: _P.kwargs) -> _T:
15
19
  """Retry the request."""
16
20
  retries = [2, 4, 8]
17
21
  while True:
18
22
  try:
19
23
  res = await func(*args, **kwargs)
20
24
  return res
21
- except Exception as err: # pylint: disable=broad-except
25
+ except Exception as err:
22
26
  if retries:
23
27
  await asyncio.sleep(retries.pop())
24
28
  else:
@@ -1,37 +1,27 @@
1
- Metadata-Version: 2.1
2
- Name: aiohttp_msal
3
- Version: 0.7.1
1
+ Metadata-Version: 2.3
2
+ Name: aiohttp-msal
3
+ Version: 1.0.0
4
4
  Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
5
- Home-page: https://github.com/kellerza/aiohttp_msal
5
+ Keywords: aiohttp,asyncio,msal,oauth
6
6
  Author: Johann Kellerman
7
- Author-email: kellerza@gmail.com
7
+ Author-email: Johann Kellerman <kellerza@gmail.com>
8
8
  License: MIT
9
- Keywords: msal,oauth,aiohttp,asyncio
10
9
  Classifier: Development Status :: 4 - Beta
11
10
  Classifier: Intended Audience :: Developers
12
11
  Classifier: Natural Language :: English
13
- Classifier: Programming Language :: Python :: 3
14
12
  Classifier: Programming Language :: Python :: 3 :: Only
15
- Classifier: Programming Language :: Python :: 3.10
16
13
  Classifier: Programming Language :: Python :: 3.11
17
14
  Classifier: Programming Language :: Python :: 3.12
18
- Requires-Python: >=3.10
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Requires-Dist: aiohttp>=3.11.18,<3.13
17
+ Requires-Dist: aiohttp-session>=2.12.1,<3
18
+ Requires-Dist: attrs>=25.3,<26
19
+ Requires-Dist: msal>=1.32.3,<2
20
+ Requires-Dist: aiohttp-session[aioredis]>=2.12.1,<3 ; extra == 'aioredis'
21
+ Requires-Python: >=3.11
22
+ Project-URL: Homepage, https://github.com/kellerza/aiohttp_msal
23
+ Provides-Extra: aioredis
19
24
  Description-Content-Type: text/markdown
20
- License-File: LICENSE
21
- Requires-Dist: msal >=1.30.0
22
- Requires-Dist: aiohttp-session >=2.12
23
- Requires-Dist: aiohttp >=3.8
24
- Provides-Extra: redis
25
- Requires-Dist: aiohttp-session[aioredis] >=2.12 ; extra == 'redis'
26
- Provides-Extra: tests
27
- Requires-Dist: black ==24.8.0 ; extra == 'tests'
28
- Requires-Dist: pylint ==3.2.6 ; extra == 'tests'
29
- Requires-Dist: flake8 ; extra == 'tests'
30
- Requires-Dist: pytest-aiohttp ; extra == 'tests'
31
- Requires-Dist: pytest ; extra == 'tests'
32
- Requires-Dist: pytest-cov ; extra == 'tests'
33
- Requires-Dist: pytest-asyncio ; extra == 'tests'
34
- Requires-Dist: pytest-env ; extra == 'tests'
35
25
 
36
26
  # aiohttp_msal Python library
37
27
 
@@ -139,3 +129,12 @@ def main()
139
129
  # ...
140
130
  # use the Graphclient
141
131
  ```
132
+
133
+ ## Development
134
+
135
+ ```bash
136
+ uv sync --all-extras
137
+ uv tool install ruff
138
+ uv tool install codespell
139
+ uv tool install pyproject-fmt
140
+ ```
@@ -0,0 +1,10 @@
1
+ aiohttp_msal/__init__.py,sha256=867ca27f2272908ecd32f33ddcaac722c4e5aeb7c2554ae929d6d128be86b9bc,4027
2
+ aiohttp_msal/msal_async.py,sha256=8efdf9608c55e41f99ed66294a18303192e921739f49986cf90a015b07a50c55,9470
3
+ aiohttp_msal/redis_tools.py,sha256=ea40b0d3fcc341cbc872c25a3df1becc752f4f7bf392b37248d4d5e72d6d7241,6539
4
+ aiohttp_msal/routes.py,sha256=f305368d4f6a4a5a87e5fabd92f901647020af141d981f83bed402f969576e1d,8135
5
+ aiohttp_msal/settings.py,sha256=b6d6ea19bd97d6bec3b0bbca6f50250c32816195a32301cb1c8f926bef0afa52,1562
6
+ aiohttp_msal/settings_base.py,sha256=b516e3829851d6dbc70ab14271f394868140c949eafd79c90553b834d88f74d8,3150
7
+ aiohttp_msal/user_info.py,sha256=b4efaf03f9313ec787b1e5b136584673f3650df80919d7f538c09a1a6bc37fc4,1875
8
+ aiohttp_msal-1.0.0.dist-info/WHEEL,sha256=76443c98c0efcfdd1191eac5fa1d8223dba1c474dbd47676674a255e7ca48770,79
9
+ aiohttp_msal-1.0.0.dist-info/METADATA,sha256=dd4a69bc47da5e6c559fda12c41505fdc047c8edde36d498d25a93fc0b7b4ec8,4478
10
+ aiohttp_msal-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: uv 0.8.12
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -1,21 +0,0 @@
1
- The MIT License (MIT)
2
-
3
- Copyright (c) 2022-2024 kellerza
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
@@ -1,18 +0,0 @@
1
- aiohttp_msal/__init__.py,sha256=-8b6kR9wbqLoWmAVw4MsU5GYv5EK8IWoPFJJLr_UsHk,3850
2
- aiohttp_msal/msal_async.py,sha256=afvfh7gZrXk5KO7Umb9jAnQu4jdg_iVlgaTnxS3JgNM,8899
3
- aiohttp_msal/redis_tools.py,sha256=zgRACVxm2wPkbEHtA6VmArsd-QQlKn-crlq1XlFbjEY,5919
4
- aiohttp_msal/routes.py,sha256=gUkNJknrR0_9ohdgMejWCQXfcobPltuAl_3C4wthZAM,8198
5
- aiohttp_msal/settings.py,sha256=hWVJdtqcdAkqqN5I4GINJIZSFGhEuoBImM26NrhqY_M,1341
6
- aiohttp_msal/settings_base.py,sha256=m4tmurnq8xipVNAa-Dh4ii9Rsu6gg39F4aDJNHPLwiI,2919
7
- aiohttp_msal/user_info.py,sha256=fijBUbl5g1AVgrpOl-2ZY-eQCCWcu4YqcA0QaMQrcWw,1766
8
- tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- tests/test_init.py,sha256=EXq56_E2FUePXVFilAauHXoqItXe5Lvlpz8hBSUh6cU,1832
10
- tests/test_msal_async.py,sha256=7-G6dO3_qWb8zxqC6_SqMMubCBbEERZy513P7UM4vmw,365
11
- tests/test_redis_tools.py,sha256=uFpPSe6atbDVAuh1_OUtFgeZwyuLDspp42_EECJDSPg,1869
12
- tests/test_settings.py,sha256=z-qtUs1zl5Q9NEux051eebyPnArLZ_OfZu65FKz0N4Y,333
13
- aiohttp_msal-0.7.1.dist-info/LICENSE,sha256=BwqFEcF0Ij49hDZx4A_5CzsKnfU_twRjrm87JFwydFc,1080
14
- aiohttp_msal-0.7.1.dist-info/METADATA,sha256=mxa80vNNCVI7s7I5htk1OHqVBdfgy1TuPCm8_ydV8x4,4724
15
- aiohttp_msal-0.7.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
16
- aiohttp_msal-0.7.1.dist-info/top_level.txt,sha256=QPWOi5JtacVEdbaU5bJExc9o-cCT2Lufx0QhUpsv5_E,19
17
- aiohttp_msal-0.7.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
18
- aiohttp_msal-0.7.1.dist-info/RECORD,,
@@ -1,5 +0,0 @@
1
- Wheel-Version: 1.0
2
- Generator: setuptools (72.2.0)
3
- Root-Is-Purelib: true
4
- Tag: py3-none-any
5
-
@@ -1,2 +0,0 @@
1
- aiohttp_msal
2
- tests
@@ -1 +0,0 @@
1
-
tests/__init__.py DELETED
File without changes
tests/test_init.py DELETED
@@ -1,79 +0,0 @@
1
- """Init."""
2
-
3
- from unittest.mock import MagicMock, patch
4
-
5
- import pytest
6
-
7
- import aiohttp_msal.routes # noqa
8
- from aiohttp_msal import auth_ok, msal_session
9
-
10
-
11
- def a_yes(ses):
12
- return True
13
-
14
-
15
- def a_no(ses):
16
- return False
17
-
18
-
19
- @msal_session(a_yes, a_yes)
20
- async def t_2yes(request, ses):
21
- return True
22
-
23
-
24
- @msal_session(a_no, a_yes, at_least_one=True)
25
- async def t_1no1yes_one(request, ses):
26
- return True
27
-
28
-
29
- @msal_session(a_no, a_no, at_least_one=True)
30
- async def t_2no_one(request, ses):
31
- return True
32
-
33
-
34
- @patch("aiohttp_msal.get_session")
35
- async def test_include_any(get_session: MagicMock):
36
- get_session.return_value = {}
37
-
38
- assert await t_2yes({})
39
-
40
- with pytest.raises(Exception):
41
- await t_1no1yes_one
42
-
43
- with pytest.raises(Exception):
44
- await t_2no_one
45
-
46
-
47
- async def func(request, ses):
48
- return True
49
-
50
-
51
- @patch("aiohttp_msal.get_session")
52
- async def test_msal_session_auth(get_session: MagicMock):
53
- get_session.return_value = {}
54
-
55
- assert await msal_session(a_yes, a_yes)(func)({})
56
- assert await msal_session(a_yes, a_no, at_least_one=True)(func)({})
57
- assert await msal_session(a_no, a_yes, at_least_one=True)(func)({})
58
- assert await msal_session(a_no, a_no, a_no, a_yes, at_least_one=True)(func)({})
59
-
60
- with pytest.raises(Exception):
61
- await msal_session(a_yes, a_no)(func)({})
62
-
63
- with pytest.raises(Exception):
64
- await msal_session(a_yes, a_yes, a_no)(func)({})
65
-
66
- with pytest.raises(Exception):
67
- await msal_session(a_no, a_no, at_least_one=True)(func)({})
68
-
69
-
70
- @patch("aiohttp_msal.get_session")
71
- async def test_auth_ok(get_session: MagicMock):
72
- get_session.return_value = {"mail": "yes!"}
73
-
74
- assert await msal_session(a_yes)(func)({})
75
-
76
- get_session.return_value = {}
77
-
78
- with pytest.raises(Exception):
79
- assert await msal_session(a_yes, auth_ok)(func)({})
tests/test_msal_async.py DELETED
@@ -1,12 +0,0 @@
1
- """Test the AsyncMSAL class."""
2
-
3
- from aiohttp_msal.msal_async import AsyncMSAL, Session
4
-
5
-
6
- def test_ses():
7
- session = Session(None, new=True, data={"session": {"mail": "j@k", "name": "j"}})
8
- ses = AsyncMSAL(session)
9
- assert str(ses.name) == "j"
10
- assert str(ses.mail) == "j@k"
11
- assert str(ses.manager_mail) == ""
12
- assert str(ses.manager_name) == ""
tests/test_redis_tools.py DELETED
@@ -1,60 +0,0 @@
1
- """Test redis tools."""
2
-
3
- from json import dumps
4
- from typing import AsyncGenerator
5
- from unittest.mock import AsyncMock, MagicMock, Mock, call
6
-
7
- import pytest
8
-
9
- from aiohttp_msal.redis_tools import Redis, session_iter
10
-
11
-
12
- @pytest.fixture
13
- def redis() -> Redis:
14
- """Get a redis Mock instance."""
15
- testdata = {
16
- "a": dumps({"created": 1, "session": {"key": "a", "a": 1, "b": "2a"}}),
17
- "b": dumps({"created": 2, "session": {"key": "b", "a": 1, "b": "2b"}}),
18
- "c": dumps({"created": 3, "session": {"key": "c", "a": 5, "b": "6c"}}),
19
- }
20
-
21
- async def scan_iter(*, count: int, match: str) -> AsyncGenerator[str, None]:
22
- """Mock keys."""
23
- assert count == 100
24
- assert match == "a*"
25
- for key in testdata:
26
- yield key
27
-
28
- red = Mock()
29
- red.scan_iter = MagicMock(side_effect=scan_iter)
30
- red.get = AsyncMock(side_effect=list(testdata.values()))
31
- return red
32
-
33
-
34
- @pytest.mark.asyncio
35
- async def test_session_iter_fail(redis: Redis) -> None:
36
- """Test session iter."""
37
- match = {"a": 1}
38
- with pytest.raises(ValueError):
39
- async for _ in session_iter(redis, match=match, key_match="a*"):
40
- pass
41
-
42
- match = {"a": "1"}
43
- async for _ in session_iter(redis, match=match, key_match="a*"):
44
- assert False, "no match expected"
45
-
46
-
47
- @pytest.mark.asyncio
48
- async def test_session_iter(redis: Redis) -> None:
49
- """Test session iter."""
50
- match = {"b": "2"}
51
- expected = ["a", "b"]
52
- async for key, created, ses in session_iter(redis, match=match, key_match="a*"):
53
- assert expected.pop(0) == key
54
- assert key == ses["key"]
55
- assert created in (1, 2)
56
- assert key in ("a", "b")
57
-
58
- assert redis.scan_iter.call_args[1]["match"] == "a*"
59
- assert redis.scan_iter.call_args[1]["count"] == 100
60
- assert redis.scan_iter.call_args_list == [call(count=100, match="a*")]
tests/test_settings.py DELETED
@@ -1,11 +0,0 @@
1
- from aiohttp_msal.settings import ENV, Var
2
-
3
-
4
- def test_load():
5
- assert ENV.DOMAIN == "mydomain.com"
6
- assert isinstance(ENV.SP_APP_ID, Var)
7
- ENV.load("X_")
8
- assert ENV.SP_APP_ID == "i1"
9
- assert ENV.SP_APP_PW == "p1"
10
- ENV.load()
11
- assert ENV.to_dict() == {"SP_APP_ID": "i2", "SP_APP_PW": "p2", "SP_AUTHORITY": "a2"}