aiohttp-msal 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aiohttp_msal/__init__.py CHANGED
@@ -13,7 +13,7 @@ from .settings import ENV
13
13
 
14
14
  _LOGGER = logging.getLogger(__name__)
15
15
 
16
- VERSION = "0.6.3"
16
+ VERSION = "0.6.5"
17
17
 
18
18
 
19
19
  def msal_session(*args: Callable[[AsyncMSAL], Union[Any, Awaitable[Any]]]) -> Callable:
@@ -23,7 +23,7 @@ HTTP_PATCH = "patch"
23
23
  HTTP_DELETE = "delete"
24
24
  HTTP_ALLOWED = [HTTP_GET, HTTP_POST, HTTP_PUT, HTTP_PATCH, HTTP_DELETE]
25
25
 
26
- MY_SCOPE = ["User.Read", "User.Read.All"]
26
+ DEFAULT_SCOPES = ["User.Read", "User.Read.All"]
27
27
 
28
28
 
29
29
  def async_wrap(func: Callable) -> Callable:
@@ -71,11 +71,12 @@ class AsyncMSAL:
71
71
  https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.initiate_auth_code_flow
72
72
 
73
73
  The caller is expected to:
74
- 1.somehow store this content, typically inside the current session of the server,
75
- 2.guide the end user (i.e. resource owner) to visit that auth_uri,
76
- typically with a redirect
77
- 3.and then relay this dict and subsequent auth response to
78
- acquire_token_by_auth_code_flow().
74
+ 1. somehow store this content, typically inside the current session of the
75
+ server,
76
+ 2. guide the end user (i.e. resource owner) to visit that auth_uri,
77
+ typically with a redirect
78
+ 3. and then relay this dict and subsequent auth response to
79
+ acquire_token_by_auth_code_flow().
79
80
 
80
81
  [1. and part of 3.] is stored by this class in the aiohttp_session
81
82
 
@@ -147,12 +148,14 @@ class AsyncMSAL:
147
148
  if hasattr(self, "save_token_cache"):
148
149
  self.save_token_cache(self.token_cache)
149
150
 
150
- def build_auth_code_flow(self, redirect_uri: str) -> str:
151
+ def build_auth_code_flow(
152
+ self, redirect_uri: str, scopes: Optional[list[str]] = None
153
+ ) -> str:
151
154
  """First step - Start the flow."""
152
155
  self.session[TOKEN_CACHE] = None # type: ignore
153
156
  self.session[USER_EMAIL] = None # type: ignore
154
157
  self.session[FLOW_CACHE] = res = self.app.initiate_auth_code_flow(
155
- MY_SCOPE,
158
+ scopes or DEFAULT_SCOPES,
156
159
  redirect_uri=redirect_uri,
157
160
  response_mode="form_post"
158
161
  # max_age=1209600,
@@ -182,11 +185,13 @@ class AsyncMSAL:
182
185
  None, self.acquire_token_by_auth_code_flow, auth_response
183
186
  )
184
187
 
185
- def get_token(self) -> Optional[dict[str, Any]]:
188
+ def get_token(self, scopes: Optional[list[str]] = None) -> Optional[dict[str, Any]]:
186
189
  """Acquire a token based on username."""
187
190
  accounts = self.app.get_accounts()
188
191
  if accounts:
189
- result = self.app.acquire_token_silent(scopes=MY_SCOPE, account=accounts[0])
192
+ result = self.app.acquire_token_silent(
193
+ scopes=scopes or DEFAULT_SCOPES, account=accounts[0]
194
+ )
190
195
  self._save_token_cache()
191
196
  return result
192
197
  return None
@@ -3,6 +3,7 @@ import asyncio
3
3
  import json
4
4
  import logging
5
5
  import time
6
+ from contextlib import AsyncExitStack, asynccontextmanager
6
7
  from typing import Any, AsyncGenerator, Optional
7
8
 
8
9
  from redis.asyncio import Redis, from_url
@@ -15,56 +16,81 @@ _LOGGER = logging.getLogger(__name__)
15
16
  SES_KEYS = ("mail", "name", "m_mail", "m_name")
16
17
 
17
18
 
18
- def get_redis() -> Redis:
19
+ @asynccontextmanager
20
+ async def get_redis() -> AsyncGenerator[Redis, None]:
19
21
  """Get a Redis connection."""
22
+ if ENV.database:
23
+ _LOGGER.debug("Using redis from environment")
24
+ yield ENV.database
25
+ return
20
26
  _LOGGER.info("Connect to Redis %s", ENV.REDIS)
21
- ENV.database = from_url(ENV.REDIS) # pylint: disable=no-member
22
- return ENV.database
27
+ redis = from_url(ENV.REDIS)
28
+ try:
29
+ yield redis
30
+ finally:
31
+ await redis.close()
32
+
33
+
34
+ async def session_iter(
35
+ redis: Redis,
36
+ *,
37
+ match: Optional[dict[str, str]] = None,
38
+ key_match: Optional[str] = None,
39
+ ) -> AsyncGenerator[tuple[str, int, dict[str, Any]], None]:
40
+ """Iterate over the Redis keys to find a specific session.
41
+
42
+ match: Filter based on session content (i.e. mail/name)
43
+ key_match: Filter the Redis keys. Defaults to ENV.cookie_name
44
+ """
45
+ async for key in redis.scan_iter(
46
+ count=100, match=key_match or f"{ENV.COOKIE_NAME}*"
47
+ ):
48
+ sval = await redis.get(key)
49
+ created, ses = 0, {}
50
+ try:
51
+ val = json.loads(sval) # type: ignore
52
+ created = int(val["created"])
53
+ ses = val["session"]
54
+ except Exception: # pylint: disable=broad-except
55
+ pass
56
+ if match:
57
+ # Ensure we match all the supplied terms
58
+ if not all(k in ses and v in ses[k] for k, v in match.items()):
59
+ continue
60
+ yield key, created, ses
23
61
 
24
62
 
25
- async def iter_redis(
26
- redis: Redis, *, clean: bool = False, match: Optional[dict[str, str]] = None
27
- ) -> AsyncGenerator[tuple[str, str, dict], None]:
28
- """Iterate over the Redis keys to find a specific session."""
29
- async for key in redis.scan_iter(count=100, match=f"{ENV.COOKIE_NAME}*"):
30
- sval = await redis.get(key)
31
- if not isinstance(sval, (str, bytes, bytearray)):
32
- if clean:
33
- await redis.delete(key)
34
- continue
35
- val = json.loads(sval)
36
- ses = val.get("session") or {}
37
- created = val.get("created")
38
- if clean and not ses or not created:
39
- await redis.delete(key)
40
- continue
41
- if match and not all(v in ses[k] for k, v in match.items()):
42
- continue
43
- yield key, created or "0", ses
44
-
45
-
46
- async def clean_redis(redis: Redis, max_age: int = 90) -> None:
63
+ async def session_clean(
64
+ redis: Redis, *, max_age: int = 90, expected_keys: Optional[dict] = None
65
+ ) -> None:
47
66
  """Clear session entries older than max_age days."""
67
+ rem, keep = 0, 0
48
68
  expire = int(time.time() - max_age * 24 * 60 * 60)
49
- async for key, created, ses in iter_redis(redis, clean=True):
50
- for key in SES_KEYS:
51
- if not ses.get(key):
69
+ try:
70
+ async for key, created, ses in session_iter(redis):
71
+ all_keys = all(sk in ses for sk in (expected_keys or SES_KEYS))
72
+ if created < expire or not all_keys:
73
+ rem += 1
52
74
  await redis.delete(key)
53
- continue
54
- if int(created) < expire:
55
- await redis.delete(key)
75
+ else:
76
+ keep += 1
77
+ finally:
78
+ if rem:
79
+ _LOGGER.info("Sessions removed: %s (%s total)", rem, keep)
80
+ else:
81
+ _LOGGER.debug("No sessions removed (%s total)", keep)
56
82
 
57
83
 
58
84
  def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
59
- """Create a session with a save callback."""
85
+ """Create a AsyncMSAL session.
86
+
87
+ When get_token refreshes the token retrieved from Redis, the save_cache callback
88
+ will be responsible to update the cache in Redis."""
60
89
 
61
90
  async def async_save_cache(_: dict) -> None:
62
91
  """Save the token cache to Redis."""
63
- rd2 = get_redis()
64
- try:
92
+ async with get_redis() as rd2:
65
93
  await rd2.set(key, json.dumps({"created": created, "session": session}))
66
- finally:
67
- await rd2.close()
68
94
 
69
95
  def save_cache(*args: Any) -> None:
70
96
  """Save the token cache to Redis."""
@@ -76,8 +102,11 @@ def _session_factory(key: str, created: str, session: dict) -> AsyncMSAL:
76
102
  return AsyncMSAL(session, save_cache=save_cache)
77
103
 
78
104
 
79
- async def get_session(red: Redis, email: str) -> AsyncMSAL:
105
+ async def get_session(email: str, *, redis: Optional[Redis] = None) -> AsyncMSAL:
80
106
  """Get a session from Redis."""
81
- async for key, created, session in iter_redis(red, match={"mail": email}):
82
- return _session_factory(key, created, session)
107
+ async with AsyncExitStack() as stack:
108
+ if redis is None:
109
+ redis = await stack.enter_async_context(get_redis())
110
+ async for key, created, session in session_iter(redis, match={"mail": email}):
111
+ return _session_factory(key, str(created), session)
83
112
  raise ValueError(f"Session for {email} not found")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: aiohttp-msal
3
- Version: 0.6.3
3
+ Version: 0.6.5
4
4
  Summary: Helper Library to use the Microsoft Authentication Library (MSAL) with aiohttp
5
5
  Home-page: https://github.com/kellerza/aiohttp_msal
6
6
  Author: Johann Kellerman
@@ -125,23 +125,13 @@ async def user_authorized(request: web.Request) -> web.Response:
125
125
 
126
126
  ```python
127
127
  from aiohttp_msal import ENV, AsyncMSAL
128
- from aiohttp_msal.redis_tools import clean_redis, get_redis, get_session
129
-
130
- async def get_async_msal(email: str) -> AsyncMSAL:
131
- """Clean redis and get a session."""
132
- red = get_redis()
133
- try:
134
- return await get_session(red, email)
135
- finally:
136
- await red.close()
137
-
128
+ from aiohttp_msal.redis_tools import get_session
138
129
 
139
130
  def main()
140
131
  # Uses the redis.asyncio driver to retrieve the current token
141
132
  # Will update the token_cache if a RefreshToken was used
142
- ases = asyncio.run(get_async_msal(MYEMAIL))
133
+ ases = asyncio.run(get_session(MYEMAIL))
143
134
  client = GraphClient(ases.get_token)
144
135
  # ...
145
136
  # use the Graphclient
146
- # ...
147
137
  ```
@@ -1,6 +1,6 @@
1
- aiohttp_msal/__init__.py,sha256=NfHiiJnjJer0z96e3aM5xQnJ2ICSeBmKu79JBHL7alo,3001
2
- aiohttp_msal/msal_async.py,sha256=lSwTK2utBVjhQQ921aoq34hNa0z-AJsQleSHXxh9STk,9779
3
- aiohttp_msal/redis_tools.py,sha256=I9tjkRSSAQi3S1TajsZX0fkvNgIYyOwgNGkViP1vjRA,2794
1
+ aiohttp_msal/__init__.py,sha256=hz7_nNDPT3bWxWu78vjAAWr9i60eRX0He34RRZ6DIz0,3001
2
+ aiohttp_msal/msal_async.py,sha256=Z810J2OHn7H4EevfQ7XB5L7Rks8iB4RsdFbFv5wPb3k,9952
3
+ aiohttp_msal/redis_tools.py,sha256=eEYGJTCWtpyBvmI7IAs2jInypsyzbQP_RAVQnAyRgtE,3737
4
4
  aiohttp_msal/routes.py,sha256=c-w5wHaLAYGEqZvfZ8PnzzRh60asLqdUa30lvSANdYM,8319
5
5
  aiohttp_msal/settings.py,sha256=ZZn7D6QmIyQSvuqCAoTacKRXYfopqK4P74eVdPCw-uI,1231
6
6
  aiohttp_msal/settings_base.py,sha256=pmVmzTtaGEgRh-AMGy0HdhF1JvoZhZp42G3PL_ILHLw,2892
@@ -9,9 +9,9 @@ tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  tests/test_init.py,sha256=sHmt7yNDlcu5JHrpM_gim0NeLce0NwUAMM3HAdGoo58,75
10
10
  tests/test_msal_async.py,sha256=31MCoAbUiyUhc4SkebUKpjLDHozEBko-QgEBSHjfSoM,332
11
11
  tests/test_settings.py,sha256=z-qtUs1zl5Q9NEux051eebyPnArLZ_OfZu65FKz0N4Y,333
12
- aiohttp_msal-0.6.3.dist-info/LICENSE,sha256=H1aGfkSfZFwK3q4INn9mUldOJGZy-ZXu5-65K9Glunw,1080
13
- aiohttp_msal-0.6.3.dist-info/METADATA,sha256=YWhtkOjRGbnHriO9MUEZQeYG6AKV1GBqnnN1zzdAHRY,4811
14
- aiohttp_msal-0.6.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
15
- aiohttp_msal-0.6.3.dist-info/top_level.txt,sha256=QPWOi5JtacVEdbaU5bJExc9o-cCT2Lufx0QhUpsv5_E,19
16
- aiohttp_msal-0.6.3.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
17
- aiohttp_msal-0.6.3.dist-info/RECORD,,
12
+ aiohttp_msal-0.6.5.dist-info/LICENSE,sha256=H1aGfkSfZFwK3q4INn9mUldOJGZy-ZXu5-65K9Glunw,1080
13
+ aiohttp_msal-0.6.5.dist-info/METADATA,sha256=xXoce4L-2TzT6OufblyHIsHoWYIrrsa3Qarxfi3KOT0,4565
14
+ aiohttp_msal-0.6.5.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
15
+ aiohttp_msal-0.6.5.dist-info/top_level.txt,sha256=QPWOi5JtacVEdbaU5bJExc9o-cCT2Lufx0QhUpsv5_E,19
16
+ aiohttp_msal-0.6.5.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
17
+ aiohttp_msal-0.6.5.dist-info/RECORD,,