splunk-soar-sdk 3.6.1__py3-none-any.whl → 3.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,540 @@
1
+ from __future__ import annotations
2
+
3
+ import secrets
4
+ import urllib.parse
5
+ import uuid
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import httpx
9
+ from authlib.oauth2.rfc7636 import ( # type: ignore[import-untyped]
10
+ create_s256_code_challenge,
11
+ )
12
+
13
+ from soar_sdk.auth.models import (
14
+ CertificateCredentials,
15
+ OAuthConfig,
16
+ OAuthGrantType,
17
+ OAuthSession,
18
+ OAuthState,
19
+ OAuthToken,
20
+ )
21
+ from soar_sdk.logging import getLogger
22
+
23
+ if TYPE_CHECKING:
24
+ from soar_sdk.asset_state import AssetState
25
+
26
+ logger = getLogger()
27
+
28
+
29
+ class OAuthClientError(Exception):
30
+ """Base exception for OAuth client errors."""
31
+
32
+
33
+ class TokenExpiredError(OAuthClientError):
34
+ """Raised when an access token has expired and cannot be refreshed."""
35
+
36
+
37
+ class AuthorizationRequiredError(OAuthClientError):
38
+ """Raised when user authorization is required to obtain tokens."""
39
+
40
+
41
+ class TokenRefreshError(OAuthClientError):
42
+ """Raised when token refresh fails."""
43
+
44
+
45
+ class ConfigurationChangedError(OAuthClientError):
46
+ """Raised when OAuth client credentials have changed, requiring re-authorization."""
47
+
48
+
49
+ class SOARAssetOAuthClient:
50
+ """Complete OAuth 2.0 client for SOAR asset authentication."""
51
+
52
+ def __init__(
53
+ self,
54
+ config: OAuthConfig,
55
+ auth_state: AssetState,
56
+ *,
57
+ http_client: httpx.Client | None = None,
58
+ verify_ssl: bool = True,
59
+ timeout: float = 30.0,
60
+ ) -> None:
61
+ self._config = config
62
+ self._auth_state = auth_state
63
+ self._timeout = timeout
64
+ self._verify_ssl = verify_ssl
65
+
66
+ self._http_client = http_client or httpx.Client(
67
+ verify=verify_ssl,
68
+ timeout=timeout,
69
+ )
70
+
71
+ @property
72
+ def config(self) -> OAuthConfig:
73
+ """Return the OAuth configuration."""
74
+ return self._config
75
+
76
+ def _load_state(self) -> OAuthState:
77
+ """Load OAuth state from asset storage."""
78
+ state_data = self._auth_state.get_all()
79
+ if not state_data:
80
+ return OAuthState()
81
+
82
+ oauth_data = state_data.get("oauth")
83
+ if oauth_data is None:
84
+ return OAuthState()
85
+
86
+ if isinstance(oauth_data, dict):
87
+ return OAuthState.model_validate(oauth_data)
88
+ return OAuthState()
89
+
90
+ def _save_state(self, state: OAuthState) -> None:
91
+ """Save OAuth state to asset storage."""
92
+ current = self._auth_state.get_all()
93
+ current["oauth"] = state.model_dump(mode="json", exclude_none=True)
94
+ self._auth_state.put_all(current)
95
+
96
+ def _clear_tokens(self) -> None:
97
+ """Clear stored tokens from auth state."""
98
+ state = self._load_state()
99
+ state.token = None
100
+ self._save_state(state)
101
+
102
+ def _check_client_id_changed(self) -> bool:
103
+ """Check if the client_id has changed since tokens were stored."""
104
+ state = self._load_state()
105
+ if state.client_id is None:
106
+ return False
107
+ return state.client_id != self._config.client_id
108
+
109
+ def get_stored_token(self) -> OAuthToken | None:
110
+ """Retrieve the stored OAuth token if available."""
111
+ if self._check_client_id_changed():
112
+ self._clear_tokens()
113
+ raise ConfigurationChangedError(
114
+ "OAuth client credentials have changed. Re-authorization required."
115
+ )
116
+
117
+ state = self._load_state()
118
+ return state.token
119
+
120
+ def get_valid_token(self, auto_refresh: bool = True) -> OAuthToken:
121
+ """Get a valid access token, refreshing if necessary."""
122
+ token = self.get_stored_token()
123
+
124
+ if token is None:
125
+ raise AuthorizationRequiredError(
126
+ "No OAuth token available. Authorization is required."
127
+ )
128
+
129
+ if not token.is_expired():
130
+ return token
131
+
132
+ if not auto_refresh:
133
+ raise TokenExpiredError("Access token has expired.")
134
+
135
+ if token.refresh_token is None:
136
+ raise TokenExpiredError(
137
+ "Access token has expired and no refresh token is available."
138
+ )
139
+
140
+ return self.refresh_token(token.refresh_token)
141
+
142
+ def refresh_token(self, refresh_token: str) -> OAuthToken:
143
+ """Refresh an access token using a refresh token."""
144
+ try:
145
+ response = self._http_client.post(
146
+ self._config.token_endpoint,
147
+ data={
148
+ "grant_type": OAuthGrantType.REFRESH_TOKEN.value,
149
+ "client_id": self._config.client_id,
150
+ "client_secret": self._config.client_secret,
151
+ "refresh_token": refresh_token,
152
+ },
153
+ timeout=self._timeout,
154
+ )
155
+ response.raise_for_status()
156
+ token_data = response.json()
157
+
158
+ except httpx.HTTPStatusError as e:
159
+ error_detail = ""
160
+ try:
161
+ error_body = e.response.json()
162
+ error_detail = error_body.get("error_description", str(error_body))
163
+ except Exception:
164
+ error_detail = e.response.text
165
+
166
+ raise TokenRefreshError(f"Token refresh failed: {error_detail}") from e
167
+
168
+ except httpx.RequestError as e:
169
+ raise TokenRefreshError(f"Token refresh request failed: {e}") from e
170
+
171
+ if "refresh_token" not in token_data:
172
+ token_data["refresh_token"] = refresh_token
173
+
174
+ new_token = OAuthToken.model_validate(token_data)
175
+ self._store_token(new_token)
176
+
177
+ return new_token
178
+
179
+ def _store_token(self, token: OAuthToken) -> None:
180
+ """Store a token in the auth state."""
181
+ state = self._load_state()
182
+ state.token = token
183
+ state.client_id = self._config.client_id
184
+ self._save_state(state)
185
+
186
+ def fetch_token_with_client_credentials(
187
+ self,
188
+ *,
189
+ extra_params: dict[str, Any] | None = None,
190
+ ) -> OAuthToken:
191
+ """Fetch an access token using client credentials grant."""
192
+ data: dict[str, Any] = {
193
+ "grant_type": OAuthGrantType.CLIENT_CREDENTIALS.value,
194
+ "client_id": self._config.client_id,
195
+ }
196
+
197
+ if self._config.client_secret:
198
+ data["client_secret"] = self._config.client_secret
199
+
200
+ scope = self._config.get_scope_string()
201
+ if scope:
202
+ data["scope"] = scope
203
+
204
+ if extra_params:
205
+ data.update(extra_params)
206
+
207
+ try:
208
+ response = self._http_client.post(
209
+ self._config.token_endpoint,
210
+ data=data,
211
+ timeout=self._timeout,
212
+ )
213
+ response.raise_for_status()
214
+ token_data = response.json()
215
+
216
+ except httpx.HTTPStatusError as e:
217
+ error_detail = self._extract_error_detail(e.response)
218
+ raise OAuthClientError(
219
+ f"Client credentials token request failed: {error_detail}"
220
+ ) from e
221
+
222
+ except httpx.RequestError as e:
223
+ raise OAuthClientError(
224
+ f"Client credentials token request failed: {e}"
225
+ ) from e
226
+
227
+ token = OAuthToken.model_validate(token_data)
228
+ self._store_token(token)
229
+
230
+ return token
231
+
232
+ def create_authorization_url(
233
+ self,
234
+ asset_id: str,
235
+ *,
236
+ use_pkce: bool = True,
237
+ extra_params: dict[str, Any] | None = None,
238
+ ) -> tuple[str, OAuthSession]:
239
+ """Create a sessioned authorization URL for the authorization code flow."""
240
+ if not self._config.authorization_endpoint:
241
+ raise OAuthClientError(
242
+ "authorization_endpoint is required for authorization code flow"
243
+ )
244
+
245
+ session_id = str(uuid.uuid4())
246
+ state_value = urllib.parse.urlencode(
247
+ {
248
+ "asset_id": asset_id,
249
+ "session_id": session_id,
250
+ }
251
+ )
252
+
253
+ session = OAuthSession(
254
+ session_id=session_id,
255
+ asset_id=asset_id,
256
+ state=state_value,
257
+ auth_pending=True,
258
+ auth_complete=False,
259
+ )
260
+
261
+ params: dict[str, Any] = {
262
+ "response_type": "code",
263
+ "client_id": self._config.client_id,
264
+ "state": state_value,
265
+ }
266
+
267
+ if self._config.redirect_uri:
268
+ params["redirect_uri"] = self._config.redirect_uri
269
+
270
+ scope = self._config.get_scope_string()
271
+ if scope:
272
+ params["scope"] = scope
273
+
274
+ if use_pkce:
275
+ code_verifier = secrets.token_urlsafe(32)
276
+ code_challenge = create_s256_code_challenge(code_verifier)
277
+ params["code_challenge"] = code_challenge
278
+ params["code_challenge_method"] = "S256"
279
+ session.code_verifier = code_verifier
280
+
281
+ if extra_params:
282
+ params.update(extra_params)
283
+
284
+ auth_url = (
285
+ f"{self._config.authorization_endpoint}?{urllib.parse.urlencode(params)}"
286
+ )
287
+
288
+ state = self._load_state()
289
+ state.session = session
290
+ self._save_state(state)
291
+
292
+ return auth_url, session
293
+
294
+ def fetch_token_with_authorization_code(
295
+ self,
296
+ code: str,
297
+ *,
298
+ code_verifier: str | None = None,
299
+ extra_params: dict[str, Any] | None = None,
300
+ ) -> OAuthToken:
301
+ """Exchange an authorization code for an access token."""
302
+ state = self._load_state()
303
+ session = state.session
304
+
305
+ if code_verifier is None and session and session.code_verifier:
306
+ code_verifier = session.code_verifier
307
+
308
+ data: dict[str, Any] = {
309
+ "grant_type": OAuthGrantType.AUTHORIZATION_CODE.value,
310
+ "client_id": self._config.client_id,
311
+ "code": code,
312
+ }
313
+
314
+ if self._config.client_secret:
315
+ data["client_secret"] = self._config.client_secret
316
+
317
+ if self._config.redirect_uri:
318
+ data["redirect_uri"] = self._config.redirect_uri
319
+
320
+ if code_verifier:
321
+ data["code_verifier"] = code_verifier
322
+
323
+ if extra_params:
324
+ data.update(extra_params)
325
+
326
+ try:
327
+ response = self._http_client.post(
328
+ self._config.token_endpoint,
329
+ data=data,
330
+ timeout=self._timeout,
331
+ )
332
+ response.raise_for_status()
333
+ token_data = response.json()
334
+
335
+ except httpx.HTTPStatusError as e:
336
+ error_detail = self._extract_error_detail(e.response)
337
+ raise OAuthClientError(
338
+ f"Authorization code exchange failed: {error_detail}"
339
+ ) from e
340
+
341
+ except httpx.RequestError as e:
342
+ raise OAuthClientError(
343
+ f"Authorization code exchange request failed: {e}"
344
+ ) from e
345
+
346
+ token = OAuthToken.model_validate(token_data)
347
+ self._store_token(token)
348
+
349
+ state.session = None
350
+ self._save_state(state)
351
+
352
+ return token
353
+
354
+ def handle_authorization_callback(
355
+ self,
356
+ callback_params: dict[str, str],
357
+ ) -> OAuthToken:
358
+ """Handle the OAuth authorization callback."""
359
+ if "error" in callback_params:
360
+ error = callback_params.get("error", "unknown_error")
361
+ error_description = callback_params.get(
362
+ "error_description", "No description provided"
363
+ )
364
+ raise OAuthClientError(
365
+ f"Authorization failed: {error} - {error_description}"
366
+ )
367
+
368
+ code = callback_params.get("code")
369
+ if not code:
370
+ raise OAuthClientError("No authorization code in callback")
371
+
372
+ callback_state = callback_params.get("state")
373
+ if callback_state:
374
+ state = self._load_state()
375
+ if state.session and state.session.state != callback_state:
376
+ raise OAuthClientError("State mismatch in authorization callback")
377
+
378
+ return self.fetch_token_with_authorization_code(code)
379
+
380
+ def get_pending_session(self) -> OAuthSession | None:
381
+ """Get the current pending authorization session if any."""
382
+ state = self._load_state()
383
+ if state.session and state.session.auth_pending:
384
+ return state.session
385
+ return None
386
+
387
+ def complete_session(
388
+ self,
389
+ session_id: str,
390
+ *,
391
+ auth_code: str | None = None,
392
+ error: str | None = None,
393
+ error_description: str | None = None,
394
+ ) -> None:
395
+ """Mark an authorization session as complete."""
396
+ state = self._load_state()
397
+ if state.session is None or state.session.session_id != session_id:
398
+ return
399
+
400
+ state.session.auth_pending = False
401
+ state.session.auth_complete = error is None
402
+ state.session.auth_code = auth_code
403
+ state.session.error = error
404
+ state.session.error_description = error_description
405
+
406
+ self._save_state(state)
407
+
408
+ def set_authorization_code(self, code: str) -> None:
409
+ """Store an authorization code in the current session."""
410
+ state = self._load_state()
411
+ if state.session:
412
+ state.session.auth_code = code
413
+ state.session.auth_pending = False
414
+ state.session.auth_complete = True
415
+ self._save_state(state)
416
+
417
+ def get_authorization_code(self, *, force_reload: bool = False) -> str | None:
418
+ """Retrieve the authorization code from the current session."""
419
+ if force_reload:
420
+ self._auth_state.get_all(force_reload=True)
421
+ state = self._load_state()
422
+ if state.session and state.session.auth_complete:
423
+ return state.session.auth_code
424
+ return None
425
+
426
+ def clear_session(self) -> None:
427
+ """Clear the current authorization session."""
428
+ state = self._load_state()
429
+ state.session = None
430
+ self._save_state(state)
431
+
432
+ @staticmethod
433
+ def _extract_error_detail(response: httpx.Response) -> str:
434
+ """Extract error details from an HTTP response."""
435
+ try:
436
+ error_body = response.json()
437
+ if isinstance(error_body, dict):
438
+ if "error_description" in error_body:
439
+ return str(error_body["error_description"])
440
+ if "error" in error_body:
441
+ return str(error_body["error"])
442
+ return str(error_body)
443
+ except Exception:
444
+ return response.text or f"HTTP {response.status_code}"
445
+
446
+
447
+ class CertificateOAuthClient(SOARAssetOAuthClient):
448
+ """Complete OAuth client for certificate-based authentication."""
449
+
450
+ def __init__(
451
+ self,
452
+ config: OAuthConfig,
453
+ auth_state: AssetState,
454
+ certificate: CertificateCredentials,
455
+ *,
456
+ http_client: httpx.Client | None = None,
457
+ verify_ssl: bool = True,
458
+ timeout: float = 30.0,
459
+ ) -> None:
460
+ super().__init__(
461
+ config,
462
+ auth_state,
463
+ http_client=http_client,
464
+ verify_ssl=verify_ssl,
465
+ timeout=timeout,
466
+ )
467
+ self._certificate = certificate
468
+
469
+ def fetch_token_with_certificate(
470
+ self,
471
+ *,
472
+ extra_params: dict[str, Any] | None = None,
473
+ ) -> OAuthToken:
474
+ """Fetch an access token using certificate-based client credentials."""
475
+ import time
476
+
477
+ import jwt
478
+
479
+ now = int(time.time())
480
+ jwt_payload = {
481
+ "aud": self._config.token_endpoint,
482
+ "iss": self._config.client_id,
483
+ "sub": self._config.client_id,
484
+ "exp": now + 300,
485
+ "iat": now,
486
+ "jti": str(uuid.uuid4()),
487
+ }
488
+
489
+ headers = {
490
+ "alg": "RS256",
491
+ "typ": "JWT",
492
+ "x5t": self._certificate.certificate_thumbprint,
493
+ }
494
+
495
+ client_assertion = jwt.encode(
496
+ jwt_payload,
497
+ self._certificate.private_key,
498
+ algorithm="RS256",
499
+ headers=headers,
500
+ )
501
+
502
+ assertion_type = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
503
+ data: dict[str, Any] = {
504
+ "grant_type": OAuthGrantType.CLIENT_CREDENTIALS.value,
505
+ "client_id": self._config.client_id,
506
+ "client_assertion_type": assertion_type,
507
+ "client_assertion": client_assertion,
508
+ }
509
+
510
+ scope = self._config.get_scope_string()
511
+ if scope:
512
+ data["scope"] = scope
513
+
514
+ if extra_params:
515
+ data.update(extra_params)
516
+
517
+ try:
518
+ response = self._http_client.post(
519
+ self._config.token_endpoint,
520
+ data=data,
521
+ timeout=self._timeout,
522
+ )
523
+ response.raise_for_status()
524
+ token_data = response.json()
525
+
526
+ except httpx.HTTPStatusError as e:
527
+ error_detail = self._extract_error_detail(e.response)
528
+ raise OAuthClientError(
529
+ f"Certificate-based token request failed: {error_detail}"
530
+ ) from e
531
+
532
+ except httpx.RequestError as e:
533
+ raise OAuthClientError(
534
+ f"Certificate-based token request failed: {e}"
535
+ ) from e
536
+
537
+ token = OAuthToken.model_validate(token_data)
538
+ self._store_token(token)
539
+
540
+ return token
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from contextlib import contextmanager
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import httpx
8
+
9
+ from soar_sdk.auth.client import SOARAssetOAuthClient
10
+ from soar_sdk.auth.httpx_auth import OAuthBearerAuth
11
+ from soar_sdk.auth.models import OAuthConfig
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Iterator
15
+
16
+ from soar_sdk.asset import BaseAsset
17
+ from soar_sdk.webhooks.models import WebhookRequest, WebhookResponse
18
+
19
+
20
+ def create_oauth_auth(
21
+ asset: BaseAsset,
22
+ *,
23
+ client_id: str | None = None,
24
+ client_secret: str | None = None,
25
+ token_endpoint: str | None = None,
26
+ scope: list[str] | None = None,
27
+ auto_refresh: bool = True,
28
+ ) -> OAuthBearerAuth:
29
+ """Create an OAuthBearerAuth from an asset with defaults."""
30
+ resolved_client_id = client_id or getattr(asset, "client_id", None)
31
+ if not resolved_client_id:
32
+ msg = "client_id must be provided or available as asset.client_id"
33
+ raise ValueError(msg)
34
+
35
+ resolved_token_endpoint = (
36
+ token_endpoint
37
+ or getattr(asset, "token_endpoint", None)
38
+ or getattr(asset, "token_url", None)
39
+ )
40
+ if not resolved_token_endpoint:
41
+ msg = "token_endpoint must be provided or available as asset.token_endpoint or asset.token_url"
42
+ raise ValueError(msg)
43
+
44
+ resolved_scope = scope
45
+ if resolved_scope is None:
46
+ asset_scope = getattr(asset, "scope", None)
47
+ if isinstance(asset_scope, str) and asset_scope:
48
+ resolved_scope = asset_scope.split()
49
+ elif isinstance(asset_scope, list):
50
+ resolved_scope = asset_scope
51
+
52
+ config = OAuthConfig(
53
+ client_id=resolved_client_id,
54
+ client_secret=client_secret or getattr(asset, "client_secret", None),
55
+ token_endpoint=resolved_token_endpoint,
56
+ scope=resolved_scope,
57
+ )
58
+ oauth_client = SOARAssetOAuthClient(config, asset.auth_state)
59
+ return OAuthBearerAuth(oauth_client, auto_refresh=auto_refresh)
60
+
61
+
62
+ @contextmanager
63
+ def create_oauth_client(
64
+ asset: BaseAsset,
65
+ *,
66
+ client_id: str | None = None,
67
+ client_secret: str | None = None,
68
+ token_endpoint: str | None = None,
69
+ scope: list[str] | None = None,
70
+ auto_refresh: bool = True,
71
+ timeout: float = 30.0,
72
+ **httpx_kwargs: Any, # noqa: ANN401
73
+ ) -> Iterator[httpx.Client]:
74
+ """Create an httpx.Client with OAuth authentication from an asset."""
75
+ auth = create_oauth_auth(
76
+ asset,
77
+ client_id=client_id,
78
+ client_secret=client_secret,
79
+ token_endpoint=token_endpoint,
80
+ scope=scope,
81
+ auto_refresh=auto_refresh,
82
+ )
83
+ with httpx.Client(auth=auth, timeout=timeout, **httpx_kwargs) as client:
84
+ yield client
85
+
86
+
87
+ def create_oauth_callback_handler(
88
+ get_oauth_client: Callable[[Any], SOARAssetOAuthClient],
89
+ *,
90
+ success_message: str = "Authorization successful! You can close this window.",
91
+ ) -> Callable[[WebhookRequest], WebhookResponse]:
92
+ """Factory for creating standard OAuth callback webhook handlers."""
93
+ from soar_sdk.webhooks.models import WebhookResponse
94
+
95
+ def oauth_callback(request: WebhookRequest) -> WebhookResponse:
96
+ query_params = {k: v[0] if v else "" for k, v in request.query.items()}
97
+
98
+ if "error" in query_params:
99
+ reason = query_params.get("error_description", "Unknown error")
100
+ return WebhookResponse.text_response(
101
+ content=f"Authorization failed: {reason}",
102
+ status_code=400,
103
+ )
104
+
105
+ code = query_params.get("code")
106
+ if not code:
107
+ return WebhookResponse.text_response(
108
+ content="Missing authorization code",
109
+ status_code=400,
110
+ )
111
+
112
+ oauth_client = get_oauth_client(request.asset)
113
+ oauth_client.set_authorization_code(code)
114
+
115
+ return WebhookResponse.text_response(
116
+ content=success_message,
117
+ status_code=200,
118
+ )
119
+
120
+ return oauth_callback