aiohttp-msal 1.0.5__tar.gz → 1.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: aiohttp-msal
3
- Version: 1.0.5
3
+ Version: 1.0.6
4
4
  Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
5
5
  Keywords: aiohttp,asyncio,msal,oauth
6
6
  Author: Johann Kellerman
@@ -4,7 +4,7 @@ requires = [ "uv-build" ] # >=0.5.15,<0.6
4
4
 
5
5
  [project]
6
6
  name = "aiohttp-msal"
7
- version = "1.0.5"
7
+ version = "1.0.6"
8
8
  description = "Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp"
9
9
  readme = "README.md"
10
10
  keywords = [ "aiohttp", "asyncio", "msal", "oauth" ]
@@ -15,7 +15,7 @@ from aiohttp_msal.msal_async import AsyncMSAL
15
15
  from aiohttp_msal.settings import ENV
16
16
  from aiohttp_msal.utils import retry
17
17
 
18
- _LOGGER = logging.getLogger(__name__)
18
+ _LOG = logging.getLogger(__name__)
19
19
 
20
20
  _T = TypeVar("_T")
21
21
  Ts = TypeVarTuple("Ts")
@@ -107,7 +107,7 @@ async def app_init_redis_session(
107
107
  else:
108
108
  await check_proxy()
109
109
 
110
- _LOGGER.info("Connect to Redis %s", ENV.REDIS)
110
+ _LOG.info("Connect to Redis %s", ENV.REDIS)
111
111
  try:
112
112
  ENV.database = from_url(ENV.REDIS)
113
113
  # , encoding="utf-8", decode_responses=True
@@ -6,14 +6,14 @@ import logging
6
6
  import time
7
7
  from collections.abc import AsyncGenerator
8
8
  from contextlib import AsyncExitStack, asynccontextmanager
9
- from typing import Any
9
+ from typing import Any, TypeVar
10
10
 
11
11
  from redis.asyncio import Redis, from_url
12
12
 
13
13
  from aiohttp_msal.msal_async import AsyncMSAL
14
14
  from aiohttp_msal.settings import ENV as MENV
15
15
 
16
- _LOGGER = logging.getLogger(__name__)
16
+ _LOG = logging.getLogger(__name__)
17
17
 
18
18
  SES_KEYS = ("mail", "name", "m_mail", "m_name")
19
19
 
@@ -22,10 +22,10 @@ SES_KEYS = ("mail", "name", "m_mail", "m_name")
22
22
  async def get_redis() -> AsyncGenerator[Redis, None]:
23
23
  """Get a Redis connection."""
24
24
  if MENV.database:
25
- _LOGGER.debug("Using redis from environment")
25
+ _LOG.debug("Using redis from environment")
26
26
  yield MENV.database
27
27
  return
28
- _LOGGER.info("Connect to Redis %s", MENV.REDIS)
28
+ _LOG.info("Connect to Redis %s", MENV.REDIS)
29
29
  redis = from_url(MENV.REDIS) # decode_responses=True not allowed aiohttp_session
30
30
  MENV.database = redis
31
31
  try:
@@ -37,6 +37,7 @@ async def get_redis() -> AsyncGenerator[Redis, None]:
37
37
 
38
38
  async def session_iter(
39
39
  redis: Redis,
40
+ /,
40
41
  *,
41
42
  match: dict[str, str] | None = None,
42
43
  key_match: str | None = None,
@@ -74,7 +75,7 @@ async def session_iter(
74
75
 
75
76
 
76
77
  async def session_clean(
77
- redis: Redis, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
78
+ redis: Redis, /, *, max_age: int = 90, expected_keys: dict[str, Any] | None = None
78
79
  ) -> None:
79
80
  """Clear session entries older than max_age days."""
80
81
  rem, keep = 0, 0
@@ -89,12 +90,12 @@ async def session_clean(
89
90
  keep += 1
90
91
  finally:
91
92
  if rem:
92
- _LOGGER.info("Sessions removed: %s (%s total)", rem, keep)
93
+ _LOG.info("Sessions removed: %s (%s total)", rem, keep)
93
94
  else:
94
- _LOGGER.debug("No sessions removed (%s total)", keep)
95
+ _LOG.debug("No sessions removed (%s total)", keep)
95
96
 
96
97
 
97
- async def invalid_sessions(redis: Redis) -> None:
98
+ async def invalid_sessions(redis: Redis, /) -> None:
98
99
  """Find & clean invalid sessions."""
99
100
  async for key in redis.scan_iter(count=100, match=f"{MENV.COOKIE_NAME}*"):
100
101
  if not isinstance(key, str):
@@ -107,12 +108,17 @@ async def invalid_sessions(redis: Redis) -> None:
107
108
  assert isinstance(val["created"], int)
108
109
  assert isinstance(val["session"], dict)
109
110
  except Exception as err:
110
- _LOGGER.warning("Removing session %s: %s", key, err)
111
+ _LOG.warning("Removing session %s: %s", key, err)
111
112
  await redis.delete(key)
112
113
 
113
114
 
114
- def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
115
- """Create a AsyncMSAL session.
115
+ T = TypeVar("T", bound=AsyncMSAL)
116
+
117
+
118
+ def async_msal_factory(
119
+ cls: type[T], key: str, created: int, session: dict[str, Any], /
120
+ ) -> T:
121
+ """Create a AsyncMSAL session with a save_callback.
116
122
 
117
123
  When get_token refreshes the token retrieved from Redis, the save_cache callback
118
124
  will be responsible to update the cache in Redis.
@@ -130,12 +136,17 @@ def _session_factory(key: str, created: int, session: dict) -> AsyncMSAL:
130
136
  except RuntimeError:
131
137
  asyncio.run(async_save_cache(*args))
132
138
 
133
- return AsyncMSAL(session, save_callback=save_cache)
139
+ return cls(session, save_callback=save_cache)
134
140
 
135
141
 
136
142
  async def get_session(
137
- email: str, *, redis: Redis | None = None, scope: str = ""
138
- ) -> AsyncMSAL:
143
+ cls: type[T],
144
+ email: str,
145
+ /,
146
+ *,
147
+ redis: Redis | None = None,
148
+ scope: str = "",
149
+ ) -> T:
139
150
  """Get a session from Redis."""
140
151
  cnt = 0
141
152
  async with AsyncExitStack() as stack:
@@ -143,22 +154,22 @@ async def get_session(
143
154
  redis = await stack.enter_async_context(get_redis())
144
155
  async for key, created, session in session_iter(redis, match={"mail": email}):
145
156
  cnt += 1
146
- if scope and scope not in str(session.get("token_cache")).lower():
157
+ if scope and scope not in str(session.get(cls.token_cache_key)).lower():
147
158
  continue
148
- return _session_factory(key, created, session)
159
+ return async_msal_factory(cls, key, created, session)
149
160
  msg = f"Session for {email}"
150
161
  if not scope:
151
162
  raise ValueError(f"{msg} not found")
152
163
  raise ValueError(f"{msg} with scope {scope} not found ({cnt} checked)")
153
164
 
154
165
 
155
- async def redis_get_json(key: str) -> list | dict | None:
166
+ async def redis_get_json(key: str) -> list[str] | dict[str, Any] | None:
156
167
  """Get a key from redis."""
157
168
  res = await MENV.database.get(key)
158
169
  if isinstance(res, str | bytes | bytearray):
159
170
  return json.loads(res)
160
171
  if res is not None:
161
- _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
172
+ _LOG.warning("Unexpected type for %s: %s", key, type(res))
162
173
  return None
163
174
 
164
175
 
@@ -170,7 +181,7 @@ async def redis_get(key: str) -> str | None:
170
181
  if isinstance(res, bytes | bytearray):
171
182
  return res.decode()
172
183
  if res is not None:
173
- _LOGGER.warning("Unexpected type for %s: %s", key, type(res))
184
+ _LOG.warning("Unexpected type for %s: %s", key, type(res))
174
185
  return None
175
186
 
176
187
 
@@ -182,16 +193,16 @@ async def redis_set_set(key: str, new_set: set[str]) -> None:
182
193
  )
183
194
  dif = list(cur_set - new_set)
184
195
  if dif:
185
- _LOGGER.warning("%s: removing %s", key, dif)
196
+ _LOG.warning("%s: removing %s", key, dif)
186
197
  await MENV.database.srem(key, *dif)
187
198
 
188
199
  dif = list(new_set - cur_set)
189
200
  if dif:
190
- _LOGGER.info("%s: adding %s", key, dif)
201
+ _LOG.info("%s: adding %s", key, dif)
191
202
  await MENV.database.sadd(key, *dif)
192
203
 
193
204
 
194
- async def redis_scan(match_str: str) -> list[str]:
205
+ async def redis_scan_keys(match_str: str) -> list[str]:
195
206
  """Return a list of matching keys."""
196
207
  return [
197
208
  s if isinstance(s, str) else s.decode()
File without changes