aio-sf 0.1.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,511 @@
1
+ """
2
+ Salesforce connection module providing authentication and basic API functionality.
3
+ """
4
+
5
+ import base64
6
+ import logging
7
+ import time
8
+ from abc import ABC, abstractmethod
9
+ from typing import Dict, Optional, Any
10
+ from urllib.parse import urljoin
11
+
12
+ import httpx
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SalesforceAuthError(Exception):
18
+ """Raised when authentication fails or tokens are invalid."""
19
+
20
+ pass
21
+
22
+
23
+ class AuthStrategy(ABC):
24
+ """Abstract base class for Salesforce authentication strategies."""
25
+
26
+ def __init__(self, instance_url: str):
27
+ self.instance_url = instance_url.rstrip("/") if instance_url else None
28
+ self.access_token: Optional[str] = None
29
+ self.expires_at: Optional[int] = None
30
+
31
+ @abstractmethod
32
+ async def authenticate(self, http_client: httpx.AsyncClient) -> str:
33
+ """Authenticate and return access token."""
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def refresh_if_needed(self, http_client: httpx.AsyncClient) -> str:
38
+ """Refresh token if needed and return access token."""
39
+ pass
40
+
41
+ @abstractmethod
42
+ def can_refresh(self) -> bool:
43
+ """Return True if this strategy can refresh expired tokens."""
44
+ pass
45
+
46
+ def is_token_expired(self) -> bool:
47
+ """Check if the current token is expired."""
48
+ if not self.expires_at:
49
+ return False
50
+ return self.expires_at <= int(time.time())
51
+
52
+
53
+ class ClientCredentialsAuth(AuthStrategy):
54
+ """OAuth Client Credentials authentication strategy."""
55
+
56
+ def __init__(self, instance_url: str, client_id: str, client_secret: str):
57
+ super().__init__(instance_url)
58
+ self.client_id = client_id
59
+ self.client_secret = client_secret
60
+
61
+ async def authenticate(self, http_client: httpx.AsyncClient) -> str:
62
+ """Authenticate using OAuth client credentials flow."""
63
+ logger.info("Getting Salesforce access token using client credentials")
64
+
65
+ oauth_url = urljoin(self.instance_url, "/services/oauth2/token")
66
+ data = {
67
+ "grant_type": "client_credentials",
68
+ "client_id": self.client_id,
69
+ "client_secret": self.client_secret,
70
+ }
71
+
72
+ oauth_headers = {
73
+ "Content-Type": "application/x-www-form-urlencoded",
74
+ "Accept": "application/json",
75
+ }
76
+
77
+ try:
78
+ response = await http_client.post(
79
+ oauth_url, data=data, headers=oauth_headers
80
+ )
81
+ response.raise_for_status()
82
+ token_data = response.json()
83
+ self.access_token = token_data["access_token"]
84
+
85
+ # Get token expiration information
86
+ await self._get_token_expiration(http_client)
87
+
88
+ logger.info("Successfully obtained Salesforce access token")
89
+ return self.access_token
90
+
91
+ except httpx.HTTPError as e:
92
+ logger.error(f"HTTP error getting access token: {e}")
93
+ raise SalesforceAuthError(f"Authentication failed: {e}")
94
+ except Exception as e:
95
+ logger.error(f"Unexpected error getting access token: {e}")
96
+ raise SalesforceAuthError(f"Authentication failed: {e}")
97
+
98
+ async def refresh_if_needed(self, http_client: httpx.AsyncClient) -> str:
99
+ """Refresh token if needed (always re-authenticate for client credentials)."""
100
+ if self.access_token and not self.is_token_expired():
101
+ return self.access_token
102
+ return await self.authenticate(http_client)
103
+
104
+ def can_refresh(self) -> bool:
105
+ """Client credentials can always re-authenticate."""
106
+ return True
107
+
108
+ async def _get_token_expiration(self, http_client: httpx.AsyncClient) -> None:
109
+ """Get token expiration time via introspection."""
110
+ introspect_url = urljoin(self.instance_url, "/services/oauth2/introspect")
111
+ introspect_data = {
112
+ "token": self.access_token,
113
+ "token_type_hint": "access_token",
114
+ }
115
+
116
+ auth_string = base64.b64encode(
117
+ f"{self.client_id}:{self.client_secret}".encode("utf-8")
118
+ ).decode("utf-8")
119
+
120
+ introspect_response = await http_client.post(
121
+ introspect_url,
122
+ data=introspect_data,
123
+ headers={
124
+ "Authorization": f"Basic {auth_string}",
125
+ "Content-Type": "application/x-www-form-urlencoded",
126
+ "Accept": "application/json",
127
+ },
128
+ )
129
+ introspect_response.raise_for_status()
130
+ introspect_data = introspect_response.json()
131
+
132
+ # Set expiration time with 30 second buffer
133
+ self.expires_at = introspect_data["exp"] - 30
134
+
135
+
136
+ class RefreshTokenAuth(AuthStrategy):
137
+ """OAuth Refresh Token authentication strategy."""
138
+
139
+ def __init__(
140
+ self,
141
+ instance_url: str,
142
+ access_token: str,
143
+ refresh_token: str,
144
+ client_id: str,
145
+ client_secret: str,
146
+ ):
147
+ super().__init__(instance_url)
148
+ self.access_token = access_token
149
+ self.refresh_token = refresh_token
150
+ self.client_id = client_id
151
+ self.client_secret = client_secret
152
+
153
+ async def authenticate(self, http_client: httpx.AsyncClient) -> str:
154
+ """Use the provided access token (refresh if needed)."""
155
+ if self.access_token and not self.is_token_expired():
156
+ return self.access_token
157
+ return await self._refresh_token(http_client)
158
+
159
+ async def refresh_if_needed(self, http_client: httpx.AsyncClient) -> str:
160
+ """Refresh token if needed."""
161
+ if self.access_token and not self.is_token_expired():
162
+ return self.access_token
163
+ return await self._refresh_token(http_client)
164
+
165
+ def can_refresh(self) -> bool:
166
+ """Refresh token auth can refresh tokens."""
167
+ return bool(self.refresh_token)
168
+
169
+ async def _refresh_token(self, http_client: httpx.AsyncClient) -> str:
170
+ """Refresh the access token using the refresh token."""
171
+ logger.info("Refreshing Salesforce access token using refresh token")
172
+
173
+ oauth_url = urljoin(self.instance_url, "/services/oauth2/token")
174
+ data = {
175
+ "grant_type": "refresh_token",
176
+ "refresh_token": self.refresh_token,
177
+ "client_id": self.client_id,
178
+ "client_secret": self.client_secret,
179
+ }
180
+
181
+ oauth_headers = {
182
+ "Content-Type": "application/x-www-form-urlencoded",
183
+ "Accept": "application/json",
184
+ }
185
+
186
+ try:
187
+ response = await http_client.post(
188
+ oauth_url, data=data, headers=oauth_headers
189
+ )
190
+ response.raise_for_status()
191
+ token_data = response.json()
192
+ self.access_token = token_data["access_token"]
193
+
194
+ # Update refresh token if a new one is provided
195
+ if "refresh_token" in token_data:
196
+ self.refresh_token = token_data["refresh_token"]
197
+
198
+ # Get token expiration information
199
+ await self._get_token_expiration(http_client)
200
+
201
+ logger.info("Successfully refreshed Salesforce access token")
202
+ return self.access_token
203
+
204
+ except httpx.HTTPError as e:
205
+ logger.error(f"HTTP error refreshing access token: {e}")
206
+ raise SalesforceAuthError(f"Token refresh failed: {e}")
207
+ except Exception as e:
208
+ logger.error(f"Unexpected error refreshing access token: {e}")
209
+ raise SalesforceAuthError(f"Token refresh failed: {e}")
210
+
211
+ async def _get_token_expiration(self, http_client: httpx.AsyncClient) -> None:
212
+ """Get token expiration time via introspection."""
213
+ introspect_url = urljoin(self.instance_url, "/services/oauth2/introspect")
214
+ introspect_data = {
215
+ "token": self.access_token,
216
+ "token_type_hint": "access_token",
217
+ }
218
+
219
+ auth_string = base64.b64encode(
220
+ f"{self.client_id}:{self.client_secret}".encode("utf-8")
221
+ ).decode("utf-8")
222
+
223
+ introspect_response = await http_client.post(
224
+ introspect_url,
225
+ data=introspect_data,
226
+ headers={
227
+ "Authorization": f"Basic {auth_string}",
228
+ "Content-Type": "application/x-www-form-urlencoded",
229
+ "Accept": "application/json",
230
+ },
231
+ )
232
+ introspect_response.raise_for_status()
233
+ introspect_data = introspect_response.json()
234
+
235
+ # Set expiration time with 30 second buffer
236
+ self.expires_at = introspect_data["exp"] - 30
237
+
238
+
239
+ class StaticTokenAuth(AuthStrategy):
240
+ """Static access token authentication strategy (no refresh capability)."""
241
+
242
+ def __init__(self, instance_url: str, access_token: str):
243
+ super().__init__(instance_url)
244
+ self.access_token = access_token
245
+
246
+ async def authenticate(self, http_client: httpx.AsyncClient) -> str:
247
+ """Return the static access token."""
248
+ if not self.access_token:
249
+ raise SalesforceAuthError("No access token available")
250
+ return self.access_token
251
+
252
+ async def refresh_if_needed(self, http_client: httpx.AsyncClient) -> str:
253
+ """Cannot refresh static tokens."""
254
+ if self.is_token_expired():
255
+ raise SalesforceAuthError(
256
+ "Access token has expired and no refresh capability is available."
257
+ )
258
+ return await self.authenticate(http_client)
259
+
260
+ def can_refresh(self) -> bool:
261
+ """Static tokens cannot be refreshed."""
262
+ return False
263
+
264
+
265
+ class SalesforceConnection:
266
+ """
267
+ A connection object containing Salesforce authentication details and basic API functionality.
268
+
269
+ This provides a simple interface for Salesforce API interactions using explicit
270
+ authentication strategies.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ auth_strategy: AuthStrategy,
276
+ version: str = "v60.0",
277
+ timeout: float = 30.0,
278
+ ):
279
+ """
280
+ Initialize Salesforce connection with an explicit authentication strategy.
281
+
282
+ :param auth_strategy: Authentication strategy to use (ClientCredentialsAuth, RefreshTokenAuth, or StaticTokenAuth)
283
+ :param version: API version (e.g., "v60.0")
284
+ :param timeout: HTTP request timeout in seconds
285
+ """
286
+ self.auth_strategy = auth_strategy
287
+ self.version = version
288
+ self.timeout = timeout
289
+
290
+ # Persistent HTTP client for better connection management
291
+ self._http_client: Optional[httpx.AsyncClient] = None
292
+
293
+ # Extract instance from URL for compatibility
294
+ if self.auth_strategy.instance_url:
295
+ if "://" in self.auth_strategy.instance_url:
296
+ self.instance = self.auth_strategy.instance_url.split("://")[1].split(
297
+ "/"
298
+ )[0]
299
+ else:
300
+ self.instance = self.auth_strategy.instance_url.split("/")[0]
301
+ else:
302
+ self.instance = None
303
+
304
+ # Initialize API modules
305
+ self._describe_api = None
306
+ self._bulk_v2_api = None
307
+ self._query_api = None
308
+
309
+ @property
310
+ def instance_url(self) -> Optional[str]:
311
+ """Get the instance URL from the auth strategy."""
312
+ return self.auth_strategy.instance_url
313
+
314
+ @property
315
+ def access_token(self) -> Optional[str]:
316
+ """Get the current access token from the auth strategy."""
317
+ return self.auth_strategy.access_token
318
+
319
+ @property
320
+ def describe(self):
321
+ """Access to Salesforce Describe API methods."""
322
+ if self._describe_api is None:
323
+ from .api.describe import DescribeAPI
324
+
325
+ self._describe_api = DescribeAPI(self)
326
+ return self._describe_api
327
+
328
+ @property
329
+ def bulk_v2(self):
330
+ """Access to Salesforce Bulk API v2 methods."""
331
+ if self._bulk_v2_api is None:
332
+ from .api.bulk_v2 import BulkV2API
333
+
334
+ self._bulk_v2_api = BulkV2API(self)
335
+ return self._bulk_v2_api
336
+
337
+ @property
338
+ def query(self):
339
+ """Access to Salesforce Query API methods."""
340
+ if self._query_api is None:
341
+ from .api.query import QueryAPI
342
+
343
+ self._query_api = QueryAPI(self)
344
+ return self._query_api
345
+
346
+ @property
347
+ def http_client(self) -> httpx.AsyncClient:
348
+ """Get or create the HTTP client."""
349
+ if self._http_client is None or self._http_client.is_closed:
350
+ self._http_client = httpx.AsyncClient(timeout=self.timeout)
351
+ return self._http_client
352
+
353
+ async def close(self):
354
+ """Close the HTTP client and clean up resources."""
355
+ if self._http_client and not self._http_client.is_closed:
356
+ await self._http_client.aclose()
357
+ self._http_client = None
358
+
359
+ async def __aenter__(self):
360
+ """Async context manager entry."""
361
+ return self
362
+
363
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
364
+ """Async context manager exit."""
365
+ await self.close()
366
+
367
+ async def ensure_authenticated(self) -> str:
368
+ """
369
+ Ensure the connection has a valid access token using the configured auth strategy.
370
+
371
+ :returns: Valid access token
372
+ :raises: SalesforceAuthError if authentication/refresh fails
373
+ """
374
+ return await self.auth_strategy.refresh_if_needed(self.http_client)
375
+
376
+ @property
377
+ def headers(self) -> Dict[str, str]:
378
+ """Get the standard headers for API requests."""
379
+ if not self.access_token:
380
+ raise ValueError(
381
+ "No access token available. Call ensure_authenticated() first."
382
+ )
383
+ return {
384
+ "Authorization": f"Bearer {self.access_token}",
385
+ "Content-Type": "application/json",
386
+ "Accept": "application/json",
387
+ }
388
+
389
+ async def get_authenticated_headers(self) -> Dict[str, str]:
390
+ """
391
+ Get headers with a valid access token, ensuring authentication first.
392
+
393
+ :returns: Headers dictionary with valid Bearer token
394
+ """
395
+ await self.ensure_authenticated()
396
+ return self.headers
397
+
398
+ def get_base_url(self, api_version: Optional[str] = None) -> str:
399
+ """
400
+ Get the base URL for API requests.
401
+
402
+ :param api_version: API version to use (defaults to connection version)
403
+ :returns: Base URL for Salesforce API
404
+ """
405
+ if not self.instance_url:
406
+ raise ValueError("instance_url is required to build API URLs")
407
+ effective_version = api_version or self.version
408
+ return f"{self.instance_url}/services/data/{effective_version}"
409
+
410
+ def get_sobject_url(
411
+ self, sobject_type: str, api_version: Optional[str] = None
412
+ ) -> str:
413
+ """
414
+ Get the URL for sobject operations.
415
+
416
+ :param sobject_type: Salesforce object type (e.g., 'Account', 'Contact')
417
+ :param api_version: API version to use (defaults to connection version)
418
+ :returns: URL for sobject operations
419
+ """
420
+ base_url = self.get_base_url(api_version)
421
+ return f"{base_url}/sobjects/{sobject_type}"
422
+
423
+ def get_describe_url(
424
+ self, sobject_type: str, api_version: Optional[str] = None
425
+ ) -> str:
426
+ """
427
+ Get the URL for describing a Salesforce object.
428
+
429
+ :param sobject_type: Salesforce object type (e.g., 'Account', 'Contact')
430
+ :param api_version: API version to use (defaults to connection version)
431
+ :returns: URL for describe operation
432
+ """
433
+ sobject_url = self.get_sobject_url(sobject_type, api_version)
434
+ return f"{sobject_url}/describe"
435
+
436
+ async def request(
437
+ self,
438
+ method: str,
439
+ url: str,
440
+ headers: Optional[Dict[str, str]] = None,
441
+ auto_auth: bool = True,
442
+ **kwargs,
443
+ ) -> httpx.Response:
444
+ """
445
+ Make an authenticated HTTP request through the connection.
446
+
447
+ :param method: HTTP method (GET, POST, etc.)
448
+ :param url: Full URL to request
449
+ :param headers: Additional headers (will be merged with auth headers)
450
+ :param auto_auth: Whether to automatically ensure authentication
451
+ :param kwargs: Additional arguments passed to httpx request
452
+ :returns: HTTP response
453
+ :raises: SalesforceAuthError if authentication fails
454
+ """
455
+ if auto_auth:
456
+ await self.ensure_authenticated()
457
+
458
+ # Merge auth headers with any additional headers
459
+ request_headers = self.headers.copy()
460
+ if headers:
461
+ request_headers.update(headers)
462
+
463
+ client = self.http_client
464
+ response = await client.request(method, url, headers=request_headers, **kwargs)
465
+
466
+ # If we get a 401, try to re-authenticate once and retry
467
+ if response.status_code == 401 and auto_auth:
468
+ logger.info("Got 401 response, attempting to re-authenticate")
469
+ try:
470
+ # Force re-authentication by clearing current token in the strategy
471
+ old_token = self.auth_strategy.access_token
472
+ self.auth_strategy.access_token = None
473
+ self.auth_strategy.expires_at = None
474
+
475
+ await self.ensure_authenticated()
476
+
477
+ # Retry the request with new token
478
+ request_headers = self.headers.copy()
479
+ if headers:
480
+ request_headers.update(headers)
481
+
482
+ response = await client.request(
483
+ method, url, headers=request_headers, **kwargs
484
+ )
485
+
486
+ except Exception as e:
487
+ logger.error(f"Re-authentication failed: {e}")
488
+ # If re-auth fails and we had an old token, restore it
489
+ if old_token:
490
+ self.auth_strategy.access_token = old_token
491
+ raise SalesforceAuthError(
492
+ f"Authentication failed after 401 response: {e}"
493
+ )
494
+
495
+ return response
496
+
497
+ async def get(self, url: str, **kwargs) -> httpx.Response:
498
+ """Make an authenticated GET request."""
499
+ return await self.request("GET", url, **kwargs)
500
+
501
+ async def post(self, url: str, **kwargs) -> httpx.Response:
502
+ """Make an authenticated POST request."""
503
+ return await self.request("POST", url, **kwargs)
504
+
505
+ async def put(self, url: str, **kwargs) -> httpx.Response:
506
+ """Make an authenticated PUT request."""
507
+ return await self.request("PUT", url, **kwargs)
508
+
509
+ async def delete(self, url: str, **kwargs) -> httpx.Response:
510
+ """Make an authenticated DELETE request."""
511
+ return await self.request("DELETE", url, **kwargs)
@@ -0,0 +1,38 @@
1
+ """
2
+ Exporter module for aio-salesforce.
3
+
4
+ This module contains utilities for exporting Salesforce data to various formats.
5
+ The entire module requires optional dependencies (pandas, pyarrow).
6
+ """
7
+
8
+ from .bulk_export import (
9
+ bulk_query,
10
+ get_bulk_fields,
11
+ resume_from_locator,
12
+ write_records_to_csv,
13
+ QueryResult,
14
+ batch_records,
15
+ batch_records_async,
16
+ )
17
+ from .parquet_writer import (
18
+ ParquetWriter,
19
+ create_schema_from_metadata,
20
+ write_query_to_parquet,
21
+ write_query_to_parquet_async,
22
+ salesforce_to_arrow_type,
23
+ )
24
+
25
+ __all__ = [
26
+ "bulk_query",
27
+ "get_bulk_fields",
28
+ "resume_from_locator",
29
+ "write_records_to_csv",
30
+ "QueryResult",
31
+ "batch_records",
32
+ "batch_records_async",
33
+ "ParquetWriter",
34
+ "create_schema_from_metadata",
35
+ "write_query_to_parquet",
36
+ "write_query_to_parquet_async",
37
+ "salesforce_to_arrow_type",
38
+ ]