ff-ltitoolkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. ff_ltitoolkit-0.1.0.dist-info/METADATA +98 -0
  2. ff_ltitoolkit-0.1.0.dist-info/RECORD +94 -0
  3. ff_ltitoolkit-0.1.0.dist-info/WHEEL +4 -0
  4. ff_ltitoolkit-0.1.0.dist-info/licenses/LICENSE +21 -0
  5. ltitoolkit/__init__.py +20 -0
  6. ltitoolkit/adapters/__init__.py +11 -0
  7. ltitoolkit/adapters/brightspace/__init__.py +35 -0
  8. ltitoolkit/adapters/brightspace/client.py +176 -0
  9. ltitoolkit/adapters/canvas/__init__.py +27 -0
  10. ltitoolkit/adapters/canvas/client.py +142 -0
  11. ltitoolkit/advantage/__init__.py +9 -0
  12. ltitoolkit/advantage/service.py +96 -0
  13. ltitoolkit/core/__init__.py +19 -0
  14. ltitoolkit/core/actions.py +6 -0
  15. ltitoolkit/core/assignments_grades.py +300 -0
  16. ltitoolkit/core/contrib/__init__.py +0 -0
  17. ltitoolkit/core/contrib/django/__init__.py +5 -0
  18. ltitoolkit/core/contrib/django/cookie.py +56 -0
  19. ltitoolkit/core/contrib/django/launch_data_storage/__init__.py +0 -0
  20. ltitoolkit/core/contrib/django/launch_data_storage/cache.py +10 -0
  21. ltitoolkit/core/contrib/django/lti1p3_tool_config/__init__.py +139 -0
  22. ltitoolkit/core/contrib/django/lti1p3_tool_config/admin.py +48 -0
  23. ltitoolkit/core/contrib/django/lti1p3_tool_config/apps.py +6 -0
  24. ltitoolkit/core/contrib/django/lti1p3_tool_config/migrations/0001_initial.py +168 -0
  25. ltitoolkit/core/contrib/django/lti1p3_tool_config/migrations/__init__.py +0 -0
  26. ltitoolkit/core/contrib/django/lti1p3_tool_config/models.py +185 -0
  27. ltitoolkit/core/contrib/django/message_launch.py +39 -0
  28. ltitoolkit/core/contrib/django/oidc_login.py +41 -0
  29. ltitoolkit/core/contrib/django/redirect.py +34 -0
  30. ltitoolkit/core/contrib/django/request.py +32 -0
  31. ltitoolkit/core/contrib/django/session.py +5 -0
  32. ltitoolkit/core/contrib/flask/__init__.py +7 -0
  33. ltitoolkit/core/contrib/flask/cookie.py +34 -0
  34. ltitoolkit/core/contrib/flask/launch_data_storage/__init__.py +0 -0
  35. ltitoolkit/core/contrib/flask/launch_data_storage/cache.py +9 -0
  36. ltitoolkit/core/contrib/flask/message_launch.py +32 -0
  37. ltitoolkit/core/contrib/flask/oidc_login.py +31 -0
  38. ltitoolkit/core/contrib/flask/redirect.py +34 -0
  39. ltitoolkit/core/contrib/flask/request.py +40 -0
  40. ltitoolkit/core/contrib/flask/session.py +5 -0
  41. ltitoolkit/core/contrib/py.typed +0 -0
  42. ltitoolkit/core/cookie.py +17 -0
  43. ltitoolkit/core/cookies_allowed_check.py +151 -0
  44. ltitoolkit/core/course_groups.py +115 -0
  45. ltitoolkit/core/deep_link.py +100 -0
  46. ltitoolkit/core/deep_link_resource.py +96 -0
  47. ltitoolkit/core/deployment.py +13 -0
  48. ltitoolkit/core/exception.py +16 -0
  49. ltitoolkit/core/grade.py +143 -0
  50. ltitoolkit/core/launch_data_storage/__init__.py +0 -0
  51. ltitoolkit/core/launch_data_storage/base.py +75 -0
  52. ltitoolkit/core/launch_data_storage/cache.py +43 -0
  53. ltitoolkit/core/launch_data_storage/session.py +29 -0
  54. ltitoolkit/core/lineitem.py +205 -0
  55. ltitoolkit/core/message_launch.py +828 -0
  56. ltitoolkit/core/message_validators/__init__.py +13 -0
  57. ltitoolkit/core/message_validators/abstract.py +25 -0
  58. ltitoolkit/core/message_validators/deep_link.py +34 -0
  59. ltitoolkit/core/message_validators/privacy_launch.py +40 -0
  60. ltitoolkit/core/message_validators/resource_message.py +21 -0
  61. ltitoolkit/core/message_validators/submission_review.py +45 -0
  62. ltitoolkit/core/names_roles.py +97 -0
  63. ltitoolkit/core/oidc_login.py +275 -0
  64. ltitoolkit/core/py.typed +0 -0
  65. ltitoolkit/core/redirect.py +24 -0
  66. ltitoolkit/core/registration.py +119 -0
  67. ltitoolkit/core/request.py +17 -0
  68. ltitoolkit/core/roles.py +109 -0
  69. ltitoolkit/core/service_connector.py +144 -0
  70. ltitoolkit/core/session.py +70 -0
  71. ltitoolkit/core/tool_config/__init__.py +4 -0
  72. ltitoolkit/core/tool_config/abstract.py +117 -0
  73. ltitoolkit/core/tool_config/dict.py +253 -0
  74. ltitoolkit/core/tool_config/json_file.py +100 -0
  75. ltitoolkit/core/tool_config/py.typed +0 -0
  76. ltitoolkit/core/utils.py +10 -0
  77. ltitoolkit/dynamic_registration/__init__.py +39 -0
  78. ltitoolkit/dynamic_registration/models.py +192 -0
  79. ltitoolkit/dynamic_registration/service.py +156 -0
  80. ltitoolkit/dynamic_registration/store.py +40 -0
  81. ltitoolkit/dynamic_registration/tool_conf.py +102 -0
  82. ltitoolkit/exceptions.py +42 -0
  83. ltitoolkit/fastapi/__init__.py +30 -0
  84. ltitoolkit/fastapi/cookie.py +53 -0
  85. ltitoolkit/fastapi/dynamic_registration.py +40 -0
  86. ltitoolkit/fastapi/message_launch.py +60 -0
  87. ltitoolkit/fastapi/oidc_login.py +47 -0
  88. ltitoolkit/fastapi/redirect.py +54 -0
  89. ltitoolkit/fastapi/request.py +77 -0
  90. ltitoolkit/fastapi/session.py +13 -0
  91. ltitoolkit/http.py +80 -0
  92. ltitoolkit/token/__init__.py +20 -0
  93. ltitoolkit/token/cache.py +47 -0
  94. ltitoolkit/token/service.py +165 -0
@@ -0,0 +1,119 @@
1
+ import json
2
+ import typing as t
3
+ import typing_extensions as te
4
+ from jwcrypto.jwk import JWK # type: ignore
5
+
6
+
7
+ TKey = te.TypedDict("TKey", {"kid": str, "alg": str}, total=True)
8
+ TKeySet = te.TypedDict("TKeySet", {"keys": t.List[TKey]}, total=True)
9
+
10
+
11
+ class Registration:
12
+ _issuer: t.Optional[str] = None
13
+ _client_id: t.Optional[str] = None
14
+ _key_set_url: t.Optional[str] = None
15
+ _key_set: t.Optional[TKeySet] = None
16
+ _auth_token_url: t.Optional[str] = None
17
+ _auth_login_url: t.Optional[str] = None
18
+ _tool_private_key: t.Optional[str] = None
19
+ _auth_audience: t.Optional[str] = None
20
+ _tool_public_key = None
21
+ _kid: t.Optional[str] = None
22
+
23
+ def get_issuer(self) -> t.Optional[str]:
24
+ return self._issuer
25
+
26
+ def set_issuer(self, issuer: str) -> "Registration":
27
+ self._issuer = issuer
28
+ return self
29
+
30
+ def get_client_id(self) -> t.Optional[str]:
31
+ return self._client_id
32
+
33
+ def set_client_id(self, client_id: str) -> "Registration":
34
+ self._client_id = client_id
35
+ return self
36
+
37
+ def get_key_set(self) -> t.Optional[TKeySet]:
38
+ return self._key_set
39
+
40
+ def set_key_set(self, key_set: t.Optional[TKeySet]) -> "Registration":
41
+ self._key_set = key_set
42
+ return self
43
+
44
+ def get_key_set_url(self) -> t.Optional[str]:
45
+ return self._key_set_url
46
+
47
+ def set_key_set_url(self, key_set_url: t.Optional[str]) -> "Registration":
48
+ self._key_set_url = key_set_url
49
+ return self
50
+
51
+ def get_auth_token_url(self) -> t.Optional[str]:
52
+ return self._auth_token_url
53
+
54
+ def set_auth_token_url(self, auth_token_url: str) -> "Registration":
55
+ self._auth_token_url = auth_token_url
56
+ return self
57
+
58
+ def get_auth_login_url(self) -> t.Optional[str]:
59
+ return self._auth_login_url
60
+
61
+ def set_auth_login_url(self, auth_login_url: str) -> "Registration":
62
+ self._auth_login_url = auth_login_url
63
+ return self
64
+
65
+ def get_auth_audience(self) -> t.Optional[str]:
66
+ return self._auth_audience
67
+
68
+ def set_auth_audience(self, auth_audience: str) -> "Registration":
69
+ self._auth_audience = auth_audience
70
+ return self
71
+
72
+ def get_tool_private_key(self) -> t.Optional[str]:
73
+ return self._tool_private_key
74
+
75
+ def set_tool_private_key(self, tool_private_key: str) -> "Registration":
76
+ self._tool_private_key = tool_private_key
77
+ return self
78
+
79
+ def get_tool_public_key(self):
80
+ return self._tool_public_key
81
+
82
+ def set_tool_public_key(self, tool_public_key) -> "Registration":
83
+ self._tool_public_key = tool_public_key
84
+ return self
85
+
86
+ @classmethod
87
+ def get_jwk(cls, public_key: str) -> t.Mapping[str, t.Any]:
88
+ jwk_obj = JWK.from_pem(public_key.encode("utf-8"))
89
+ public_jwk = json.loads(jwk_obj.export_public())
90
+ public_jwk["alg"] = "RS256"
91
+ public_jwk["use"] = "sig"
92
+ return public_jwk
93
+
94
+ def get_jwks(self) -> t.List[t.Mapping[str, t.Any]]:
95
+ keys = []
96
+ public_key = self.get_tool_public_key()
97
+ if public_key:
98
+ keys.append(Registration.get_jwk(public_key))
99
+ return keys
100
+
101
+ def set_kid(self, kid: str) -> "Registration":
102
+ """Pin the ``kid`` used in signed client assertions.
103
+
104
+ Set this when the tool publishes its JWKS with an explicit key id (e.g.
105
+ a named ``kid`` rather than a PEM-derived thumbprint), so the assertion
106
+ header ``kid`` matches a key the platform can resolve in the JWKS. When
107
+ unset, ``get_kid`` falls back to the public key's own ``kid`` (if any).
108
+ """
109
+ self._kid = kid
110
+ return self
111
+
112
+ def get_kid(self) -> t.Optional[str]:
113
+ if self._kid:
114
+ return self._kid
115
+ key = self.get_tool_public_key()
116
+ if key:
117
+ jwk = Registration.get_jwk(key)
118
+ return jwk.get("kid") if jwk else None
119
+ return None
@@ -0,0 +1,17 @@
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class Request:
5
+ __metaclass__ = ABCMeta
6
+
7
+ @property
8
+ def session(self):
9
+ raise NotImplementedError
10
+
11
+ @abstractmethod
12
+ def is_secure(self) -> bool:
13
+ raise NotImplementedError
14
+
15
+ @abstractmethod
16
+ def get_param(self, key: str) -> str:
17
+ raise NotImplementedError
@@ -0,0 +1,109 @@
1
+ from abc import ABCMeta
2
+ import typing as t
3
+ import typing_extensions as te
4
+
5
+
6
+ class RoleType:
7
+ SYSTEM: te.Final = "system"
8
+ INSTITUTION: te.Final = "institution"
9
+ CONTEXT: te.Final = "membership"
10
+
11
+
12
+ class AbstractRole:
13
+ __metaclass__ = ABCMeta
14
+ _base_prefix: str = "http://purl.imsglobal.org/vocab/lis/v2"
15
+ _role_types = [RoleType.SYSTEM, RoleType.INSTITUTION, RoleType.CONTEXT]
16
+ _jwt_roles: t.List[str] = []
17
+ _common_roles: t.Optional[t.Tuple] = None
18
+ _system_roles: t.Optional[t.Tuple] = None
19
+ _institution_roles: t.Optional[t.Tuple] = None
20
+ _context_roles: t.Optional[t.Tuple] = None
21
+
22
+ def __init__(self, jwt_body):
23
+ self._jwt_roles = jwt_body.get(
24
+ "https://purl.imsglobal.org/spec/lti/claim/roles", []
25
+ )
26
+
27
+ def check(self) -> bool:
28
+ for role_str in self._jwt_roles:
29
+ role_name, role_type = self.parse_role_str(role_str)
30
+ res = self._check_access(role_name, role_type)
31
+ if res:
32
+ return True
33
+ return False
34
+
35
+ def _check_access(self, role_name: str, role_type: t.Optional[str] = None):
36
+ return bool(
37
+ (
38
+ self._system_roles
39
+ and role_type == RoleType.SYSTEM
40
+ and role_name in self._system_roles
41
+ )
42
+ or (
43
+ self._institution_roles
44
+ and role_type == RoleType.INSTITUTION
45
+ and role_name in self._institution_roles
46
+ )
47
+ or (
48
+ self._context_roles
49
+ and role_type == RoleType.CONTEXT
50
+ and role_name in self._context_roles
51
+ )
52
+ or (
53
+ self._common_roles
54
+ and role_type is None
55
+ and role_name in self._common_roles
56
+ )
57
+ )
58
+
59
+ def parse_role_str(self, role_str: str) -> t.Tuple[str, t.Optional[str]]:
60
+ if role_str.startswith(self._base_prefix):
61
+ role = role_str[len(self._base_prefix) :]
62
+ role_parts = role.split("/")
63
+ role_name_parts = role.split("#")
64
+
65
+ if len(role_parts) > 1 and len(role_name_parts) > 1:
66
+ role_type = role_parts[1]
67
+ role_name = role_name_parts[1]
68
+ if role_type in self._role_types:
69
+ return role_name, role_type
70
+ return role_name, None
71
+ return role_str, None
72
+
73
+
74
+ class StaffRole(AbstractRole):
75
+ _system_roles = ("Administrator", "SysAdmin")
76
+ _institution_roles = ("Faculty", "SysAdmin", "Staff", "Instructor")
77
+
78
+
79
+ class StudentRole(AbstractRole):
80
+ _common_roles = ("Learner", "Member", "User")
81
+ _system_roles = ("User",)
82
+ _institution_roles = ("Student", "Learner", "Member", "ProspectiveStudent", "User")
83
+ _context_roles = ("Learner", "Member")
84
+
85
+
86
+ class TeacherRole(AbstractRole):
87
+ _common_roles = ("Instructor", "Administrator")
88
+ _context_roles = ("Instructor", "Administrator")
89
+
90
+
91
+ class TeachingAssistantRole(AbstractRole):
92
+ _context_roles = ("TeachingAssistant",)
93
+
94
+
95
+ class DesignerRole(AbstractRole):
96
+ _common_roles = ("ContentDeveloper",)
97
+ _context_roles = ("ContentDeveloper",)
98
+
99
+
100
+ class ObserverRole(AbstractRole):
101
+ _common_roles = ("Mentor",)
102
+ _context_roles = ("Mentor",)
103
+
104
+
105
+ class TransientRole(AbstractRole):
106
+ _common_roles = ("Transient",)
107
+ _system_roles = ("Transient",)
108
+ _institution_roles = ("Transient",)
109
+ _context_roles = ("Transient",)
@@ -0,0 +1,144 @@
1
+ import hashlib
2
+ import re
3
+ import time
4
+ import typing as t
5
+ import uuid
6
+
7
+ import jwt # type: ignore
8
+ import requests
9
+ import typing_extensions as te
10
+ from .exception import LtiServiceException
11
+ from .registration import Registration
12
+
13
+ TServiceConnectorResponse = te.TypedDict(
14
+ "TServiceConnectorResponse",
15
+ {
16
+ "headers": t.Union[t.Dict[str, str], t.MutableMapping[str, str]],
17
+ "body": t.Union[None, int, float, t.List[object], t.Dict[str, object], str],
18
+ "next_page_url": t.Optional[str],
19
+ },
20
+ )
21
+
22
+
23
+ REQUESTS_USER_AGENT = "PyLTI1p3-client"
24
+
25
+
26
+ class ServiceConnector:
27
+ _registration: Registration
28
+ _access_tokens: t.Dict[str, str]
29
+
30
+ def __init__(
31
+ self,
32
+ registration: Registration,
33
+ requests_session: t.Optional[requests.Session] = None,
34
+ ):
35
+ self._registration = registration
36
+ self._access_tokens = {}
37
+ if requests_session:
38
+ self._requests_session = requests_session
39
+ else:
40
+ self._requests_session = requests.Session()
41
+ self._requests_session.headers["User-Agent"] = REQUESTS_USER_AGENT
42
+
43
+ def get_access_token(self, scopes: t.Sequence[str]) -> str:
44
+ # Don't fetch the same key more than once
45
+ scopes = sorted(scopes)
46
+ scopes_str: str = "|".join(scopes)
47
+ scopes_bytes = scopes_str.encode("utf-8")
48
+
49
+ scope_key = hashlib.md5(scopes_bytes).hexdigest()
50
+
51
+ if scope_key in self._access_tokens:
52
+ return self._access_tokens[scope_key]
53
+
54
+ # Build up JWT to exchange for an auth token
55
+ client_id = self._registration.get_client_id()
56
+ assert client_id is not None, "client_id should be set at this point"
57
+ auth_url = self._registration.get_auth_token_url()
58
+ assert auth_url is not None, "auth_url should be set at this point"
59
+ auth_audience = self._registration.get_auth_audience()
60
+ aud = auth_audience if auth_audience else auth_url
61
+
62
+ jwt_claim: t.Dict[str, t.Union[str, int]] = {
63
+ "iss": str(client_id),
64
+ "sub": str(client_id),
65
+ "aud": str(aud),
66
+ "iat": int(time.time()) - 5,
67
+ "exp": int(time.time()) + 60,
68
+ "jti": "lti-service-token-" + str(uuid.uuid4()),
69
+ }
70
+ headers = {}
71
+ kid = self._registration.get_kid()
72
+ if kid:
73
+ headers = {"kid": kid}
74
+
75
+ # Sign the JWT with our private key (given by the platform on registration)
76
+ private_key = self._registration.get_tool_private_key()
77
+ assert private_key is not None, "Private key should be set at this point"
78
+ jwt_val = self.encode_jwt(jwt_claim, private_key, headers)
79
+
80
+ auth_request = {
81
+ "grant_type": "client_credentials",
82
+ "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
83
+ "client_assertion": jwt_val,
84
+ "scope": " ".join(scopes),
85
+ }
86
+
87
+ # Make request to get auth token
88
+ r = self._requests_session.post(auth_url, data=auth_request)
89
+ if not r.ok:
90
+ raise LtiServiceException(r)
91
+ response = r.json()
92
+
93
+ self._access_tokens[scope_key] = response["access_token"]
94
+ return self._access_tokens[scope_key]
95
+
96
+ def encode_jwt(
97
+ self,
98
+ message: t.Dict[str, t.Union[str, int]],
99
+ private_key: str,
100
+ headers: t.Dict[str, str],
101
+ ) -> str:
102
+ jwt_val = jwt.encode(message, private_key, algorithm="RS256", headers=headers)
103
+ if isinstance(jwt_val, bytes):
104
+ return jwt_val.decode("utf-8")
105
+ return jwt_val
106
+
107
+ def make_service_request(
108
+ self,
109
+ scopes: t.Sequence[str],
110
+ url: str,
111
+ is_post: bool = False,
112
+ data: t.Optional[str] = None,
113
+ content_type: str = "application/json",
114
+ accept: str = "application/json",
115
+ case_insensitive_headers: bool = False,
116
+ ) -> TServiceConnectorResponse:
117
+ access_token = self.get_access_token(scopes)
118
+ headers = {"Authorization": "Bearer " + access_token, "Accept": accept}
119
+
120
+ if is_post:
121
+ headers["Content-Type"] = content_type
122
+ post_data = data or None
123
+ r = self._requests_session.post(url, data=post_data, headers=headers)
124
+ else:
125
+ r = self._requests_session.get(url, headers=headers)
126
+
127
+ if not r.ok:
128
+ raise LtiServiceException(r)
129
+
130
+ next_page_url = None
131
+ link_header = r.headers.get("link", "")
132
+ if link_header:
133
+ match = re.search(
134
+ r'<([^>]*)>;\s*rel="next"',
135
+ link_header.replace("\n", " ").lower().strip(),
136
+ )
137
+ if match:
138
+ next_page_url = match.group(1)
139
+
140
+ return {
141
+ "headers": r.headers if case_insensitive_headers else dict(r.headers),
142
+ "body": r.json() if r.content else None,
143
+ "next_page_url": next_page_url if next_page_url else None,
144
+ }
@@ -0,0 +1,70 @@
1
+ import typing as t
2
+ from .launch_data_storage.session import SessionDataStorage
3
+ from .request import Request
4
+ from .launch_data_storage.base import LaunchDataStorage
5
+
6
+
7
+ TStateParams = t.Dict[str, object]
8
+ TJwtBody = t.Mapping[str, t.Any]
9
+
10
+
11
+ class SessionService:
12
+ data_storage: LaunchDataStorage[t.Any]
13
+ _launch_data_lifetime = 86400
14
+ _session_prefix = "lti1p3"
15
+
16
+ def __init__(self, request: Request):
17
+ self.data_storage = SessionDataStorage()
18
+ self.data_storage.set_request(request)
19
+
20
+ def _get_key(
21
+ self, key: str, nonce: t.Optional[str] = None, add_prefix: bool = True
22
+ ):
23
+ return (
24
+ ((self._session_prefix + "-") if add_prefix else "")
25
+ + key
26
+ + (("-" + nonce) if nonce else "")
27
+ )
28
+
29
+ def _set_value(self, key: str, value: object):
30
+ self.data_storage.set_value(key, value, exp=self._launch_data_lifetime)
31
+
32
+ def _get_value(self, key: str) -> t.Any:
33
+ return self.data_storage.get_value(key)
34
+
35
+ def get_launch_data(self, key: str) -> TJwtBody:
36
+ return self._get_value(self._get_key(key, add_prefix=False))
37
+
38
+ def save_launch_data(self, key: str, jwt_body: TJwtBody):
39
+ self._set_value(self._get_key(key, add_prefix=False), jwt_body)
40
+
41
+ def save_nonce(self, nonce: str):
42
+ self._set_value(self._get_key("nonce", nonce), True)
43
+
44
+ def check_nonce(self, nonce: str) -> bool:
45
+ nonce_key = self._get_key("nonce", nonce)
46
+ return self.data_storage.check_value(nonce_key)
47
+
48
+ def save_state_params(self, state: str, params: TStateParams):
49
+ self._set_value(self._get_key(state), params)
50
+
51
+ def get_state_params(self, state: str) -> TStateParams:
52
+ return self._get_value(self._get_key(state))
53
+
54
+ def set_state_valid(self, state: str, id_token_hash: str):
55
+ return self._set_value(self._get_key(state + "-id-token-hash"), id_token_hash)
56
+
57
+ def check_state_is_valid(self, state: str, id_token_hash: str) -> bool:
58
+ return self._get_value(self._get_key(state + "-id-token-hash")) == id_token_hash
59
+
60
+ def set_data_storage(self, data_storage: LaunchDataStorage[t.Any]):
61
+ self.data_storage = data_storage
62
+
63
+ def set_launch_data_lifetime(self, time_sec: int):
64
+ if self.data_storage.can_set_keys_expiration():
65
+ self._launch_data_lifetime = time_sec
66
+ else:
67
+ raise Exception(
68
+ f"{self.data_storage.__class__.__name__} launch storage doesn't support "
69
+ f"manual change expiration of the keys"
70
+ )
@@ -0,0 +1,4 @@
1
+ # flake8: noqa
2
+ from .abstract import ToolConfAbstract
3
+ from .dict import ToolConfDict
4
+ from .json_file import ToolConfJsonFile
@@ -0,0 +1,117 @@
1
+ import typing as t
2
+ from abc import ABCMeta, abstractmethod
3
+ import typing_extensions as te
4
+ from ..deployment import Deployment
5
+ from ..registration import Registration
6
+ from ..request import Request
7
+
8
+
9
+ REQ = t.TypeVar("REQ", bound=Request)
10
+
11
+
12
+ class IssuerToClientRelation:
13
+ ONE_CLIENT_ID_PER_ISSUER: te.Final = "one-issuer-one-client-id"
14
+ MANY_CLIENTS_IDS_PER_ISSUER: te.Final = "one-issuer-many-client-ids"
15
+
16
+
17
+ class ToolConfAbstract(t.Generic[REQ]):
18
+ __metaclass__ = ABCMeta
19
+ issuers_relation_types: t.MutableMapping[str, str] = {}
20
+
21
+ def check_iss_has_one_client(self, iss: str) -> bool:
22
+ """
23
+ Two methods check_iss_has_one_client / check_iss_has_many_clients are needed for the the backward compatibility
24
+ with the previous versions of the library (1.4.0 and early) where ToolConfDict supported only client_id per iss.
25
+ Should return False for all new ToolConf-s
26
+ """
27
+ iss_type = self.issuers_relation_types.get(
28
+ iss, IssuerToClientRelation.ONE_CLIENT_ID_PER_ISSUER
29
+ )
30
+ return iss_type == IssuerToClientRelation.ONE_CLIENT_ID_PER_ISSUER
31
+
32
+ def check_iss_has_many_clients(self, iss: str) -> bool:
33
+ """
34
+ Should return True for all new ToolConf-s
35
+ """
36
+ iss_type = self.issuers_relation_types.get(
37
+ iss, IssuerToClientRelation.ONE_CLIENT_ID_PER_ISSUER
38
+ )
39
+ return iss_type == IssuerToClientRelation.MANY_CLIENTS_IDS_PER_ISSUER
40
+
41
+ def set_iss_has_one_client(self, iss: str):
42
+ self.issuers_relation_types[
43
+ iss
44
+ ] = IssuerToClientRelation.ONE_CLIENT_ID_PER_ISSUER
45
+
46
+ def set_iss_has_many_clients(self, iss: str):
47
+ self.issuers_relation_types[
48
+ iss
49
+ ] = IssuerToClientRelation.MANY_CLIENTS_IDS_PER_ISSUER
50
+
51
+ def find_registration(self, iss: str, *args, **kwargs) -> Registration:
52
+ """
53
+ Backward compatibility method
54
+ """
55
+ return self.find_registration_by_issuer(iss, *args, **kwargs)
56
+
57
+ @abstractmethod
58
+ def find_registration_by_issuer(self, iss: str, *args, **kwargs) -> Registration:
59
+ """
60
+ Find registration in case if iss has only one client id, i.e
61
+ in case of { ... "iss": { ... "client_id: "client" ... }, ... } config.
62
+
63
+ You may skip implementation of this method in case if all iss in your config could have more than one client id.
64
+ """
65
+ raise NotImplementedError
66
+
67
+ @abstractmethod
68
+ def find_registration_by_params(
69
+ self, iss: str, client_id: str, *args, **kwargs
70
+ ) -> Registration:
71
+ """
72
+ Find registration in case if iss has many client ids, i.e
73
+ in case of { ... "iss": [ { ... "client_id: "client1" ... }, { ... "client_id: "client2" ... } ], ... } config.
74
+
75
+ You may skip implementation of this method in case if all iss in your config couldn't have more than one
76
+ client id, but it is outdated and not recommended way of storing configuration.
77
+ """
78
+ raise NotImplementedError
79
+
80
+ @abstractmethod
81
+ def find_deployment(self, iss: str, deployment_id: str) -> t.Optional[Deployment]:
82
+ """
83
+ Find deployment in case if iss has only one client id, i.e
84
+ in case of { ... "iss": { ... "client_id: "client" ... }, ... } config.
85
+
86
+ You may skip implementation of this method in case if all iss in your config could have more than one client id.
87
+ """
88
+ raise NotImplementedError
89
+
90
+ @abstractmethod
91
+ def find_deployment_by_params(
92
+ self, iss: str, deployment_id: str, client_id: str, *args, **kwargs
93
+ ) -> t.Optional[Deployment]:
94
+ """
95
+ Find deployment in case if iss has many client ids, i.e
96
+ in case of { ... "iss": [ { ... "client_id: "client1" ... }, { ... "client_id: "client2" ... } ], ... } config.
97
+
98
+ You may skip implementation of this method in case if all iss in your config couldn't have more than one
99
+ client id, but it is outdated and not recommended way of storing configuration.
100
+ """
101
+ raise NotImplementedError
102
+
103
+ def get_jwks(
104
+ self, iss: t.Optional[str] = None, client_id: t.Optional[str] = None, **kwargs
105
+ ):
106
+ keys: t.List[t.Mapping[str, t.Any]] = []
107
+ if iss:
108
+ if self.check_iss_has_one_client(iss):
109
+ reg = self.find_registration(iss)
110
+ elif self.check_iss_has_many_clients(iss):
111
+ if not client_id:
112
+ raise Exception("client_id is not specified")
113
+ reg = self.find_registration_by_params(iss, client_id, **kwargs)
114
+ else:
115
+ raise Exception("Invalid issuer relation type")
116
+ keys = reg.get_jwks()
117
+ return {"keys": keys}