atlan-application-sdk 0.1.1rc33__py3-none-any.whl → 0.1.1rc35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. application_sdk/activities/__init__.py +3 -2
  2. application_sdk/activities/common/utils.py +21 -1
  3. application_sdk/activities/metadata_extraction/base.py +104 -0
  4. application_sdk/activities/metadata_extraction/sql.py +13 -12
  5. application_sdk/activities/query_extraction/sql.py +24 -20
  6. application_sdk/application/__init__.py +8 -0
  7. application_sdk/clients/atlan_auth.py +2 -2
  8. application_sdk/clients/base.py +293 -0
  9. application_sdk/clients/temporal.py +6 -10
  10. application_sdk/handlers/base.py +50 -0
  11. application_sdk/inputs/json.py +6 -4
  12. application_sdk/inputs/parquet.py +16 -13
  13. application_sdk/outputs/__init__.py +6 -3
  14. application_sdk/outputs/json.py +9 -6
  15. application_sdk/outputs/parquet.py +10 -36
  16. application_sdk/server/fastapi/__init__.py +4 -5
  17. application_sdk/server/fastapi/models.py +1 -1
  18. application_sdk/services/__init__.py +18 -0
  19. application_sdk/{outputs → services}/atlan_storage.py +64 -16
  20. application_sdk/{outputs → services}/eventstore.py +68 -6
  21. application_sdk/services/objectstore.py +407 -0
  22. application_sdk/services/secretstore.py +344 -0
  23. application_sdk/services/statestore.py +267 -0
  24. application_sdk/version.py +1 -1
  25. application_sdk/worker.py +1 -1
  26. {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/METADATA +1 -1
  27. {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/RECORD +30 -30
  28. application_sdk/common/credential_utils.py +0 -85
  29. application_sdk/inputs/objectstore.py +0 -238
  30. application_sdk/inputs/secretstore.py +0 -130
  31. application_sdk/inputs/statestore.py +0 -101
  32. application_sdk/outputs/objectstore.py +0 -125
  33. application_sdk/outputs/secretstore.py +0 -38
  34. application_sdk/outputs/statestore.py +0 -113
  35. {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/WHEEL +0 -0
  36. {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/licenses/LICENSE +0 -0
  37. {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/licenses/NOTICE +0 -0
@@ -29,7 +29,6 @@ from application_sdk.activities.common.utils import (
29
29
  from application_sdk.common.error_codes import OrchestratorError
30
30
  from application_sdk.constants import TEMPORARY_PATH
31
31
  from application_sdk.handlers import HandlerInterface
32
- from application_sdk.inputs.statestore import StateStoreInput, StateType
33
32
  from application_sdk.observability.logger_adaptor import get_logger
34
33
 
35
34
  logger = get_logger(__name__)
@@ -190,7 +189,9 @@ class ActivitiesInterface(ABC, Generic[ActivitiesStateType]):
190
189
 
191
190
  try:
192
191
  # This already handles the Dapr call internally
193
- workflow_args = StateStoreInput.get_state(workflow_id, StateType.WORKFLOWS)
192
+ from application_sdk.services.statestore import StateStore, StateType
193
+
194
+ workflow_args = await StateStore.get_state(workflow_id, StateType.WORKFLOWS)
194
195
  workflow_args["output_prefix"] = workflow_args.get(
195
196
  "output_prefix", TEMPORARY_PATH
196
197
  )
@@ -5,13 +5,18 @@ including workflow ID retrieval, automatic heartbeating, and periodic heartbeat
5
5
  """
6
6
 
7
7
  import asyncio
8
+ import os
8
9
  from datetime import timedelta
9
10
  from functools import wraps
10
11
  from typing import Any, Awaitable, Callable, Optional, TypeVar, cast
11
12
 
12
13
  from temporalio import activity
13
14
 
14
- from application_sdk.constants import APPLICATION_NAME, WORKFLOW_OUTPUT_PATH_TEMPLATE
15
+ from application_sdk.constants import (
16
+ APPLICATION_NAME,
17
+ TEMPORARY_PATH,
18
+ WORKFLOW_OUTPUT_PATH_TEMPLATE,
19
+ )
15
20
  from application_sdk.observability.logger_adaptor import get_logger
16
21
 
17
22
  logger = get_logger(__name__)
@@ -72,6 +77,21 @@ def build_output_path() -> str:
72
77
  )
73
78
 
74
79
 
80
+ def get_object_store_prefix(path: str) -> str:
81
+ """Get the object store prefix for the path.
82
+ Args:
83
+ path: The path to the output directory.
84
+
85
+ Returns:
86
+ The object store prefix for the path.
87
+
88
+ Example:
89
+ >>> get_object_store_prefix("./local/tmp/artifacts/apps/appName/workflows/wf-123/run-456")
90
+ "artifacts/apps/appName/workflows/wf-123/run-456"
91
+ """
92
+ return os.path.relpath(path, TEMPORARY_PATH)
93
+
94
+
75
95
  def auto_heartbeater(fn: F) -> F:
76
96
  """Decorator that automatically sends heartbeats during activity execution.
77
97
 
@@ -0,0 +1,104 @@
1
+ from typing import Any, Dict, Optional, Type
2
+
3
+ from temporalio import activity
4
+
5
+ from application_sdk.activities import ActivitiesInterface, ActivitiesState
6
+ from application_sdk.activities.common.utils import get_workflow_id
7
+ from application_sdk.clients.base import BaseClient
8
+ from application_sdk.constants import APP_TENANT_ID, APPLICATION_NAME
9
+ from application_sdk.handlers.base import BaseHandler
10
+ from application_sdk.observability.logger_adaptor import get_logger
11
+ from application_sdk.services.secretstore import SecretStore
12
+ from application_sdk.transformers import TransformerInterface
13
+
14
+ logger = get_logger(__name__)
15
+ activity.logger = logger
16
+
17
+
18
+ class BaseMetadataExtractionActivitiesState(ActivitiesState):
19
+ """State for base metadata extraction activities."""
20
+
21
+ client: Optional[BaseClient] = None
22
+ handler: Optional[BaseHandler] = None
23
+ transformer: Optional[TransformerInterface] = None
24
+
25
+
26
+ class BaseMetadataExtractionActivities(ActivitiesInterface):
27
+ """Base activities for non-SQL metadata extraction workflows."""
28
+
29
+ _state: Dict[str, BaseMetadataExtractionActivitiesState] = {}
30
+
31
+ client_class: Type[BaseClient] = BaseClient
32
+ handler_class: Type[BaseHandler] = BaseHandler
33
+ transformer_class: Optional[Type[TransformerInterface]] = None
34
+
35
+ def __init__(
36
+ self,
37
+ client_class: Optional[Type[BaseClient]] = None,
38
+ handler_class: Optional[Type[BaseHandler]] = None,
39
+ transformer_class: Optional[Type[TransformerInterface]] = None,
40
+ ):
41
+ """Initialize the base metadata extraction activities.
42
+
43
+ Args:
44
+ client_class: Client class to use. Defaults to BaseClient.
45
+ handler_class: Handler class to use. Defaults to BaseHandler.
46
+ transformer_class: Transformer class to use. Users must provide their own transformer implementation.
47
+ """
48
+ if client_class:
49
+ self.client_class = client_class
50
+ if handler_class:
51
+ self.handler_class = handler_class
52
+ if transformer_class:
53
+ self.transformer_class = transformer_class
54
+
55
+ super().__init__()
56
+
57
+ async def _set_state(self, workflow_args: Dict[str, Any]):
58
+ """Set up the state for the current workflow.
59
+
60
+ Args:
61
+ workflow_args: Arguments for the workflow.
62
+ """
63
+ workflow_id = get_workflow_id()
64
+ if not self._state.get(workflow_id):
65
+ self._state[workflow_id] = BaseMetadataExtractionActivitiesState()
66
+
67
+ await super()._set_state(workflow_args)
68
+
69
+ state = self._state[workflow_id]
70
+
71
+ # Initialize client
72
+ client = self.client_class()
73
+ # Extract credentials from state store if credential_guid is available
74
+ if "credential_guid" in workflow_args:
75
+ logger.info(
76
+ f"Retrieving credentials for credential_guid: {workflow_args['credential_guid']}"
77
+ )
78
+ try:
79
+ credentials = await SecretStore.get_credentials(
80
+ workflow_args["credential_guid"]
81
+ )
82
+ logger.info(
83
+ f"Successfully retrieved credentials with keys: {list(credentials.keys())}"
84
+ )
85
+ # Load the client with credentials
86
+ await client.load(credentials=credentials)
87
+ except Exception as e:
88
+ logger.error(f"Failed to retrieve credentials: {e}")
89
+ raise
90
+
91
+ state.client = client
92
+
93
+ # Initialize handler
94
+ handler = self.handler_class(client=client)
95
+ state.handler = handler
96
+
97
+ # Initialize transformer if provided
98
+ if self.transformer_class:
99
+ transformer_params = {
100
+ "connector_name": APPLICATION_NAME,
101
+ "connector_type": APPLICATION_NAME,
102
+ "tenant_id": APP_TENANT_ID,
103
+ }
104
+ state.transformer = self.transformer_class(**transformer_params)
@@ -5,25 +5,24 @@ from temporalio import activity
5
5
 
6
6
  from application_sdk.activities import ActivitiesInterface, ActivitiesState
7
7
  from application_sdk.activities.common.models import ActivityStatistics
8
- from application_sdk.activities.common.utils import auto_heartbeater, get_workflow_id
8
+ from application_sdk.activities.common.utils import (
9
+ auto_heartbeater,
10
+ get_object_store_prefix,
11
+ get_workflow_id,
12
+ )
9
13
  from application_sdk.clients.sql import BaseSQLClient
10
- from application_sdk.common.credential_utils import get_credentials
11
14
  from application_sdk.common.dataframe_utils import is_empty_dataframe
12
15
  from application_sdk.common.error_codes import ActivityError
13
16
  from application_sdk.common.utils import prepare_query, read_sql_files
14
- from application_sdk.constants import (
15
- APP_TENANT_ID,
16
- APPLICATION_NAME,
17
- SQL_QUERIES_PATH,
18
- TEMPORARY_PATH,
19
- )
17
+ from application_sdk.constants import APP_TENANT_ID, APPLICATION_NAME, SQL_QUERIES_PATH
20
18
  from application_sdk.handlers.sql import BaseSQLHandler
21
19
  from application_sdk.inputs.parquet import ParquetInput
22
20
  from application_sdk.inputs.sql_query import SQLQueryInput
23
21
  from application_sdk.observability.logger_adaptor import get_logger
24
- from application_sdk.outputs.atlan_storage import AtlanStorageOutput
25
22
  from application_sdk.outputs.json import JsonOutput
26
23
  from application_sdk.outputs.parquet import ParquetOutput
24
+ from application_sdk.services.atlan_storage import AtlanStorage
25
+ from application_sdk.services.secretstore import SecretStore
27
26
  from application_sdk.transformers import TransformerInterface
28
27
  from application_sdk.transformers.query import QueryBasedTransformer
29
28
 
@@ -144,7 +143,9 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
144
143
  self._state[workflow_id].handler = handler
145
144
 
146
145
  if "credential_guid" in workflow_args:
147
- credentials = await get_credentials(workflow_args["credential_guid"])
146
+ credentials = await SecretStore.get_credentials(
147
+ workflow_args["credential_guid"]
148
+ )
148
149
  await sql_client.load(credentials)
149
150
 
150
151
  self._state[workflow_id].sql_client = sql_client
@@ -536,11 +537,11 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
536
537
 
537
538
  # Upload data from object store to Atlan storage
538
539
  # Use workflow_id/workflow_run_id as the prefix to migrate specific data
539
- migration_prefix = os.path.relpath(workflow_args["output_path"], TEMPORARY_PATH)
540
+ migration_prefix = get_object_store_prefix(workflow_args["output_path"])
540
541
  logger.info(
541
542
  f"Starting migration from object store with prefix: {migration_prefix}"
542
543
  )
543
- upload_stats = await AtlanStorageOutput.migrate_from_objectstore_to_atlan(
544
+ upload_stats = await AtlanStorage.migrate_from_objectstore_to_atlan(
544
545
  prefix=migration_prefix
545
546
  )
546
547
 
@@ -7,17 +7,20 @@ from pydantic import BaseModel, Field
7
7
  from temporalio import activity
8
8
 
9
9
  from application_sdk.activities import ActivitiesInterface, ActivitiesState
10
- from application_sdk.activities.common.utils import auto_heartbeater, get_workflow_id
10
+ from application_sdk.activities.common.utils import (
11
+ auto_heartbeater,
12
+ get_object_store_prefix,
13
+ get_workflow_id,
14
+ )
11
15
  from application_sdk.clients.sql import BaseSQLClient
12
- from application_sdk.common.credential_utils import get_credentials
13
16
  from application_sdk.constants import UPSTREAM_OBJECT_STORE_NAME
14
17
  from application_sdk.handlers import HandlerInterface
15
18
  from application_sdk.handlers.sql import BaseSQLHandler
16
- from application_sdk.inputs.objectstore import ObjectStoreInput
17
19
  from application_sdk.inputs.sql_query import SQLQueryInput
18
20
  from application_sdk.observability.logger_adaptor import get_logger
19
- from application_sdk.outputs.objectstore import ObjectStoreOutput
20
21
  from application_sdk.outputs.parquet import ParquetOutput
22
+ from application_sdk.services.objectstore import ObjectStore
23
+ from application_sdk.services.secretstore import SecretStore
21
24
  from application_sdk.transformers import TransformerInterface
22
25
  from application_sdk.transformers.atlas import AtlasTransformer
23
26
 
@@ -129,7 +132,9 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
129
132
  workflow_id = get_workflow_id()
130
133
  sql_client = self.sql_client_class()
131
134
  if "credential_guid" in workflow_args:
132
- credentials = await get_credentials(workflow_args["credential_guid"])
135
+ credentials = await SecretStore.get_credentials(
136
+ workflow_args["credential_guid"]
137
+ )
133
138
  await sql_client.load(credentials)
134
139
 
135
140
  handler = self.handler_class(sql_client)
@@ -412,14 +417,14 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
412
417
  f.write(last_marker)
413
418
 
414
419
  logger.info(f"Last marker: {last_marker}")
415
- await ObjectStoreOutput.push_file_to_object_store(
416
- workflow_args["output_prefix"],
417
- marker_file_path,
418
- object_store_name=UPSTREAM_OBJECT_STORE_NAME,
420
+ await ObjectStore.upload_file(
421
+ source=marker_file_path,
422
+ destination=get_object_store_prefix(marker_file_path),
423
+ store_name=UPSTREAM_OBJECT_STORE_NAME,
419
424
  )
420
425
  logger.info(f"Marker file written to {marker_file_path}")
421
426
 
422
- def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]:
427
+ async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]:
423
428
  """Read the marker from the output path.
424
429
 
425
430
  This method reads the current marker value from a marker file to determine the
@@ -442,15 +447,12 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
442
447
  marker_file_path = os.path.join(output_path, "markerfile")
443
448
  logger.info(f"Downloading marker file from {marker_file_path}")
444
449
 
445
- os.makedirs(workflow_args["output_prefix"], exist_ok=True)
446
-
447
- ObjectStoreInput.download_file_from_object_store(
448
- workflow_args["output_prefix"],
449
- marker_file_path,
450
- object_store_name=UPSTREAM_OBJECT_STORE_NAME,
450
+ await ObjectStore.download_file(
451
+ source=get_object_store_prefix(marker_file_path),
452
+ destination=marker_file_path,
453
+ store_name=UPSTREAM_OBJECT_STORE_NAME,
451
454
  )
452
455
 
453
- logger.info(f"Output prefix: {workflow_args['output_prefix']}")
454
456
  logger.info(f"Marker file downloaded to {marker_file_path}")
455
457
  if not os.path.exists(marker_file_path):
456
458
  logger.warning(f"Marker file does not exist at {marker_file_path}")
@@ -487,7 +489,7 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
487
489
 
488
490
  miner_args = MinerArgs(**workflow_args.get("miner_args", {}))
489
491
 
490
- current_marker = self.read_marker(workflow_args)
492
+ current_marker = await self.read_marker(workflow_args)
491
493
  if current_marker:
492
494
  miner_args.current_marker = current_marker
493
495
 
@@ -522,8 +524,10 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
522
524
  with open(metadata_file_path, "w") as f:
523
525
  f.write(json.dumps(parallel_markers))
524
526
 
525
- await ObjectStoreOutput.push_file_to_object_store(
526
- workflow_args["output_prefix"], metadata_file_path
527
+ await ObjectStore.upload_file(
528
+ source=metadata_file_path,
529
+ destination=get_object_store_prefix(metadata_file_path),
530
+ store_name=UPSTREAM_OBJECT_STORE_NAME,
527
531
  )
528
532
 
529
533
  try:
@@ -2,8 +2,10 @@ from concurrent.futures import ThreadPoolExecutor
2
2
  from typing import Any, Dict, List, Optional, Tuple, Type
3
3
 
4
4
  from application_sdk.activities import ActivitiesInterface
5
+ from application_sdk.clients.base import BaseClient
5
6
  from application_sdk.clients.utils import get_workflow_client
6
7
  from application_sdk.events.models import EventRegistration
8
+ from application_sdk.handlers.base import BaseHandler
7
9
  from application_sdk.observability.logger_adaptor import get_logger
8
10
  from application_sdk.server import ServerInterface
9
11
  from application_sdk.server.fastapi import APIServer, HttpWorkflowTrigger
@@ -28,6 +30,8 @@ class BaseApplication:
28
30
  name: str,
29
31
  server: Optional[ServerInterface] = None,
30
32
  application_manifest: Optional[dict] = None,
33
+ client_class: Optional[Type[BaseClient]] = None,
34
+ handler_class: Optional[Type[BaseHandler]] = None,
31
35
  ):
32
36
  """
33
37
  Initialize the application.
@@ -48,6 +52,9 @@ class BaseApplication:
48
52
  self.application_manifest: Dict[str, Any] = application_manifest
49
53
  self.bootstrap_event_registration()
50
54
 
55
+ self.client_class = client_class or BaseClient
56
+ self.handler_class = handler_class or BaseHandler
57
+
51
58
  def bootstrap_event_registration(self):
52
59
  self.event_subscriptions = {}
53
60
  if self.application_manifest is None:
@@ -168,6 +175,7 @@ class BaseApplication:
168
175
  self.server = APIServer(
169
176
  workflow_client=self.workflow_client,
170
177
  ui_enabled=ui_enabled,
178
+ handler=self.handler_class(client=self.client_class()),
171
179
  )
172
180
 
173
181
  if self.event_subscriptions:
@@ -13,8 +13,8 @@ from application_sdk.constants import (
13
13
  WORKFLOW_AUTH_ENABLED,
14
14
  WORKFLOW_AUTH_URL_KEY,
15
15
  )
16
- from application_sdk.inputs.secretstore import SecretStoreInput
17
16
  from application_sdk.observability.logger_adaptor import get_logger
17
+ from application_sdk.services.secretstore import SecretStore
18
18
 
19
19
  logger = get_logger(__name__)
20
20
 
@@ -39,7 +39,7 @@ class AtlanAuthClient:
39
39
  (environment variables, AWS Secrets Manager, Azure Key Vault, etc.)
40
40
  """
41
41
  self.application_name = APPLICATION_NAME
42
- self.auth_config: Dict[str, Any] = SecretStoreInput.get_deployment_secret()
42
+ self.auth_config: Dict[str, Any] = SecretStore.get_deployment_secret()
43
43
  self.auth_enabled: bool = WORKFLOW_AUTH_ENABLED
44
44
  self.auth_url: Optional[str] = None
45
45
 
@@ -0,0 +1,293 @@
1
+ from typing import Any, Dict, Optional
2
+
3
+ import httpx
4
+ from httpx import Headers
5
+ from httpx._types import (
6
+ AuthTypes,
7
+ HeaderTypes,
8
+ QueryParamTypes,
9
+ RequestData,
10
+ RequestFiles,
11
+ )
12
+
13
+ from application_sdk.clients import ClientInterface
14
+ from application_sdk.observability.logger_adaptor import get_logger
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ class BaseClient(ClientInterface):
20
+ """
21
+ Base client for non-SQL based applications.
22
+
23
+ This class provides a base implementation for clients that need to connect
24
+ to non-SQL data sources. It implements the ClientInterface and provides
25
+ basic functionality that can be extended by subclasses.
26
+
27
+ Attributes:
28
+ credentials (Dict[str, Any]): Client credentials for authentication.
29
+ http_headers (HeaderTypes): HTTP headers for all http requests made by this client. Supports dict, Headers object, or list of tuples.
30
+ http_retry_transporter (httpx.AsyncBaseTransport): HTTP transport for requests. Uses httpx default transport by default.
31
+ Can be overridden in load() method for custom retry behavior.
32
+
33
+ Extending the Client:
34
+ To customize retry behavior, subclasses can override the http_retry_transporter
35
+ in the load() method, similar to how http_headers is set:
36
+
37
+ Example:
38
+ >>> class MyClient(BaseClient):
39
+ ... async def load(self, **kwargs):
40
+ ... # Set up HTTP headers in load method for better modularity
41
+ ... credentials = kwargs.get("credentials", {})
42
+ ... # Can use dict, Headers object, or list of tuples
43
+ ... self.http_headers = {
44
+ ... "Authorization": f"Bearer {credentials.get('token')}",
45
+ ... "User-Agent": "MyApp/1.0"
46
+ ... }
47
+ ... # Optionally override retry transport with custom configuration
48
+ ... # For advanced retry logic with status code handling, use httpx-retries:
49
+ ... # from httpx_retries import Retry, RetryTransport
50
+ ... # retry = Retry(total=5, backoff_factor=20)
51
+ ... # self.http_retry_transporter = RetryTransport(retry=retry) #replace transport with custom transport if needed
52
+
53
+ Advanced Retry Configuration:
54
+ For applications requiring advanced retry logic (e.g., status code-based retries,
55
+ rate limiting, custom backoff strategies), consider using httpx-retries library:
56
+
57
+ >>> class MyClient(BaseClient):
58
+ ... async def load(self, **kwargs):
59
+ ... # Set up headers
60
+ ... self.http_headers = {"Authorization": f"Bearer {kwargs.get('token')}"}
61
+ ...
62
+ ... # Install httpx-retries: pip install httpx-retries
63
+ ... from httpx_retries import Retry, RetryTransport
64
+ ...
65
+ ... # Configure retry for status codes and network errors
66
+ ... retry = Retry(
67
+ ... total=5,
68
+ ... backoff_factor=10,
69
+ ... status_forcelist=[429, 500, 502, 503, 504]
70
+ ... )
71
+ ... self.http_retry_transporter = RetryTransport(retry=retry)
72
+
73
+ Header Management:
74
+ The client supports a two-level header system using httpx Headers for merging headers:
75
+ - Client-level headers: Set in the load() method and used for all requests
76
+ - Method-level headers: Passed to individual methods and override/add to client headers
77
+
78
+ Example:
79
+ >>> client = MyClient()
80
+ >>> await client.load(credentials={"token": "initial_token"})
81
+ >>> # This request will use: {"Authorization": "Bearer initial_token", "User-Agent": "MyApp/1.0", "Content-Type": "application/json"}
82
+ >>> response = await client.execute_http_post_request(
83
+ ... url="https://api.example.com/data",
84
+ ... headers={"Content-Type": "application/json"}
85
+ ... )
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ credentials: Dict[str, Any] = {},
91
+ http_headers: HeaderTypes = {},
92
+ ):
93
+ """
94
+ Initialize the base client.
95
+
96
+ Args:
97
+ credentials (Dict[str, Any], optional): Client credentials for authentication. Defaults to {}.
98
+ http_headers (HeaderTypes, optional): HTTP headers for all requests. Defaults to {}.
99
+ """
100
+ self.credentials = credentials
101
+ self.http_headers = http_headers
102
+
103
+ # Use httpx default transport (no retries on status codes)
104
+ self.http_retry_transport: httpx.AsyncBaseTransport = httpx.AsyncHTTPTransport()
105
+
106
+ async def load(self, **kwargs: Any) -> None:
107
+ """
108
+ Initialize the client with credentials and necessary attributes for the client to work.
109
+
110
+ This method should be implemented by subclasses to:
111
+ - Set up authentication headers in self.http_headers in case of http requestss
112
+ - Initialize any required client state
113
+ - Handle credential processing
114
+ - Optionally override self.http_retry_transport for custom retry behavior
115
+
116
+ For advanced retry logic (status code-based retries, rate limiting, custom backoff),
117
+ consider using httpx-retries library and overriding http_retry_transport:
118
+
119
+ Example:
120
+ >>> async def load(self, **kwargs):
121
+ ... # Set up headers
122
+ ... self.http_headers = {"Authorization": f"Bearer {kwargs.get('token')}"}
123
+ ...
124
+ ... # For advanced retry logic, install httpx-retries: pip install httpx-retries
125
+ ... from httpx_retries import Retry, RetryTransport
126
+ ... retry = Retry(total=5, backoff_factor=10, status_forcelist=[429, 500, 502, 503, 504])
127
+ ... self.http_retry_transport = RetryTransport(retry=retry)
128
+
129
+ Args:
130
+ **kwargs: Additional keyword arguments, typically including credentials.
131
+ May also include retry configuration parameters that can be used to
132
+ create a custom http_retry_transport.
133
+
134
+ Raises:
135
+ NotImplementedError: If the subclass does not implement this method.
136
+ """
137
+ raise NotImplementedError("load method is not implemented")
138
+
139
+ async def execute_http_get_request(
140
+ self,
141
+ url: str,
142
+ headers: Optional[HeaderTypes] = None,
143
+ params: Optional[QueryParamTypes] = None,
144
+ auth: Optional[AuthTypes] = None,
145
+ timeout: int = 10,
146
+ ) -> Optional[httpx.Response]:
147
+ """
148
+ Perform an HTTP GET request using the configured transport.
149
+
150
+ This method uses httpx default transport which only retries on network-level errors
151
+ (connection failures, timeouts). For status code-based retries (429, 500, etc.),
152
+ consider overriding http_retry_transport in the load() method using httpx-retries library.
153
+
154
+ Args:
155
+ url (str): The URL to make the GET request to
156
+ headers (Optional[HeaderTypes]): HTTP headers to include in the request. Supports dict, Headers object, or list of tuples. These headers will override/add to any client-level headers set in the load() method.
157
+ params (Optional[QueryParamTypes]): Query parameters to include in the request. Supports dict, list of tuples, or string.
158
+ auth (Optional[AuthTypes]): Authentication to use for the request. Supports BasicAuth, DigestAuth, custom auth classes, or tuples for basic auth.
159
+ timeout (int): Request timeout in seconds. Defaults to 10.
160
+
161
+ Returns:
162
+ Optional[httpx.Response]: The HTTP response if successful, None if failed
163
+
164
+ Example:
165
+ >>> # Using Basic Authentication
166
+ >>> from httpx import BasicAuth
167
+ >>> response = await client.execute_http_get_request(
168
+ ... url="https://api.example.com/data",
169
+ ... auth=BasicAuth("username", "password"),
170
+ ... params={"limit": 100}
171
+ ... )
172
+ >>>
173
+ >>> # Using tuple for basic auth (username, password)
174
+ >>> response = await client.execute_http_get_request(
175
+ ... url="https://api.example.com/data",
176
+ ... auth=("username", "password"),
177
+ ... params={"limit": 100}
178
+ ... )
179
+ >>>
180
+ >>> # Using custom headers for Bearer token
181
+ >>> response = await client.execute_http_get_request(
182
+ ... url="https://api.example.com/data",
183
+ ... headers={"Authorization": "Bearer token"},
184
+ ... params={"limit": 100}
185
+ ... )
186
+ """
187
+ async with httpx.AsyncClient(
188
+ timeout=timeout, transport=self.http_retry_transport
189
+ ) as client:
190
+ merged_headers = Headers(self.http_headers)
191
+ if headers:
192
+ merged_headers.update(headers)
193
+
194
+ try:
195
+ response = await client.get(
196
+ url,
197
+ headers=merged_headers,
198
+ params=params,
199
+ auth=auth if auth is not None else httpx.USE_CLIENT_DEFAULT,
200
+ )
201
+ return response
202
+ except httpx.HTTPStatusError as e:
203
+ logger.error(f"HTTP error for {url}: {e.response.status_code}")
204
+ return None
205
+ except Exception as e:
206
+ logger.error(f"Request failed for {url}: {e}")
207
+ return None
208
+
209
+ async def execute_http_post_request(
210
+ self,
211
+ url: str,
212
+ data: Optional[RequestData] = None,
213
+ json_data: Optional[Any] = None,
214
+ content: Optional[bytes] = None,
215
+ files: Optional[RequestFiles] = None,
216
+ headers: Optional[HeaderTypes] = None,
217
+ params: Optional[QueryParamTypes] = None,
218
+ cookies: Optional[Dict[str, str]] = None,
219
+ auth: Optional[AuthTypes] = None,
220
+ follow_redirects: bool = True,
221
+ verify: bool = True,
222
+ timeout: int = 30,
223
+ ) -> Optional[httpx.Response]:
224
+ """
225
+ Perform an HTTP POST request using the configured transport.
226
+
227
+ This method uses httpx default transport which only retries on network-level errors
228
+ (connection failures, timeouts). For status code-based retries (429, 500, etc.),
229
+ consider overriding http_retry_transport in the load() method using httpx-retries library.
230
+
231
+ Args:
232
+ url (str): The URL to make the POST request to
233
+ data (Optional[RequestData]): Form data to send in the request body. Supports dict, list of tuples, or other httpx-compatible formats.
234
+ json_data (Optional[Any]): JSON data to send in the request body. Any JSON-serializable object.
235
+ content (Optional[bytes]): Raw binary content to send in the request body
236
+ files (Optional[RequestFiles]): Files to upload in the request body. Supports various file formats and tuples.
237
+ headers (Optional[HeaderTypes]): HTTP headers to include in the request. Supports dict, Headers object, or list of tuples. These headers will override/add to any client-level headers set in the load() method.
238
+ params (Optional[QueryParamTypes]): Query parameters to include in the request. Supports dict, list of tuples, or string.
239
+ cookies (Optional[Dict[str, str]]): Cookies to include in the request
240
+ auth (Optional[AuthTypes]): Authentication to use for the request. Supports BasicAuth, DigestAuth, custom auth classes, or tuples for basic auth.
241
+ follow_redirects (bool): Whether to follow HTTP redirects. Defaults to True.
242
+ verify (bool): Whether to verify SSL certificates. Defaults to True.
243
+ timeout (int): Request timeout in seconds. Defaults to 30.
244
+
245
+ Returns:
246
+ Optional[httpx.Response]: The HTTP response if successful, None if failed
247
+
248
+ Example:
249
+ >>> # Basic JSON POST request with authentication
250
+ >>> from httpx import BasicAuth
251
+ >>> response = await client.execute_http_post_request(
252
+ ... url="https://api.example.com/data",
253
+ ... json_data={"name": "test", "value": 123},
254
+ ... headers={"Content-Type": "application/json"},
255
+ ... auth=BasicAuth("username", "password")
256
+ ... )
257
+ >>>
258
+ >>> # File upload with basic auth tuple
259
+ >>> with open("file.txt", "rb") as f:
260
+ ... response = await client.execute_http_post_request(
261
+ ... url="https://api.example.com/upload",
262
+ ... data={"description": "My file"},
263
+ ... files={"file": ("file.txt", f.read(), "text/plain")},
264
+ ... auth=("username", "password")
265
+ ... )
266
+ """
267
+ async with httpx.AsyncClient(
268
+ timeout=timeout, transport=self.http_retry_transport, verify=verify
269
+ ) as client:
270
+ merged_headers = Headers(self.http_headers)
271
+ if headers:
272
+ merged_headers.update(headers)
273
+
274
+ try:
275
+ response = await client.post(
276
+ url,
277
+ data=data,
278
+ json=json_data,
279
+ content=content,
280
+ files=files,
281
+ headers=merged_headers,
282
+ params=params,
283
+ cookies=cookies,
284
+ auth=auth if auth is not None else httpx.USE_CLIENT_DEFAULT,
285
+ follow_redirects=follow_redirects,
286
+ )
287
+ return response
288
+ except httpx.HTTPStatusError as e:
289
+ logger.error(f"HTTP error for {url}: {e.response.status_code}")
290
+ return None
291
+ except Exception as e:
292
+ logger.error(f"Request failed for {url}: {e}")
293
+ return None