splunk-soar-sdk 3.6.1__py3-none-any.whl → 3.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- soar_sdk/actions_manager.py +40 -0
- soar_sdk/app.py +52 -4
- soar_sdk/asset.py +61 -69
- soar_sdk/asset_state.py +13 -3
- soar_sdk/auth/__init__.py +41 -0
- soar_sdk/auth/client.py +540 -0
- soar_sdk/auth/factories.py +120 -0
- soar_sdk/auth/flows.py +172 -0
- soar_sdk/auth/httpx_auth.py +97 -0
- soar_sdk/auth/models.py +101 -0
- soar_sdk/cli/package/cli.py +6 -4
- soar_sdk/shims/phantom/base_connector.py +7 -0
- {splunk_soar_sdk-3.6.1.dist-info → splunk_soar_sdk-3.8.0.dist-info}/METADATA +3 -1
- {splunk_soar_sdk-3.6.1.dist-info → splunk_soar_sdk-3.8.0.dist-info}/RECORD +17 -11
- {splunk_soar_sdk-3.6.1.dist-info → splunk_soar_sdk-3.8.0.dist-info}/WHEEL +0 -0
- {splunk_soar_sdk-3.6.1.dist-info → splunk_soar_sdk-3.8.0.dist-info}/entry_points.txt +0 -0
- {splunk_soar_sdk-3.6.1.dist-info → splunk_soar_sdk-3.8.0.dist-info}/licenses/LICENSE +0 -0
soar_sdk/auth/client.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import secrets
|
|
4
|
+
import urllib.parse
|
|
5
|
+
import uuid
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from authlib.oauth2.rfc7636 import ( # type: ignore[import-untyped]
|
|
10
|
+
create_s256_code_challenge,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from soar_sdk.auth.models import (
|
|
14
|
+
CertificateCredentials,
|
|
15
|
+
OAuthConfig,
|
|
16
|
+
OAuthGrantType,
|
|
17
|
+
OAuthSession,
|
|
18
|
+
OAuthState,
|
|
19
|
+
OAuthToken,
|
|
20
|
+
)
|
|
21
|
+
from soar_sdk.logging import getLogger
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from soar_sdk.asset_state import AssetState
|
|
25
|
+
|
|
26
|
+
logger = getLogger()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OAuthClientError(Exception):
|
|
30
|
+
"""Base exception for OAuth client errors."""
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TokenExpiredError(OAuthClientError):
|
|
34
|
+
"""Raised when an access token has expired and cannot be refreshed."""
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AuthorizationRequiredError(OAuthClientError):
|
|
38
|
+
"""Raised when user authorization is required to obtain tokens."""
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TokenRefreshError(OAuthClientError):
|
|
42
|
+
"""Raised when token refresh fails."""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ConfigurationChangedError(OAuthClientError):
|
|
46
|
+
"""Raised when OAuth client credentials have changed, requiring re-authorization."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SOARAssetOAuthClient:
|
|
50
|
+
"""Complete OAuth 2.0 client for SOAR asset authentication."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
config: OAuthConfig,
|
|
55
|
+
auth_state: AssetState,
|
|
56
|
+
*,
|
|
57
|
+
http_client: httpx.Client | None = None,
|
|
58
|
+
verify_ssl: bool = True,
|
|
59
|
+
timeout: float = 30.0,
|
|
60
|
+
) -> None:
|
|
61
|
+
self._config = config
|
|
62
|
+
self._auth_state = auth_state
|
|
63
|
+
self._timeout = timeout
|
|
64
|
+
self._verify_ssl = verify_ssl
|
|
65
|
+
|
|
66
|
+
self._http_client = http_client or httpx.Client(
|
|
67
|
+
verify=verify_ssl,
|
|
68
|
+
timeout=timeout,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def config(self) -> OAuthConfig:
|
|
73
|
+
"""Return the OAuth configuration."""
|
|
74
|
+
return self._config
|
|
75
|
+
|
|
76
|
+
def _load_state(self) -> OAuthState:
|
|
77
|
+
"""Load OAuth state from asset storage."""
|
|
78
|
+
state_data = self._auth_state.get_all()
|
|
79
|
+
if not state_data:
|
|
80
|
+
return OAuthState()
|
|
81
|
+
|
|
82
|
+
oauth_data = state_data.get("oauth")
|
|
83
|
+
if oauth_data is None:
|
|
84
|
+
return OAuthState()
|
|
85
|
+
|
|
86
|
+
if isinstance(oauth_data, dict):
|
|
87
|
+
return OAuthState.model_validate(oauth_data)
|
|
88
|
+
return OAuthState()
|
|
89
|
+
|
|
90
|
+
def _save_state(self, state: OAuthState) -> None:
|
|
91
|
+
"""Save OAuth state to asset storage."""
|
|
92
|
+
current = self._auth_state.get_all()
|
|
93
|
+
current["oauth"] = state.model_dump(mode="json", exclude_none=True)
|
|
94
|
+
self._auth_state.put_all(current)
|
|
95
|
+
|
|
96
|
+
def _clear_tokens(self) -> None:
|
|
97
|
+
"""Clear stored tokens from auth state."""
|
|
98
|
+
state = self._load_state()
|
|
99
|
+
state.token = None
|
|
100
|
+
self._save_state(state)
|
|
101
|
+
|
|
102
|
+
def _check_client_id_changed(self) -> bool:
|
|
103
|
+
"""Check if the client_id has changed since tokens were stored."""
|
|
104
|
+
state = self._load_state()
|
|
105
|
+
if state.client_id is None:
|
|
106
|
+
return False
|
|
107
|
+
return state.client_id != self._config.client_id
|
|
108
|
+
|
|
109
|
+
def get_stored_token(self) -> OAuthToken | None:
|
|
110
|
+
"""Retrieve the stored OAuth token if available."""
|
|
111
|
+
if self._check_client_id_changed():
|
|
112
|
+
self._clear_tokens()
|
|
113
|
+
raise ConfigurationChangedError(
|
|
114
|
+
"OAuth client credentials have changed. Re-authorization required."
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
state = self._load_state()
|
|
118
|
+
return state.token
|
|
119
|
+
|
|
120
|
+
def get_valid_token(self, auto_refresh: bool = True) -> OAuthToken:
|
|
121
|
+
"""Get a valid access token, refreshing if necessary."""
|
|
122
|
+
token = self.get_stored_token()
|
|
123
|
+
|
|
124
|
+
if token is None:
|
|
125
|
+
raise AuthorizationRequiredError(
|
|
126
|
+
"No OAuth token available. Authorization is required."
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if not token.is_expired():
|
|
130
|
+
return token
|
|
131
|
+
|
|
132
|
+
if not auto_refresh:
|
|
133
|
+
raise TokenExpiredError("Access token has expired.")
|
|
134
|
+
|
|
135
|
+
if token.refresh_token is None:
|
|
136
|
+
raise TokenExpiredError(
|
|
137
|
+
"Access token has expired and no refresh token is available."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return self.refresh_token(token.refresh_token)
|
|
141
|
+
|
|
142
|
+
def refresh_token(self, refresh_token: str) -> OAuthToken:
|
|
143
|
+
"""Refresh an access token using a refresh token."""
|
|
144
|
+
try:
|
|
145
|
+
response = self._http_client.post(
|
|
146
|
+
self._config.token_endpoint,
|
|
147
|
+
data={
|
|
148
|
+
"grant_type": OAuthGrantType.REFRESH_TOKEN.value,
|
|
149
|
+
"client_id": self._config.client_id,
|
|
150
|
+
"client_secret": self._config.client_secret,
|
|
151
|
+
"refresh_token": refresh_token,
|
|
152
|
+
},
|
|
153
|
+
timeout=self._timeout,
|
|
154
|
+
)
|
|
155
|
+
response.raise_for_status()
|
|
156
|
+
token_data = response.json()
|
|
157
|
+
|
|
158
|
+
except httpx.HTTPStatusError as e:
|
|
159
|
+
error_detail = ""
|
|
160
|
+
try:
|
|
161
|
+
error_body = e.response.json()
|
|
162
|
+
error_detail = error_body.get("error_description", str(error_body))
|
|
163
|
+
except Exception:
|
|
164
|
+
error_detail = e.response.text
|
|
165
|
+
|
|
166
|
+
raise TokenRefreshError(f"Token refresh failed: {error_detail}") from e
|
|
167
|
+
|
|
168
|
+
except httpx.RequestError as e:
|
|
169
|
+
raise TokenRefreshError(f"Token refresh request failed: {e}") from e
|
|
170
|
+
|
|
171
|
+
if "refresh_token" not in token_data:
|
|
172
|
+
token_data["refresh_token"] = refresh_token
|
|
173
|
+
|
|
174
|
+
new_token = OAuthToken.model_validate(token_data)
|
|
175
|
+
self._store_token(new_token)
|
|
176
|
+
|
|
177
|
+
return new_token
|
|
178
|
+
|
|
179
|
+
def _store_token(self, token: OAuthToken) -> None:
|
|
180
|
+
"""Store a token in the auth state."""
|
|
181
|
+
state = self._load_state()
|
|
182
|
+
state.token = token
|
|
183
|
+
state.client_id = self._config.client_id
|
|
184
|
+
self._save_state(state)
|
|
185
|
+
|
|
186
|
+
def fetch_token_with_client_credentials(
|
|
187
|
+
self,
|
|
188
|
+
*,
|
|
189
|
+
extra_params: dict[str, Any] | None = None,
|
|
190
|
+
) -> OAuthToken:
|
|
191
|
+
"""Fetch an access token using client credentials grant."""
|
|
192
|
+
data: dict[str, Any] = {
|
|
193
|
+
"grant_type": OAuthGrantType.CLIENT_CREDENTIALS.value,
|
|
194
|
+
"client_id": self._config.client_id,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if self._config.client_secret:
|
|
198
|
+
data["client_secret"] = self._config.client_secret
|
|
199
|
+
|
|
200
|
+
scope = self._config.get_scope_string()
|
|
201
|
+
if scope:
|
|
202
|
+
data["scope"] = scope
|
|
203
|
+
|
|
204
|
+
if extra_params:
|
|
205
|
+
data.update(extra_params)
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
response = self._http_client.post(
|
|
209
|
+
self._config.token_endpoint,
|
|
210
|
+
data=data,
|
|
211
|
+
timeout=self._timeout,
|
|
212
|
+
)
|
|
213
|
+
response.raise_for_status()
|
|
214
|
+
token_data = response.json()
|
|
215
|
+
|
|
216
|
+
except httpx.HTTPStatusError as e:
|
|
217
|
+
error_detail = self._extract_error_detail(e.response)
|
|
218
|
+
raise OAuthClientError(
|
|
219
|
+
f"Client credentials token request failed: {error_detail}"
|
|
220
|
+
) from e
|
|
221
|
+
|
|
222
|
+
except httpx.RequestError as e:
|
|
223
|
+
raise OAuthClientError(
|
|
224
|
+
f"Client credentials token request failed: {e}"
|
|
225
|
+
) from e
|
|
226
|
+
|
|
227
|
+
token = OAuthToken.model_validate(token_data)
|
|
228
|
+
self._store_token(token)
|
|
229
|
+
|
|
230
|
+
return token
|
|
231
|
+
|
|
232
|
+
def create_authorization_url(
|
|
233
|
+
self,
|
|
234
|
+
asset_id: str,
|
|
235
|
+
*,
|
|
236
|
+
use_pkce: bool = True,
|
|
237
|
+
extra_params: dict[str, Any] | None = None,
|
|
238
|
+
) -> tuple[str, OAuthSession]:
|
|
239
|
+
"""Create a sessioned authorization URL for the authorization code flow."""
|
|
240
|
+
if not self._config.authorization_endpoint:
|
|
241
|
+
raise OAuthClientError(
|
|
242
|
+
"authorization_endpoint is required for authorization code flow"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
session_id = str(uuid.uuid4())
|
|
246
|
+
state_value = urllib.parse.urlencode(
|
|
247
|
+
{
|
|
248
|
+
"asset_id": asset_id,
|
|
249
|
+
"session_id": session_id,
|
|
250
|
+
}
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
session = OAuthSession(
|
|
254
|
+
session_id=session_id,
|
|
255
|
+
asset_id=asset_id,
|
|
256
|
+
state=state_value,
|
|
257
|
+
auth_pending=True,
|
|
258
|
+
auth_complete=False,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
params: dict[str, Any] = {
|
|
262
|
+
"response_type": "code",
|
|
263
|
+
"client_id": self._config.client_id,
|
|
264
|
+
"state": state_value,
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
if self._config.redirect_uri:
|
|
268
|
+
params["redirect_uri"] = self._config.redirect_uri
|
|
269
|
+
|
|
270
|
+
scope = self._config.get_scope_string()
|
|
271
|
+
if scope:
|
|
272
|
+
params["scope"] = scope
|
|
273
|
+
|
|
274
|
+
if use_pkce:
|
|
275
|
+
code_verifier = secrets.token_urlsafe(32)
|
|
276
|
+
code_challenge = create_s256_code_challenge(code_verifier)
|
|
277
|
+
params["code_challenge"] = code_challenge
|
|
278
|
+
params["code_challenge_method"] = "S256"
|
|
279
|
+
session.code_verifier = code_verifier
|
|
280
|
+
|
|
281
|
+
if extra_params:
|
|
282
|
+
params.update(extra_params)
|
|
283
|
+
|
|
284
|
+
auth_url = (
|
|
285
|
+
f"{self._config.authorization_endpoint}?{urllib.parse.urlencode(params)}"
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
state = self._load_state()
|
|
289
|
+
state.session = session
|
|
290
|
+
self._save_state(state)
|
|
291
|
+
|
|
292
|
+
return auth_url, session
|
|
293
|
+
|
|
294
|
+
def fetch_token_with_authorization_code(
|
|
295
|
+
self,
|
|
296
|
+
code: str,
|
|
297
|
+
*,
|
|
298
|
+
code_verifier: str | None = None,
|
|
299
|
+
extra_params: dict[str, Any] | None = None,
|
|
300
|
+
) -> OAuthToken:
|
|
301
|
+
"""Exchange an authorization code for an access token."""
|
|
302
|
+
state = self._load_state()
|
|
303
|
+
session = state.session
|
|
304
|
+
|
|
305
|
+
if code_verifier is None and session and session.code_verifier:
|
|
306
|
+
code_verifier = session.code_verifier
|
|
307
|
+
|
|
308
|
+
data: dict[str, Any] = {
|
|
309
|
+
"grant_type": OAuthGrantType.AUTHORIZATION_CODE.value,
|
|
310
|
+
"client_id": self._config.client_id,
|
|
311
|
+
"code": code,
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
if self._config.client_secret:
|
|
315
|
+
data["client_secret"] = self._config.client_secret
|
|
316
|
+
|
|
317
|
+
if self._config.redirect_uri:
|
|
318
|
+
data["redirect_uri"] = self._config.redirect_uri
|
|
319
|
+
|
|
320
|
+
if code_verifier:
|
|
321
|
+
data["code_verifier"] = code_verifier
|
|
322
|
+
|
|
323
|
+
if extra_params:
|
|
324
|
+
data.update(extra_params)
|
|
325
|
+
|
|
326
|
+
try:
|
|
327
|
+
response = self._http_client.post(
|
|
328
|
+
self._config.token_endpoint,
|
|
329
|
+
data=data,
|
|
330
|
+
timeout=self._timeout,
|
|
331
|
+
)
|
|
332
|
+
response.raise_for_status()
|
|
333
|
+
token_data = response.json()
|
|
334
|
+
|
|
335
|
+
except httpx.HTTPStatusError as e:
|
|
336
|
+
error_detail = self._extract_error_detail(e.response)
|
|
337
|
+
raise OAuthClientError(
|
|
338
|
+
f"Authorization code exchange failed: {error_detail}"
|
|
339
|
+
) from e
|
|
340
|
+
|
|
341
|
+
except httpx.RequestError as e:
|
|
342
|
+
raise OAuthClientError(
|
|
343
|
+
f"Authorization code exchange request failed: {e}"
|
|
344
|
+
) from e
|
|
345
|
+
|
|
346
|
+
token = OAuthToken.model_validate(token_data)
|
|
347
|
+
self._store_token(token)
|
|
348
|
+
|
|
349
|
+
state.session = None
|
|
350
|
+
self._save_state(state)
|
|
351
|
+
|
|
352
|
+
return token
|
|
353
|
+
|
|
354
|
+
def handle_authorization_callback(
|
|
355
|
+
self,
|
|
356
|
+
callback_params: dict[str, str],
|
|
357
|
+
) -> OAuthToken:
|
|
358
|
+
"""Handle the OAuth authorization callback."""
|
|
359
|
+
if "error" in callback_params:
|
|
360
|
+
error = callback_params.get("error", "unknown_error")
|
|
361
|
+
error_description = callback_params.get(
|
|
362
|
+
"error_description", "No description provided"
|
|
363
|
+
)
|
|
364
|
+
raise OAuthClientError(
|
|
365
|
+
f"Authorization failed: {error} - {error_description}"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
code = callback_params.get("code")
|
|
369
|
+
if not code:
|
|
370
|
+
raise OAuthClientError("No authorization code in callback")
|
|
371
|
+
|
|
372
|
+
callback_state = callback_params.get("state")
|
|
373
|
+
if callback_state:
|
|
374
|
+
state = self._load_state()
|
|
375
|
+
if state.session and state.session.state != callback_state:
|
|
376
|
+
raise OAuthClientError("State mismatch in authorization callback")
|
|
377
|
+
|
|
378
|
+
return self.fetch_token_with_authorization_code(code)
|
|
379
|
+
|
|
380
|
+
def get_pending_session(self) -> OAuthSession | None:
|
|
381
|
+
"""Get the current pending authorization session if any."""
|
|
382
|
+
state = self._load_state()
|
|
383
|
+
if state.session and state.session.auth_pending:
|
|
384
|
+
return state.session
|
|
385
|
+
return None
|
|
386
|
+
|
|
387
|
+
def complete_session(
|
|
388
|
+
self,
|
|
389
|
+
session_id: str,
|
|
390
|
+
*,
|
|
391
|
+
auth_code: str | None = None,
|
|
392
|
+
error: str | None = None,
|
|
393
|
+
error_description: str | None = None,
|
|
394
|
+
) -> None:
|
|
395
|
+
"""Mark an authorization session as complete."""
|
|
396
|
+
state = self._load_state()
|
|
397
|
+
if state.session is None or state.session.session_id != session_id:
|
|
398
|
+
return
|
|
399
|
+
|
|
400
|
+
state.session.auth_pending = False
|
|
401
|
+
state.session.auth_complete = error is None
|
|
402
|
+
state.session.auth_code = auth_code
|
|
403
|
+
state.session.error = error
|
|
404
|
+
state.session.error_description = error_description
|
|
405
|
+
|
|
406
|
+
self._save_state(state)
|
|
407
|
+
|
|
408
|
+
def set_authorization_code(self, code: str) -> None:
|
|
409
|
+
"""Store an authorization code in the current session."""
|
|
410
|
+
state = self._load_state()
|
|
411
|
+
if state.session:
|
|
412
|
+
state.session.auth_code = code
|
|
413
|
+
state.session.auth_pending = False
|
|
414
|
+
state.session.auth_complete = True
|
|
415
|
+
self._save_state(state)
|
|
416
|
+
|
|
417
|
+
def get_authorization_code(self, *, force_reload: bool = False) -> str | None:
|
|
418
|
+
"""Retrieve the authorization code from the current session."""
|
|
419
|
+
if force_reload:
|
|
420
|
+
self._auth_state.get_all(force_reload=True)
|
|
421
|
+
state = self._load_state()
|
|
422
|
+
if state.session and state.session.auth_complete:
|
|
423
|
+
return state.session.auth_code
|
|
424
|
+
return None
|
|
425
|
+
|
|
426
|
+
def clear_session(self) -> None:
|
|
427
|
+
"""Clear the current authorization session."""
|
|
428
|
+
state = self._load_state()
|
|
429
|
+
state.session = None
|
|
430
|
+
self._save_state(state)
|
|
431
|
+
|
|
432
|
+
@staticmethod
|
|
433
|
+
def _extract_error_detail(response: httpx.Response) -> str:
|
|
434
|
+
"""Extract error details from an HTTP response."""
|
|
435
|
+
try:
|
|
436
|
+
error_body = response.json()
|
|
437
|
+
if isinstance(error_body, dict):
|
|
438
|
+
if "error_description" in error_body:
|
|
439
|
+
return str(error_body["error_description"])
|
|
440
|
+
if "error" in error_body:
|
|
441
|
+
return str(error_body["error"])
|
|
442
|
+
return str(error_body)
|
|
443
|
+
except Exception:
|
|
444
|
+
return response.text or f"HTTP {response.status_code}"
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class CertificateOAuthClient(SOARAssetOAuthClient):
|
|
448
|
+
"""Complete OAuth client for certificate-based authentication."""
|
|
449
|
+
|
|
450
|
+
def __init__(
|
|
451
|
+
self,
|
|
452
|
+
config: OAuthConfig,
|
|
453
|
+
auth_state: AssetState,
|
|
454
|
+
certificate: CertificateCredentials,
|
|
455
|
+
*,
|
|
456
|
+
http_client: httpx.Client | None = None,
|
|
457
|
+
verify_ssl: bool = True,
|
|
458
|
+
timeout: float = 30.0,
|
|
459
|
+
) -> None:
|
|
460
|
+
super().__init__(
|
|
461
|
+
config,
|
|
462
|
+
auth_state,
|
|
463
|
+
http_client=http_client,
|
|
464
|
+
verify_ssl=verify_ssl,
|
|
465
|
+
timeout=timeout,
|
|
466
|
+
)
|
|
467
|
+
self._certificate = certificate
|
|
468
|
+
|
|
469
|
+
def fetch_token_with_certificate(
|
|
470
|
+
self,
|
|
471
|
+
*,
|
|
472
|
+
extra_params: dict[str, Any] | None = None,
|
|
473
|
+
) -> OAuthToken:
|
|
474
|
+
"""Fetch an access token using certificate-based client credentials."""
|
|
475
|
+
import time
|
|
476
|
+
|
|
477
|
+
import jwt
|
|
478
|
+
|
|
479
|
+
now = int(time.time())
|
|
480
|
+
jwt_payload = {
|
|
481
|
+
"aud": self._config.token_endpoint,
|
|
482
|
+
"iss": self._config.client_id,
|
|
483
|
+
"sub": self._config.client_id,
|
|
484
|
+
"exp": now + 300,
|
|
485
|
+
"iat": now,
|
|
486
|
+
"jti": str(uuid.uuid4()),
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
headers = {
|
|
490
|
+
"alg": "RS256",
|
|
491
|
+
"typ": "JWT",
|
|
492
|
+
"x5t": self._certificate.certificate_thumbprint,
|
|
493
|
+
}
|
|
494
|
+
|
|
495
|
+
client_assertion = jwt.encode(
|
|
496
|
+
jwt_payload,
|
|
497
|
+
self._certificate.private_key,
|
|
498
|
+
algorithm="RS256",
|
|
499
|
+
headers=headers,
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
assertion_type = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
|
503
|
+
data: dict[str, Any] = {
|
|
504
|
+
"grant_type": OAuthGrantType.CLIENT_CREDENTIALS.value,
|
|
505
|
+
"client_id": self._config.client_id,
|
|
506
|
+
"client_assertion_type": assertion_type,
|
|
507
|
+
"client_assertion": client_assertion,
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
scope = self._config.get_scope_string()
|
|
511
|
+
if scope:
|
|
512
|
+
data["scope"] = scope
|
|
513
|
+
|
|
514
|
+
if extra_params:
|
|
515
|
+
data.update(extra_params)
|
|
516
|
+
|
|
517
|
+
try:
|
|
518
|
+
response = self._http_client.post(
|
|
519
|
+
self._config.token_endpoint,
|
|
520
|
+
data=data,
|
|
521
|
+
timeout=self._timeout,
|
|
522
|
+
)
|
|
523
|
+
response.raise_for_status()
|
|
524
|
+
token_data = response.json()
|
|
525
|
+
|
|
526
|
+
except httpx.HTTPStatusError as e:
|
|
527
|
+
error_detail = self._extract_error_detail(e.response)
|
|
528
|
+
raise OAuthClientError(
|
|
529
|
+
f"Certificate-based token request failed: {error_detail}"
|
|
530
|
+
) from e
|
|
531
|
+
|
|
532
|
+
except httpx.RequestError as e:
|
|
533
|
+
raise OAuthClientError(
|
|
534
|
+
f"Certificate-based token request failed: {e}"
|
|
535
|
+
) from e
|
|
536
|
+
|
|
537
|
+
token = OAuthToken.model_validate(token_data)
|
|
538
|
+
self._store_token(token)
|
|
539
|
+
|
|
540
|
+
return token
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from soar_sdk.auth.client import SOARAssetOAuthClient
|
|
10
|
+
from soar_sdk.auth.httpx_auth import OAuthBearerAuth
|
|
11
|
+
from soar_sdk.auth.models import OAuthConfig
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Iterator
|
|
15
|
+
|
|
16
|
+
from soar_sdk.asset import BaseAsset
|
|
17
|
+
from soar_sdk.webhooks.models import WebhookRequest, WebhookResponse
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_oauth_auth(
|
|
21
|
+
asset: BaseAsset,
|
|
22
|
+
*,
|
|
23
|
+
client_id: str | None = None,
|
|
24
|
+
client_secret: str | None = None,
|
|
25
|
+
token_endpoint: str | None = None,
|
|
26
|
+
scope: list[str] | None = None,
|
|
27
|
+
auto_refresh: bool = True,
|
|
28
|
+
) -> OAuthBearerAuth:
|
|
29
|
+
"""Create an OAuthBearerAuth from an asset with defaults."""
|
|
30
|
+
resolved_client_id = client_id or getattr(asset, "client_id", None)
|
|
31
|
+
if not resolved_client_id:
|
|
32
|
+
msg = "client_id must be provided or available as asset.client_id"
|
|
33
|
+
raise ValueError(msg)
|
|
34
|
+
|
|
35
|
+
resolved_token_endpoint = (
|
|
36
|
+
token_endpoint
|
|
37
|
+
or getattr(asset, "token_endpoint", None)
|
|
38
|
+
or getattr(asset, "token_url", None)
|
|
39
|
+
)
|
|
40
|
+
if not resolved_token_endpoint:
|
|
41
|
+
msg = "token_endpoint must be provided or available as asset.token_endpoint or asset.token_url"
|
|
42
|
+
raise ValueError(msg)
|
|
43
|
+
|
|
44
|
+
resolved_scope = scope
|
|
45
|
+
if resolved_scope is None:
|
|
46
|
+
asset_scope = getattr(asset, "scope", None)
|
|
47
|
+
if isinstance(asset_scope, str) and asset_scope:
|
|
48
|
+
resolved_scope = asset_scope.split()
|
|
49
|
+
elif isinstance(asset_scope, list):
|
|
50
|
+
resolved_scope = asset_scope
|
|
51
|
+
|
|
52
|
+
config = OAuthConfig(
|
|
53
|
+
client_id=resolved_client_id,
|
|
54
|
+
client_secret=client_secret or getattr(asset, "client_secret", None),
|
|
55
|
+
token_endpoint=resolved_token_endpoint,
|
|
56
|
+
scope=resolved_scope,
|
|
57
|
+
)
|
|
58
|
+
oauth_client = SOARAssetOAuthClient(config, asset.auth_state)
|
|
59
|
+
return OAuthBearerAuth(oauth_client, auto_refresh=auto_refresh)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@contextmanager
|
|
63
|
+
def create_oauth_client(
|
|
64
|
+
asset: BaseAsset,
|
|
65
|
+
*,
|
|
66
|
+
client_id: str | None = None,
|
|
67
|
+
client_secret: str | None = None,
|
|
68
|
+
token_endpoint: str | None = None,
|
|
69
|
+
scope: list[str] | None = None,
|
|
70
|
+
auto_refresh: bool = True,
|
|
71
|
+
timeout: float = 30.0,
|
|
72
|
+
**httpx_kwargs: Any, # noqa: ANN401
|
|
73
|
+
) -> Iterator[httpx.Client]:
|
|
74
|
+
"""Create an httpx.Client with OAuth authentication from an asset."""
|
|
75
|
+
auth = create_oauth_auth(
|
|
76
|
+
asset,
|
|
77
|
+
client_id=client_id,
|
|
78
|
+
client_secret=client_secret,
|
|
79
|
+
token_endpoint=token_endpoint,
|
|
80
|
+
scope=scope,
|
|
81
|
+
auto_refresh=auto_refresh,
|
|
82
|
+
)
|
|
83
|
+
with httpx.Client(auth=auth, timeout=timeout, **httpx_kwargs) as client:
|
|
84
|
+
yield client
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def create_oauth_callback_handler(
|
|
88
|
+
get_oauth_client: Callable[[Any], SOARAssetOAuthClient],
|
|
89
|
+
*,
|
|
90
|
+
success_message: str = "Authorization successful! You can close this window.",
|
|
91
|
+
) -> Callable[[WebhookRequest], WebhookResponse]:
|
|
92
|
+
"""Factory for creating standard OAuth callback webhook handlers."""
|
|
93
|
+
from soar_sdk.webhooks.models import WebhookResponse
|
|
94
|
+
|
|
95
|
+
def oauth_callback(request: WebhookRequest) -> WebhookResponse:
|
|
96
|
+
query_params = {k: v[0] if v else "" for k, v in request.query.items()}
|
|
97
|
+
|
|
98
|
+
if "error" in query_params:
|
|
99
|
+
reason = query_params.get("error_description", "Unknown error")
|
|
100
|
+
return WebhookResponse.text_response(
|
|
101
|
+
content=f"Authorization failed: {reason}",
|
|
102
|
+
status_code=400,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
code = query_params.get("code")
|
|
106
|
+
if not code:
|
|
107
|
+
return WebhookResponse.text_response(
|
|
108
|
+
content="Missing authorization code",
|
|
109
|
+
status_code=400,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
oauth_client = get_oauth_client(request.asset)
|
|
113
|
+
oauth_client.set_authorization_code(code)
|
|
114
|
+
|
|
115
|
+
return WebhookResponse.text_response(
|
|
116
|
+
content=success_message,
|
|
117
|
+
status_code=200,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return oauth_callback
|