aiohttp-msal 1.0.5__tar.gz → 1.0.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: aiohttp-msal
3
- Version: 1.0.5
3
+ Version: 1.0.7
4
4
  Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
5
5
  Keywords: aiohttp,asyncio,msal,oauth
6
6
  Author: Johann Kellerman
@@ -4,7 +4,7 @@ requires = [ "uv-build" ] # >=0.5.15,<0.6
4
4
 
5
5
  [project]
6
6
  name = "aiohttp-msal"
7
- version = "1.0.5"
7
+ version = "1.0.7"
8
8
  description = "Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp"
9
9
  readme = "README.md"
10
10
  keywords = [ "aiohttp", "asyncio", "msal", "oauth" ]
@@ -1,11 +1,10 @@
1
1
  """aiohttp_msal."""
2
2
 
3
- import json
4
3
  import logging
5
4
  from collections.abc import Awaitable, Callable
6
5
  from functools import wraps
7
6
  from inspect import getfullargspec, iscoroutinefunction
8
- from typing import Any, TypeVar, TypeVarTuple, cast
7
+ from typing import TypeVar, TypeVarTuple, cast
9
8
 
10
9
  from aiohttp import ClientSession, web
11
10
  from aiohttp_session import get_session
@@ -15,7 +14,7 @@ from aiohttp_msal.msal_async import AsyncMSAL
15
14
  from aiohttp_msal.settings import ENV
16
15
  from aiohttp_msal.utils import retry
17
16
 
18
- _LOGGER = logging.getLogger(__name__)
17
+ _LOG = logging.getLogger(__name__)
19
18
 
20
19
  _T = TypeVar("_T")
21
20
  Ts = TypeVarTuple("Ts")
@@ -92,8 +91,6 @@ async def app_init_redis_session(
92
91
  app: web.Application,
93
92
  max_age: int = 3600 * 24 * 90,
94
93
  check_proxy_cb: Callable[[], Awaitable[None]] | None = None,
95
- encoder: Callable[[object], str] = json.dumps,
96
- decoder: Callable[[str], Any] = json.loads,
97
94
  ) -> None:
98
95
  """Init an aiohttp_session with Redis storage helper.
99
96
 
@@ -107,7 +104,7 @@ async def app_init_redis_session(
107
104
  else:
108
105
  await check_proxy()
109
106
 
110
- _LOGGER.info("Connect to Redis %s", ENV.REDIS)
107
+ _LOG.info("Connect to Redis %s", ENV.REDIS)
111
108
  try:
112
109
  ENV.database = from_url(ENV.REDIS)
113
110
  # , encoding="utf-8", decode_responses=True
@@ -123,8 +120,8 @@ async def app_init_redis_session(
123
120
  secure=True,
124
121
  domain=ENV.DOMAIN,
125
122
  cookie_name=ENV.COOKIE_NAME,
126
- encoder=encoder,
127
- decoder=decoder,
123
+ encoder=ENV.dumps,
124
+ decoder=ENV.loads,
128
125
  )
129
126
  _setup(app, storage)
130
127
 
@@ -6,7 +6,6 @@ Once you have the OAuth tokens store in the session, you are free to make reques
6
6
  """
7
7
 
8
8
  import asyncio
9
- import json
10
9
  import logging
11
10
  from collections.abc import Callable
12
11
  from functools import cached_property, partialmethod
@@ -43,7 +42,7 @@ T = TypeVar("T")
43
42
 
44
43
  @attrs.define(slots=False)
45
44
  class AsyncMSAL:
46
- """AsycMSAL class.
45
+ """AsyncMSAL class.
47
46
 
48
47
  Authorization Code Flow Helper. Learn more about auth-code-flow at
49
48
  https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-auth-code-flow
@@ -220,7 +219,7 @@ class AsyncMSAL:
220
219
  elif method in [HTTP_POST, HTTP_PUT, HTTP_PATCH]:
221
220
  headers["Content-type"] = "application/json"
222
221
  if "data" in kwargs:
223
- kwargs["data"] = json.dumps(kwargs["data"]) # auto convert to json
222
+ kwargs["data"] = ENV.dumps(kwargs["data"]) # auto convert to json
224
223
 
225
224
  if not AsyncMSAL.client_session:
226
225
  AsyncMSAL.client_session = ClientSession(trust_env=True)
@@ -1,19 +1,18 @@
1
1
  """Redis tools for sessions."""
2
2
 
3
3
  import asyncio
4
- import json
5
4
  import logging
6
5
  import time
7
6
  from collections.abc import AsyncGenerator
8
7
  from contextlib import AsyncExitStack, asynccontextmanager
9
- from typing import Any
8
+ from typing import Any, TypeVar
10
9
 
11
10
  from redis.asyncio import Redis, from_url
12
11
 
13
12
  from aiohttp_msal.msal_async import AsyncMSAL
14
- from aiohttp_msal.settings import ENV as MENV
13
+ from aiohttp_msal.settings import ENV
15
14
 
16
- _LOGGER = logging.getLogger(__name__)
15
+ _LOG = logging.getLogger(__name__)
17
16
 
18
17
  SES_KEYS = ("mail", "name", "m_mail", "m_name")
19
18
 
@@ -21,22 +20,23 @@ SES_KEYS = ("mail", "name", "m_mail", "m_name")
21
20
  @asynccontextmanager
22
21
  async def get_redis() -> AsyncGenerator[Redis, None]:
23
22
  """Get a Redis connection."""
24
- if MENV.database:
25
- _LOGGER.debug("Using redis from environment")
26
- yield MENV.database
23
+ if ENV.database:
24
+ _LOG.debug("Using redis from environment")
25
+ yield ENV.database
27
26
  return
28
- _LOGGER.info("Connect to Redis %s", MENV.REDIS)
29
- redis = from_url(MENV.REDIS) # decode_responses=True not allowed aiohttp_session
30
- MENV.database = redis
27
+ _LOG.info("Connect to Redis %s", ENV.REDIS)
28
+ redis = from_url(ENV.REDIS) # decode_responses=True not allowed aiohttp_session
29
+ ENV.database = redis
31
30
  try:
32
31
  yield redis
33
32
  finally:
34
- MENV.database = None # type:ignore[assignment]
33
+ ENV.database = None # type:ignore[assignment]
35
34
  await redis.close()
36
35
 
37
36
 
38
37
  async def session_iter(
39
38
  redis: Redis,
39
+ /,
40
40
  *,
41
41
  match: dict[str, str] | None = None,
42
42
  key_match: str | None = None,
@@ -49,14 +49,14 @@ async def session_iter(
49
49
  if match and not all(isinstance(v, str) for v in match.values()):
50
50
  raise ValueError("match values must be strings")
51
51
  async for key in redis.scan_iter(
52
- count=100, match=key_match or f"{MENV.COOKIE_NAME}*"
52
+ count=100, match=key_match or f"{ENV.COOKIE_NAME}*"
53
53
  ):
54
54
  if not isinstance(key, str):
55
55
  key = key.decode()
56
56
  sval = await redis.get(key)
57
57
  created, ses = 0, {}
58
58
  try:
59
- val = json.loads(sval) # type: ignore[arg-type]
59
+ val = ENV.loads(sval) # type: ignore[arg-type]
60
60
  created = int(val["created"])
61
61
  ses = val["session"]
62
62
  except Exception:
@@ -74,7 +74,7 @@ async def session_iter(
74
74
 
75
75
 
76
76
  async def session_clean(
77
- redis: Redis, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
77
+ redis: Redis, /, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
78
78
  ) -> None:
79
79
  """Clear session entries older than max_age days."""
80
80
  rem, keep = 0, 0
@@ -89,30 +89,35 @@ async def session_clean(
89
89
  keep += 1
90
90
  finally:
91
91
  if rem:
92
- _LOGGER.info("Sessions removed: %s (%s total)", rem, keep)
92
+ _LOG.info("Sessions removed: %s (%s total)", rem, keep)
93
93
  else:
94
- _LOGGER.debug("No sessions removed (%s total)", keep)
94
+ _LOG.debug("No sessions removed (%s total)", keep)
95
95
 
96
96
 
97
- async def invalid_sessions(redis: Redis) -> None:
97
+ async def invalid_sessions(redis: Redis, /) -> None:
98
98
  """Find & clean invalid sessions."""
99
- async for key in redis.scan_iter(count=100, match=f"{MENV.COOKIE_NAME}*"):
99
+ async for key in redis.scan_iter(count=100, match=f"{ENV.COOKIE_NAME}*"):
100
100
  if not isinstance(key, str):
101
101
  key = key.decode()
102
102
  sval = await redis.get(key)
103
103
  if sval is None:
104
104
  continue
105
105
  try:
106
- val: dict = json.loads(sval)
106
+ val: dict = ENV.loads(sval)
107
107
  assert isinstance(val["created"], int)
108
108
  assert isinstance(val["session"], dict)
109
109
  except Exception as err:
110
- _LOGGER.warning("Removing session %s: %s", key, err)
110
+ _LOG.warning("Removing session %s: %s", key, err)
111
111
  await redis.delete(key)
112
112
 
113
113
 
114
- def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
115
- """Create a AsyncMSAL session.
114
+ T = TypeVar("T", bound=AsyncMSAL)
115
+
116
+
117
+ def async_msal_factory(
118
+ cls: type[T], key: str, created: int, session: dict[str, Any], /
119
+ ) -> T:
120
+ """Create a AsyncMSAL session with a save_callback.
116
121
 
117
122
  When get_token refreshes the token retrieved from Redis, the save_cache callback
118
123
  will be responsible to update the cache in Redis.
@@ -121,7 +126,7 @@ def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
121
126
  async def async_save_cache(_: dict) -> None:
122
127
  """Save the token cache to Redis."""
123
128
  async with get_redis() as rd2:
124
- await rd2.set(key, json.dumps({"created": created, "session": session}))
129
+ await rd2.set(key, ENV.dumps({"created": created, "session": session}))
125
130
 
126
131
  def save_cache(*args: Any) -> None:
127
132
  """Save the token cache to Redis."""
@@ -130,12 +135,17 @@ def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
130
135
  except RuntimeError:
131
136
  asyncio.run(async_save_cache(*args))
132
137
 
133
- return AsyncMSAL(session, save_callback=save_cache)
138
+ return cls(session, save_callback=save_cache)
134
139
 
135
140
 
136
141
  async def get_session(
137
- email: str, *, redis: Redis | None = None, scope: str = ""
138
- ) -> AsyncMSAL:
142
+ cls: type[T],
143
+ email: str,
144
+ /,
145
+ *,
146
+ redis: Redis | None = None,
147
+ scope: str = "",
148
+ ) -> T:
139
149
  """Get a session from Redis."""
140
150
  cnt = 0
141
151
  async with AsyncExitStack() as stack:
@@ -143,34 +153,34 @@ async def get_session(
143
153
  redis = await stack.enter_async_context(get_redis())
144
154
  async for key, created, session in session_iter(redis, match={"mail": email}):
145
155
  cnt += 1
146
- if scope and scope not in str(session.get("token_cache")).lower():
156
+ if scope and scope not in str(session.get(cls.token_cache_key)).lower():
147
157
  continue
148
- return _session_factory(key, created, session)
158
+ return async_msal_factory(cls, key, created, session)
149
159
  msg = f"Session for {email}"
150
160
  if not scope:
151
161
  raise ValueError(f"{msg} not found")
152
162
  raise ValueError(f"{msg} with scope {scope} not found ({cnt} checked)")
153
163
 
154
164
 
155
- async def redis_get_json(key: str) -> list | dict | None:
165
+ async def redis_get_json(key: str) -> list[Any] | dict[str, Any] | None:
156
166
  """Get a key from redis."""
157
- res = await MENV.database.get(key)
167
+ res = await ENV.database.get(key)
158
168
  if isinstance(res, str | bytes | bytearray):
159
- return json.loads(res)
169
+ return ENV.loads(res)
160
170
  if res is not None:
161
- _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
171
+ _LOG.warning("Unexpected type for %s: %s", key, type(res))
162
172
  return None
163
173
 
164
174
 
165
175
  async def redis_get(key: str) -> str | None:
166
176
  """Get a key from redis."""
167
- res = await MENV.database.get(key)
177
+ res = await ENV.database.get(key)
168
178
  if isinstance(res, str):
169
179
  return res
170
180
  if isinstance(res, bytes | bytearray):
171
181
  return res.decode()
172
182
  if res is not None:
173
- _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
183
+ _LOG.warning("Unexpected type for %s: %s", key, type(res))
174
184
  return None
175
185
 
176
186
 
@@ -178,22 +188,22 @@ async def redis_set_set(key: str, new_set: set[str]) -> None:
178
188
  """Set the value of a set in redis."""
179
189
  cur_set = set(
180
190
  s if isinstance(s, str) else s.decode()
181
- for s in await MENV.database.smembers(key)
191
+ for s in await ENV.database.smembers(key)
182
192
  )
183
193
  dif = list(cur_set - new_set)
184
194
  if dif:
185
- _LOGGER.warning("%s: removing %s", key, dif)
186
- await MENV.database.srem(key, *dif)
195
+ _LOG.warning("%s: removing %s", key, dif)
196
+ await ENV.database.srem(key, *dif)
187
197
 
188
198
  dif = list(new_set - cur_set)
189
199
  if dif:
190
- _LOGGER.info("%s: adding %s", key, dif)
191
- await MENV.database.sadd(key, *dif)
200
+ _LOG.info("%s: adding %s", key, dif)
201
+ await ENV.database.sadd(key, *dif)
192
202
 
193
203
 
194
- async def redis_scan(match_str: str) -> list[str]:
204
+ async def redis_scan_keys(match_str: str) -> list[str]:
195
205
  """Return a list of matching keys."""
196
206
  return [
197
207
  s if isinstance(s, str) else s.decode()
198
- async for s in MENV.database.scan_iter(match=match_str)
208
+ async for s in ENV.database.scan_iter(match=match_str)
199
209
  ]
@@ -1,5 +1,6 @@
1
1
  """Settings."""
2
2
 
3
+ import json
3
4
  from collections.abc import Awaitable, Callable
4
5
  from typing import TYPE_CHECKING, Any
5
6
 
@@ -39,8 +40,11 @@ class MSALSettings(SettingsBase):
39
40
 
40
41
  REDIS: str = "redis://redis1:6379"
41
42
  """OPTIONAL: Redis database connection used by app_init_redis_session()."""
42
- database: "Redis" = attrs.field(init=False)
43
+ database: "Redis" = attrs.field(init=False, default=None)
43
44
  """Store the Redis connection when using app_init_redis_session()."""
44
45
 
46
+ dumps: Callable[[Any], str] = attrs.field(default=json.dumps)
47
+ loads: Callable[[str | bytes | bytearray], Any] = attrs.field(default=json.loads)
48
+
45
49
 
46
50
  ENV = MSALSettings()
File without changes