iaptoolkit 0.2.5__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: iaptoolkit
3
- Version: 0.2.5
3
+ Version: 0.3.0
4
4
  Summary: Library of common utils for interacting with Identity-Aware Proxies
5
5
  Author: Rob Voigt
6
6
  Author-email: code@ravoigt.com
@@ -8,6 +8,7 @@ Requires-Python: >=3.11,<4.0
8
8
  Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.11
10
10
  Classifier: Programming Language :: Python :: 3.12
11
+ Classifier: Programming Language :: Python :: 3.13
11
12
  Requires-Dist: google-auth (>=2.29.0,<3.0.0)
12
13
  Requires-Dist: kvcommon (>=0.1.3,<0.2.0)
13
14
  Requires-Dist: pytest (>=7.4.4,<8.0.0)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "iaptoolkit"
3
- version = "0.2.5"
3
+ version = "0.3.0"
4
4
  description = "Library of common utils for interacting with Identity-Aware Proxies"
5
5
  authors = ["Rob Voigt <code@ravoigt.com>"]
6
6
  readme = "README.md"
@@ -16,7 +16,7 @@ Repository = "https://github.com/RAVoigt/iaptoolkit"
16
16
  # ================================
17
17
  # Tools etc.
18
18
  [tool.black]
19
- line-length = 100
19
+ line-length = 120
20
20
  target-version = ['py311']
21
21
  include = '\.pyi?$'
22
22
 
@@ -16,6 +16,7 @@ from iaptoolkit.tokens.service_account import ServiceAccount
16
16
  from iaptoolkit.tokens.structs import ResultAddTokenHeader
17
17
 
18
18
  from iaptoolkit.tokens.structs import TokenRefreshStruct
19
+ from iaptoolkit.tokens.structs import TokenStruct
19
20
  from iaptoolkit.utils.urls import is_url_safe_for_token
20
21
 
21
22
  LOG = logger.get_logger("iaptk")
@@ -38,11 +39,9 @@ class IAPToolkit:
38
39
  def sanitize_request_headers(request_headers: dict) -> dict:
39
40
  return headers.sanitize_request_headers(request_headers)
40
41
 
41
- def get_token_oidc(self, bypass_cached: bool = False) -> TokenRefreshStruct:
42
+ def get_token_oidc(self, bypass_cached: bool = False) -> TokenStruct:
42
43
  try:
43
- return ServiceAccount.get_token(
44
- iap_client_id=self._GOOGLE_IAP_CLIENT_ID, bypass_cached=bypass_cached
45
- )
44
+ return ServiceAccount.get_token(iap_client_id=self._GOOGLE_IAP_CLIENT_ID, bypass_cached=bypass_cached)
46
45
  except ServiceAccountTokenException as ex:
47
46
  LOG.debug(ex)
48
47
  raise
@@ -60,7 +59,11 @@ class IAPToolkit:
60
59
  return struct.id_token
61
60
 
62
61
  def get_token_and_add_to_headers(
63
- self, request_headers: dict, use_oauth2: bool = False, use_auth_header: bool = False
62
+ self,
63
+ request_headers: dict,
64
+ use_oauth2: bool = False,
65
+ use_auth_header: bool = False,
66
+ bypass_cached: bool = False,
64
67
  ) -> bool:
65
68
  """
66
69
  Retrieves an auth token and inserts it into the supplied request_headers dict.
@@ -72,20 +75,29 @@ class IAPToolkit:
72
75
  As a general guideline, OIDC is the assumed default approach for ServiceAccounts.
73
76
  use_auth_header: If true, use the 'Authorization' header instead of 'Proxy-Authorization'
74
77
 
78
+ Returns:
79
+ True if token retrieved from cache, False if fresh from API
80
+
75
81
 
76
82
  """
77
- if not use_oauth2:
78
- token_refresh_struct: TokenRefreshStruct = self.get_token_oidc()
83
+ id_token = None
84
+ from_cache = False
85
+ if use_oauth2:
86
+ token_refresh_struct: TokenRefreshStruct = self.get_token_oauth2(bypass_cached=bypass_cached)
87
+ id_token = token_refresh_struct.id_token
88
+ from_cache = token_refresh_struct.from_cache
79
89
  else:
80
- token_refresh_struct: TokenRefreshStruct = self.get_token_oauth2()
90
+ token_struct: TokenStruct = self.get_token_oidc(bypass_cached=bypass_cached)
91
+ id_token = token_struct.id_token
92
+ from_cache = token_struct.from_cache
81
93
 
82
94
  headers.add_token_to_request_headers(
83
95
  request_headers=request_headers,
84
- id_token=token_refresh_struct.id_token,
96
+ id_token=id_token,
85
97
  use_auth_header=use_auth_header,
86
98
  )
87
99
 
88
- return token_refresh_struct.token_is_new
100
+ return from_cache
89
101
 
90
102
  @staticmethod
91
103
  def is_url_safe_for_token(
@@ -104,6 +116,7 @@ class IAPToolkit:
104
116
  valid_domains: t.List[str] | None = None,
105
117
  use_oauth2: bool = False,
106
118
  use_auth_header: bool = False,
119
+ bypass_cached: bool = False,
107
120
  ) -> ResultAddTokenHeader:
108
121
  """
109
122
  Checks that the supplied URL is valid (i.e.; in valid_domains) and if so, retrieves a
@@ -123,10 +136,11 @@ class IAPToolkit:
123
136
  request_headers=request_headers,
124
137
  use_oauth2=use_oauth2,
125
138
  use_auth_header=use_auth_header,
139
+ bypass_cached=bypass_cached,
126
140
  )
127
141
  return ResultAddTokenHeader(token_added=True, token_is_fresh=token_is_fresh)
128
142
  else:
129
- LOG.warn(
143
+ LOG.warning(
130
144
  "URL is not approved: %s - Token will not be added to headers. Valid domains are: %s",
131
145
  url,
132
146
  valid_domains,
@@ -146,10 +160,17 @@ class IAPToolkit_OIDC(IAPToolkit):
146
160
  raise NotImplementedError("Cannot call OAuth2 methods on OIDC-only instance of IAPToolkit.")
147
161
 
148
162
  def get_token_and_add_to_headers(
149
- self, request_headers: dict, use_auth_header: bool = False, use_oauth2: bool = False,
163
+ self,
164
+ request_headers: dict,
165
+ use_auth_header: bool = False,
166
+ use_oauth2: bool = False,
167
+ bypass_cached: bool = False,
150
168
  ) -> bool:
151
169
  return super().get_token_and_add_to_headers(
152
- request_headers=request_headers, use_oauth2=use_oauth2, use_auth_header=use_auth_header
170
+ request_headers=request_headers,
171
+ use_oauth2=use_oauth2,
172
+ use_auth_header=use_auth_header,
173
+ bypass_cached=bypass_cached,
153
174
  )
154
175
 
155
176
  def check_url_and_add_token_header(
@@ -158,6 +179,7 @@ class IAPToolkit_OIDC(IAPToolkit):
158
179
  request_headers: dict,
159
180
  valid_domains: t.List[str] | None = None,
160
181
  use_auth_header: bool = False,
182
+ bypass_cached: bool = False,
161
183
  ) -> ResultAddTokenHeader:
162
184
  return super().check_url_and_add_token_header(
163
185
  url,
@@ -165,6 +187,7 @@ class IAPToolkit_OIDC(IAPToolkit):
165
187
  valid_domains=valid_domains,
166
188
  use_oauth2=False,
167
189
  use_auth_header=use_auth_header,
190
+ bypass_cached=bypass_cached,
168
191
  )
169
192
 
170
193
 
@@ -193,10 +216,17 @@ class IAPToolkit_OAuth2(IAPToolkit):
193
216
  raise NotImplementedError("Cannot call OIDC methods on OAuth2-only instance of IAPToolkit.")
194
217
 
195
218
  def get_token_and_add_to_headers(
196
- self, request_headers: dict, use_auth_header: bool = False, use_oauth2: bool = True,
219
+ self,
220
+ request_headers: dict,
221
+ use_auth_header: bool = False,
222
+ use_oauth2: bool = True,
223
+ bypass_cached: bool = False,
197
224
  ) -> bool:
198
225
  return super().get_token_and_add_to_headers(
199
- request_headers=request_headers, use_oauth2=use_oauth2, use_auth_header=use_auth_header
226
+ request_headers=request_headers,
227
+ use_oauth2=use_oauth2,
228
+ use_auth_header=use_auth_header,
229
+ bypass_cached=bypass_cached,
200
230
  )
201
231
 
202
232
  def check_url_and_add_token_header(
@@ -205,6 +235,7 @@ class IAPToolkit_OAuth2(IAPToolkit):
205
235
  request_headers: dict,
206
236
  valid_domains: t.List[str] | None = None,
207
237
  use_auth_header: bool = False,
238
+ bypass_cached: bool = False,
208
239
  ) -> ResultAddTokenHeader:
209
240
  return super().check_url_and_add_token_header(
210
241
  url=url,
@@ -212,4 +243,5 @@ class IAPToolkit_OAuth2(IAPToolkit):
212
243
  valid_domains=valid_domains,
213
244
  use_oauth2=True,
214
245
  use_auth_header=use_auth_header,
246
+ bypass_cached=bypass_cached,
215
247
  )
@@ -23,9 +23,7 @@ class TokenStorageException(TokenException):
23
23
 
24
24
 
25
25
  class ServiceAccountTokenException(TokenException):
26
- def __init__(
27
- self, message: str, google_exception: t.Union[DefaultCredentialsError, RefreshError] | None
28
- ):
26
+ def __init__(self, message: str, google_exception: t.Union[DefaultCredentialsError, RefreshError] | None):
29
27
  self.google_exception = google_exception
30
28
  credentials_env_var_value = os.environ.get(GOOGLE_CREDENTIALS_FILE_PATH)
31
29
  metadata_server_attempted = not credentials_env_var_value
@@ -34,9 +34,7 @@ def sanitize_request_headers(headers: dict) -> dict:
34
34
  return log_safe_headers
35
35
 
36
36
 
37
- def add_token_to_request_headers(
38
- request_headers: dict, id_token: str, use_auth_header: bool = False
39
- ) -> dict:
37
+ def add_token_to_request_headers(request_headers: dict, id_token: str, use_auth_header: bool = False) -> dict:
40
38
  """
41
39
  Adds Bearer token to headers dict. Modifies dict in-place.
42
40
  Returns True if added token is a fresh one, or False if token is from cache
@@ -14,6 +14,7 @@ from iaptoolkit.tokens.token_datastore import datastore
14
14
  from iaptoolkit.exceptions import ServiceAccountTokenException
15
15
  from iaptoolkit.exceptions import ServiceAccountTokenFailedRefresh
16
16
  from iaptoolkit.exceptions import ServiceAccountNoDefaultCredentials
17
+ from iaptoolkit.exceptions import TokenException
17
18
  from iaptoolkit.exceptions import TokenStorageException
18
19
 
19
20
  from .structs import TokenStruct
@@ -41,39 +42,33 @@ class ServiceAccount(object):
41
42
  def get_stored_token(iap_client_id: str) -> t.Optional[TokenStruct]:
42
43
  try:
43
44
  token_dict = datastore.get_stored_service_account_token(iap_client_id)
44
- if (
45
- not token_dict
46
- or not token_dict.get("id_token", None)
47
- or not token_dict.get("token_expiry", None)
48
- ):
45
+ if not token_dict or not token_dict.get("id_token", None) or not token_dict.get("token_expiry", None):
49
46
  LOG.debug("No stored service account token for current iap_client_id")
50
47
  return
51
48
 
52
- id_token_from_dict = token_dict.get("id_token")
53
- token_expiry_from_dict = token_dict.get("token_expiry", "")
54
-
55
- if not id_token_from_dict:
56
- LOG.warning("Invalid stored ID token")
57
- return
49
+ id_token_from_dict: str = token_dict.get("id_token", "")
50
+ token_expiry_from_dict: str = token_dict.get("token_expiry", "")
58
51
 
59
52
  token_expiry = ""
60
53
  try:
61
54
  token_expiry = datetime.datetime.fromisoformat(token_expiry_from_dict)
62
55
  except (ValueError, TypeError) as ex:
63
- LOG.debug("Invalid token expiry for current iap_client_id")
56
+ LOG.debug("Invalid token expiry for stored token - Could not parse from ISO format to datetime.")
64
57
  return
65
58
 
66
- token_struct = TokenStruct(id_token=id_token_from_dict, expiry=token_expiry)
59
+ token_struct = TokenStruct(id_token=id_token_from_dict, expiry=token_expiry, from_cache=True)
60
+ if not token_struct.valid:
61
+ LOG.debug("Stored service account token for current iap_client_id is INVALID")
62
+ return
67
63
  if token_struct.expired:
68
64
  LOG.debug("Stored service account token for current iap_client_id has EXPIRED")
69
65
  return
66
+
70
67
  return token_struct
71
68
 
72
69
  except Exception as ex:
73
70
  # Err on the side of not letting token-caching break requests, hence blanket except
74
- raise TokenStorageException(
75
- f"Exception when trying to retrieve stored token. exception={ex}"
76
- )
71
+ raise TokenStorageException(f"Exception when trying to retrieve stored token. exception={ex}")
77
72
 
78
73
  @staticmethod
79
74
  def _get_fresh_credentials(iap_client_id: str) -> GoogleIDTokenCredentials:
@@ -104,6 +99,8 @@ class ServiceAccount(object):
104
99
  def _get_fresh_token(iap_client_id: str) -> TokenStruct:
105
100
  google_credentials = ServiceAccount._get_fresh_credentials(iap_client_id)
106
101
  id_token: str = str(google_credentials.token)
102
+ if not id_token:
103
+ raise TokenException("Invalid [empty] token retrieved for Service Account.")
107
104
 
108
105
  # Google lib uses deprecated 'utcfromtimestamp' func as of v2.29.x
109
106
  # e.g.: datetime.datetime.utcfromtimestamp(payload["exp"])
@@ -111,12 +108,10 @@ class ServiceAccount(object):
111
108
  # Python datetimes assume local TZ, and we want to explicitly only work in UTC here.
112
109
  token_expiry = google_credentials.expiry.replace(tzinfo=datetime.timezone.utc)
113
110
 
114
- return TokenStruct(id_token=id_token, expiry=token_expiry)
111
+ return TokenStruct(id_token=id_token, expiry=token_expiry, from_cache=False)
115
112
 
116
113
  @staticmethod
117
- def get_token(
118
- iap_client_id: str, bypass_cached: bool = False, attempts: int = 0
119
- ) -> TokenRefreshStruct:
114
+ def get_token(iap_client_id: str, bypass_cached: bool = False, attempts: int = 0) -> TokenStruct:
120
115
  """Retrieves an OIDC token for the current environment either from environment variable or from
121
116
  metadata service.
122
117
 
@@ -140,19 +135,17 @@ class ServiceAccount(object):
140
135
  use_cache = not bypass_cached
141
136
 
142
137
  try:
143
- token_from_cache = False
144
- token_struct = (use_cache and ServiceAccount.get_stored_token(iap_client_id)) or None
145
- if use_cache and token_struct:
146
- token_from_cache = True
147
- else:
148
- token_struct = ServiceAccount._get_fresh_token(iap_client_id)
138
+ token_struct: TokenStruct | None = None
149
139
 
150
- ServiceAccount._store_token(iap_client_id, token_struct.id_token, token_struct.expiry)
140
+ if use_cache:
141
+ token_struct = ServiceAccount.get_stored_token(iap_client_id)
151
142
 
152
- token_refresh_struct = TokenRefreshStruct(
153
- id_token=token_struct.id_token, token_is_new=not token_from_cache
154
- )
155
- return token_refresh_struct
143
+ if not token_struct:
144
+ token_struct = ServiceAccount._get_fresh_token(iap_client_id)
145
+ if use_cache:
146
+ ServiceAccount._store_token(iap_client_id, token_struct.id_token, token_struct.expiry)
147
+
148
+ return token_struct
156
149
 
157
150
  except ServiceAccountTokenException as ex:
158
151
  attempts += 1
@@ -173,16 +166,14 @@ class GoogleServiceAccount(ServiceAccount):
173
166
 
174
167
  def __init__(self, iap_client_id: str) -> None:
175
168
  if not iap_client_id or not isinstance(iap_client_id, str):
176
- raise ServiceAccountTokenException(
177
- "Invalid iap_client_id for GoogleServiceAccount", google_exception=None
178
- )
169
+ raise ServiceAccountTokenException("Invalid iap_client_id for GoogleServiceAccount", google_exception=None)
179
170
  self._iap_client_id = iap_client_id
180
171
  super().__init__()
181
172
 
182
173
  def get_stored_token(self) -> t.Optional[TokenStruct]:
183
174
  return ServiceAccount.get_stored_token(self._iap_client_id)
184
175
 
185
- def get_token(self, bypass_cached: bool = False, attempts: int = 0) -> TokenRefreshStruct:
176
+ def get_token(self, bypass_cached: bool = False, attempts: int = 0) -> TokenStruct:
186
177
  return ServiceAccount.get_token(
187
178
  iap_client_id=self._iap_client_id, bypass_cached=bypass_cached, attempts=attempts
188
179
  )
@@ -8,10 +8,18 @@ from kvcommon import logger
8
8
  LOG = logger.get_logger("iaptk")
9
9
 
10
10
 
11
+ def validate_token(token: str | None) -> bool:
12
+ if not isinstance(token, str) or token.strip() == "":
13
+ return False
14
+
15
+ return True
16
+
17
+
11
18
  @dataclass(kw_only=True)
12
19
  class TokenStruct:
13
20
  id_token: str
14
21
  expiry: datetime.datetime
22
+ from_cache: bool = False
15
23
 
16
24
  @property
17
25
  def expired(self):
@@ -30,17 +38,28 @@ class TokenStruct:
30
38
  LOG.error("Exception when checking token expiry. exception=%s", ex)
31
39
  return True
32
40
 
41
+ @property
42
+ def valid(self):
43
+ return validate_token(self.id_token)
44
+
33
45
 
34
46
  @dataclass(kw_only=True)
35
47
  class TokenRefreshStruct:
36
48
  id_token: str
37
- token_is_new: bool = True
49
+ from_cache: bool = False
38
50
 
51
+ @property
52
+ def valid(self):
53
+ return validate_token(self.id_token)
39
54
 
40
55
  @dataclass(kw_only=True)
41
56
  class TokenStructOAuth2(TokenStruct):
42
57
  refresh_token: str
43
- new_refresh_token: bool = False
58
+ from_cache: bool = False
59
+
60
+ @property
61
+ def valid(self):
62
+ return validate_token(self.refresh_token)
44
63
 
45
64
 
46
65
  @dataclass(kw_only=True)
@@ -4,9 +4,11 @@ import typing as t
4
4
  from kvcommon import logger
5
5
  from kvcommon.datastore.backend import DatastoreBackend
6
6
  from kvcommon.datastore.backend import DictBackend
7
+
7
8
  # from kvcommon.datastore.backend import TOMLBackend
8
9
  from kvcommon.datastore import VersionedDatastore
9
10
 
11
+ from iaptoolkit.exceptions import TokenException
10
12
  from iaptoolkit.constants import IAPTOOLKIT_CONFIG_VERSION
11
13
 
12
14
 
@@ -38,9 +40,10 @@ class TokenDatastore(VersionedDatastore):
38
40
  return
39
41
  return token_struct_dict
40
42
 
41
- def store_service_account_token(
42
- self, iap_client_id: str, id_token: str, token_expiry: datetime.datetime
43
- ):
43
+ def store_service_account_token(self, iap_client_id: str, id_token: str, token_expiry: datetime.datetime):
44
+ if not id_token:
45
+ raise TokenException("TokenDatastore: Attempting to store invalid [empty] token")
46
+
44
47
  tokens_dict = self.get_or_create_nested_dict(self._service_account_tokens_key)
45
48
  tokens_dict[iap_client_id] = dict(id_token=id_token, token_expiry=token_expiry.isoformat())
46
49
 
@@ -62,6 +65,7 @@ class TokenDatastore(VersionedDatastore):
62
65
  # # TODO: OAuth2
63
66
  # raise NotImplementedError()
64
67
 
68
+
65
69
  datastore = TokenDatastore(DictBackend)
66
70
 
67
71
  # if PERSISTENT_DATASTORE_ENABLED:
@@ -21,7 +21,8 @@ def is_url_safe_for_token(
21
21
  f"Invalid url_parts - Expected a ParseResult - Got: "
22
22
  f"'{str(url_parts)}' (type#: {type(url_parts).__name__})"
23
23
  )
24
- if allowed_domains is not None and not isinstance(allowed_domains, (list, set, tuple)) :
24
+
25
+ if allowed_domains is not None and not isinstance(allowed_domains, (list, set, tuple)):
25
26
  raise TypeError("allowed_domains must be a list, set or tuple if not None")
26
27
 
27
28
  netloc = get_netloc_without_port_from_url_parts(url_parts)
@@ -34,8 +35,7 @@ def is_url_safe_for_token(
34
35
  for domain in allowed_domains:
35
36
  if domain == "" or not isinstance(domain, str):
36
37
  raise InvalidDomain(
37
- f"Empty or non-string domain in allowed_domains: "
38
- f"'{str(domain)}' (type#: {type(domain).__name__})"
38
+ f"Empty or non-string domain in allowed_domains: " f"'{str(domain)}' (type#: {type(domain).__name__})"
39
39
  )
40
40
 
41
41
  if netloc.endswith(domain):
File without changes
File without changes