atlan-application-sdk 0.1.1rc33__py3-none-any.whl → 0.1.1rc35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- application_sdk/activities/__init__.py +3 -2
- application_sdk/activities/common/utils.py +21 -1
- application_sdk/activities/metadata_extraction/base.py +104 -0
- application_sdk/activities/metadata_extraction/sql.py +13 -12
- application_sdk/activities/query_extraction/sql.py +24 -20
- application_sdk/application/__init__.py +8 -0
- application_sdk/clients/atlan_auth.py +2 -2
- application_sdk/clients/base.py +293 -0
- application_sdk/clients/temporal.py +6 -10
- application_sdk/handlers/base.py +50 -0
- application_sdk/inputs/json.py +6 -4
- application_sdk/inputs/parquet.py +16 -13
- application_sdk/outputs/__init__.py +6 -3
- application_sdk/outputs/json.py +9 -6
- application_sdk/outputs/parquet.py +10 -36
- application_sdk/server/fastapi/__init__.py +4 -5
- application_sdk/server/fastapi/models.py +1 -1
- application_sdk/services/__init__.py +18 -0
- application_sdk/{outputs → services}/atlan_storage.py +64 -16
- application_sdk/{outputs → services}/eventstore.py +68 -6
- application_sdk/services/objectstore.py +407 -0
- application_sdk/services/secretstore.py +344 -0
- application_sdk/services/statestore.py +267 -0
- application_sdk/version.py +1 -1
- application_sdk/worker.py +1 -1
- {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/METADATA +1 -1
- {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/RECORD +30 -30
- application_sdk/common/credential_utils.py +0 -85
- application_sdk/inputs/objectstore.py +0 -238
- application_sdk/inputs/secretstore.py +0 -130
- application_sdk/inputs/statestore.py +0 -101
- application_sdk/outputs/objectstore.py +0 -125
- application_sdk/outputs/secretstore.py +0 -38
- application_sdk/outputs/statestore.py +0 -113
- {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/WHEEL +0 -0
- {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/licenses/LICENSE +0 -0
- {atlan_application_sdk-0.1.1rc33.dist-info → atlan_application_sdk-0.1.1rc35.dist-info}/licenses/NOTICE +0 -0
|
@@ -29,7 +29,6 @@ from application_sdk.activities.common.utils import (
|
|
|
29
29
|
from application_sdk.common.error_codes import OrchestratorError
|
|
30
30
|
from application_sdk.constants import TEMPORARY_PATH
|
|
31
31
|
from application_sdk.handlers import HandlerInterface
|
|
32
|
-
from application_sdk.inputs.statestore import StateStoreInput, StateType
|
|
33
32
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
34
33
|
|
|
35
34
|
logger = get_logger(__name__)
|
|
@@ -190,7 +189,9 @@ class ActivitiesInterface(ABC, Generic[ActivitiesStateType]):
|
|
|
190
189
|
|
|
191
190
|
try:
|
|
192
191
|
# This already handles the Dapr call internally
|
|
193
|
-
|
|
192
|
+
from application_sdk.services.statestore import StateStore, StateType
|
|
193
|
+
|
|
194
|
+
workflow_args = await StateStore.get_state(workflow_id, StateType.WORKFLOWS)
|
|
194
195
|
workflow_args["output_prefix"] = workflow_args.get(
|
|
195
196
|
"output_prefix", TEMPORARY_PATH
|
|
196
197
|
)
|
|
@@ -5,13 +5,18 @@ including workflow ID retrieval, automatic heartbeating, and periodic heartbeat
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
import asyncio
|
|
8
|
+
import os
|
|
8
9
|
from datetime import timedelta
|
|
9
10
|
from functools import wraps
|
|
10
11
|
from typing import Any, Awaitable, Callable, Optional, TypeVar, cast
|
|
11
12
|
|
|
12
13
|
from temporalio import activity
|
|
13
14
|
|
|
14
|
-
from application_sdk.constants import
|
|
15
|
+
from application_sdk.constants import (
|
|
16
|
+
APPLICATION_NAME,
|
|
17
|
+
TEMPORARY_PATH,
|
|
18
|
+
WORKFLOW_OUTPUT_PATH_TEMPLATE,
|
|
19
|
+
)
|
|
15
20
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
16
21
|
|
|
17
22
|
logger = get_logger(__name__)
|
|
@@ -72,6 +77,21 @@ def build_output_path() -> str:
|
|
|
72
77
|
)
|
|
73
78
|
|
|
74
79
|
|
|
80
|
+
def get_object_store_prefix(path: str) -> str:
|
|
81
|
+
"""Get the object store prefix for the path.
|
|
82
|
+
Args:
|
|
83
|
+
path: The path to the output directory.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The object store prefix for the path.
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
>>> get_object_store_prefix("./local/tmp/artifacts/apps/appName/workflows/wf-123/run-456")
|
|
90
|
+
"artifacts/apps/appName/workflows/wf-123/run-456"
|
|
91
|
+
"""
|
|
92
|
+
return os.path.relpath(path, TEMPORARY_PATH)
|
|
93
|
+
|
|
94
|
+
|
|
75
95
|
def auto_heartbeater(fn: F) -> F:
|
|
76
96
|
"""Decorator that automatically sends heartbeats during activity execution.
|
|
77
97
|
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Type
|
|
2
|
+
|
|
3
|
+
from temporalio import activity
|
|
4
|
+
|
|
5
|
+
from application_sdk.activities import ActivitiesInterface, ActivitiesState
|
|
6
|
+
from application_sdk.activities.common.utils import get_workflow_id
|
|
7
|
+
from application_sdk.clients.base import BaseClient
|
|
8
|
+
from application_sdk.constants import APP_TENANT_ID, APPLICATION_NAME
|
|
9
|
+
from application_sdk.handlers.base import BaseHandler
|
|
10
|
+
from application_sdk.observability.logger_adaptor import get_logger
|
|
11
|
+
from application_sdk.services.secretstore import SecretStore
|
|
12
|
+
from application_sdk.transformers import TransformerInterface
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
activity.logger = logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseMetadataExtractionActivitiesState(ActivitiesState):
|
|
19
|
+
"""State for base metadata extraction activities."""
|
|
20
|
+
|
|
21
|
+
client: Optional[BaseClient] = None
|
|
22
|
+
handler: Optional[BaseHandler] = None
|
|
23
|
+
transformer: Optional[TransformerInterface] = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class BaseMetadataExtractionActivities(ActivitiesInterface):
|
|
27
|
+
"""Base activities for non-SQL metadata extraction workflows."""
|
|
28
|
+
|
|
29
|
+
_state: Dict[str, BaseMetadataExtractionActivitiesState] = {}
|
|
30
|
+
|
|
31
|
+
client_class: Type[BaseClient] = BaseClient
|
|
32
|
+
handler_class: Type[BaseHandler] = BaseHandler
|
|
33
|
+
transformer_class: Optional[Type[TransformerInterface]] = None
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
client_class: Optional[Type[BaseClient]] = None,
|
|
38
|
+
handler_class: Optional[Type[BaseHandler]] = None,
|
|
39
|
+
transformer_class: Optional[Type[TransformerInterface]] = None,
|
|
40
|
+
):
|
|
41
|
+
"""Initialize the base metadata extraction activities.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
client_class: Client class to use. Defaults to BaseClient.
|
|
45
|
+
handler_class: Handler class to use. Defaults to BaseHandler.
|
|
46
|
+
transformer_class: Transformer class to use. Users must provide their own transformer implementation.
|
|
47
|
+
"""
|
|
48
|
+
if client_class:
|
|
49
|
+
self.client_class = client_class
|
|
50
|
+
if handler_class:
|
|
51
|
+
self.handler_class = handler_class
|
|
52
|
+
if transformer_class:
|
|
53
|
+
self.transformer_class = transformer_class
|
|
54
|
+
|
|
55
|
+
super().__init__()
|
|
56
|
+
|
|
57
|
+
async def _set_state(self, workflow_args: Dict[str, Any]):
|
|
58
|
+
"""Set up the state for the current workflow.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
workflow_args: Arguments for the workflow.
|
|
62
|
+
"""
|
|
63
|
+
workflow_id = get_workflow_id()
|
|
64
|
+
if not self._state.get(workflow_id):
|
|
65
|
+
self._state[workflow_id] = BaseMetadataExtractionActivitiesState()
|
|
66
|
+
|
|
67
|
+
await super()._set_state(workflow_args)
|
|
68
|
+
|
|
69
|
+
state = self._state[workflow_id]
|
|
70
|
+
|
|
71
|
+
# Initialize client
|
|
72
|
+
client = self.client_class()
|
|
73
|
+
# Extract credentials from state store if credential_guid is available
|
|
74
|
+
if "credential_guid" in workflow_args:
|
|
75
|
+
logger.info(
|
|
76
|
+
f"Retrieving credentials for credential_guid: {workflow_args['credential_guid']}"
|
|
77
|
+
)
|
|
78
|
+
try:
|
|
79
|
+
credentials = await SecretStore.get_credentials(
|
|
80
|
+
workflow_args["credential_guid"]
|
|
81
|
+
)
|
|
82
|
+
logger.info(
|
|
83
|
+
f"Successfully retrieved credentials with keys: {list(credentials.keys())}"
|
|
84
|
+
)
|
|
85
|
+
# Load the client with credentials
|
|
86
|
+
await client.load(credentials=credentials)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.error(f"Failed to retrieve credentials: {e}")
|
|
89
|
+
raise
|
|
90
|
+
|
|
91
|
+
state.client = client
|
|
92
|
+
|
|
93
|
+
# Initialize handler
|
|
94
|
+
handler = self.handler_class(client=client)
|
|
95
|
+
state.handler = handler
|
|
96
|
+
|
|
97
|
+
# Initialize transformer if provided
|
|
98
|
+
if self.transformer_class:
|
|
99
|
+
transformer_params = {
|
|
100
|
+
"connector_name": APPLICATION_NAME,
|
|
101
|
+
"connector_type": APPLICATION_NAME,
|
|
102
|
+
"tenant_id": APP_TENANT_ID,
|
|
103
|
+
}
|
|
104
|
+
state.transformer = self.transformer_class(**transformer_params)
|
|
@@ -5,25 +5,24 @@ from temporalio import activity
|
|
|
5
5
|
|
|
6
6
|
from application_sdk.activities import ActivitiesInterface, ActivitiesState
|
|
7
7
|
from application_sdk.activities.common.models import ActivityStatistics
|
|
8
|
-
from application_sdk.activities.common.utils import
|
|
8
|
+
from application_sdk.activities.common.utils import (
|
|
9
|
+
auto_heartbeater,
|
|
10
|
+
get_object_store_prefix,
|
|
11
|
+
get_workflow_id,
|
|
12
|
+
)
|
|
9
13
|
from application_sdk.clients.sql import BaseSQLClient
|
|
10
|
-
from application_sdk.common.credential_utils import get_credentials
|
|
11
14
|
from application_sdk.common.dataframe_utils import is_empty_dataframe
|
|
12
15
|
from application_sdk.common.error_codes import ActivityError
|
|
13
16
|
from application_sdk.common.utils import prepare_query, read_sql_files
|
|
14
|
-
from application_sdk.constants import
|
|
15
|
-
APP_TENANT_ID,
|
|
16
|
-
APPLICATION_NAME,
|
|
17
|
-
SQL_QUERIES_PATH,
|
|
18
|
-
TEMPORARY_PATH,
|
|
19
|
-
)
|
|
17
|
+
from application_sdk.constants import APP_TENANT_ID, APPLICATION_NAME, SQL_QUERIES_PATH
|
|
20
18
|
from application_sdk.handlers.sql import BaseSQLHandler
|
|
21
19
|
from application_sdk.inputs.parquet import ParquetInput
|
|
22
20
|
from application_sdk.inputs.sql_query import SQLQueryInput
|
|
23
21
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
24
|
-
from application_sdk.outputs.atlan_storage import AtlanStorageOutput
|
|
25
22
|
from application_sdk.outputs.json import JsonOutput
|
|
26
23
|
from application_sdk.outputs.parquet import ParquetOutput
|
|
24
|
+
from application_sdk.services.atlan_storage import AtlanStorage
|
|
25
|
+
from application_sdk.services.secretstore import SecretStore
|
|
27
26
|
from application_sdk.transformers import TransformerInterface
|
|
28
27
|
from application_sdk.transformers.query import QueryBasedTransformer
|
|
29
28
|
|
|
@@ -144,7 +143,9 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
144
143
|
self._state[workflow_id].handler = handler
|
|
145
144
|
|
|
146
145
|
if "credential_guid" in workflow_args:
|
|
147
|
-
credentials = await get_credentials(
|
|
146
|
+
credentials = await SecretStore.get_credentials(
|
|
147
|
+
workflow_args["credential_guid"]
|
|
148
|
+
)
|
|
148
149
|
await sql_client.load(credentials)
|
|
149
150
|
|
|
150
151
|
self._state[workflow_id].sql_client = sql_client
|
|
@@ -536,11 +537,11 @@ class BaseSQLMetadataExtractionActivities(ActivitiesInterface):
|
|
|
536
537
|
|
|
537
538
|
# Upload data from object store to Atlan storage
|
|
538
539
|
# Use workflow_id/workflow_run_id as the prefix to migrate specific data
|
|
539
|
-
migration_prefix =
|
|
540
|
+
migration_prefix = get_object_store_prefix(workflow_args["output_path"])
|
|
540
541
|
logger.info(
|
|
541
542
|
f"Starting migration from object store with prefix: {migration_prefix}"
|
|
542
543
|
)
|
|
543
|
-
upload_stats = await
|
|
544
|
+
upload_stats = await AtlanStorage.migrate_from_objectstore_to_atlan(
|
|
544
545
|
prefix=migration_prefix
|
|
545
546
|
)
|
|
546
547
|
|
|
@@ -7,17 +7,20 @@ from pydantic import BaseModel, Field
|
|
|
7
7
|
from temporalio import activity
|
|
8
8
|
|
|
9
9
|
from application_sdk.activities import ActivitiesInterface, ActivitiesState
|
|
10
|
-
from application_sdk.activities.common.utils import
|
|
10
|
+
from application_sdk.activities.common.utils import (
|
|
11
|
+
auto_heartbeater,
|
|
12
|
+
get_object_store_prefix,
|
|
13
|
+
get_workflow_id,
|
|
14
|
+
)
|
|
11
15
|
from application_sdk.clients.sql import BaseSQLClient
|
|
12
|
-
from application_sdk.common.credential_utils import get_credentials
|
|
13
16
|
from application_sdk.constants import UPSTREAM_OBJECT_STORE_NAME
|
|
14
17
|
from application_sdk.handlers import HandlerInterface
|
|
15
18
|
from application_sdk.handlers.sql import BaseSQLHandler
|
|
16
|
-
from application_sdk.inputs.objectstore import ObjectStoreInput
|
|
17
19
|
from application_sdk.inputs.sql_query import SQLQueryInput
|
|
18
20
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
19
|
-
from application_sdk.outputs.objectstore import ObjectStoreOutput
|
|
20
21
|
from application_sdk.outputs.parquet import ParquetOutput
|
|
22
|
+
from application_sdk.services.objectstore import ObjectStore
|
|
23
|
+
from application_sdk.services.secretstore import SecretStore
|
|
21
24
|
from application_sdk.transformers import TransformerInterface
|
|
22
25
|
from application_sdk.transformers.atlas import AtlasTransformer
|
|
23
26
|
|
|
@@ -129,7 +132,9 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
|
|
|
129
132
|
workflow_id = get_workflow_id()
|
|
130
133
|
sql_client = self.sql_client_class()
|
|
131
134
|
if "credential_guid" in workflow_args:
|
|
132
|
-
credentials = await get_credentials(
|
|
135
|
+
credentials = await SecretStore.get_credentials(
|
|
136
|
+
workflow_args["credential_guid"]
|
|
137
|
+
)
|
|
133
138
|
await sql_client.load(credentials)
|
|
134
139
|
|
|
135
140
|
handler = self.handler_class(sql_client)
|
|
@@ -412,14 +417,14 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
|
|
|
412
417
|
f.write(last_marker)
|
|
413
418
|
|
|
414
419
|
logger.info(f"Last marker: {last_marker}")
|
|
415
|
-
await
|
|
416
|
-
|
|
417
|
-
marker_file_path,
|
|
418
|
-
|
|
420
|
+
await ObjectStore.upload_file(
|
|
421
|
+
source=marker_file_path,
|
|
422
|
+
destination=get_object_store_prefix(marker_file_path),
|
|
423
|
+
store_name=UPSTREAM_OBJECT_STORE_NAME,
|
|
419
424
|
)
|
|
420
425
|
logger.info(f"Marker file written to {marker_file_path}")
|
|
421
426
|
|
|
422
|
-
def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]:
|
|
427
|
+
async def read_marker(self, workflow_args: Dict[str, Any]) -> Optional[int]:
|
|
423
428
|
"""Read the marker from the output path.
|
|
424
429
|
|
|
425
430
|
This method reads the current marker value from a marker file to determine the
|
|
@@ -442,15 +447,12 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
|
|
|
442
447
|
marker_file_path = os.path.join(output_path, "markerfile")
|
|
443
448
|
logger.info(f"Downloading marker file from {marker_file_path}")
|
|
444
449
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
marker_file_path,
|
|
450
|
-
object_store_name=UPSTREAM_OBJECT_STORE_NAME,
|
|
450
|
+
await ObjectStore.download_file(
|
|
451
|
+
source=get_object_store_prefix(marker_file_path),
|
|
452
|
+
destination=marker_file_path,
|
|
453
|
+
store_name=UPSTREAM_OBJECT_STORE_NAME,
|
|
451
454
|
)
|
|
452
455
|
|
|
453
|
-
logger.info(f"Output prefix: {workflow_args['output_prefix']}")
|
|
454
456
|
logger.info(f"Marker file downloaded to {marker_file_path}")
|
|
455
457
|
if not os.path.exists(marker_file_path):
|
|
456
458
|
logger.warning(f"Marker file does not exist at {marker_file_path}")
|
|
@@ -487,7 +489,7 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
|
|
|
487
489
|
|
|
488
490
|
miner_args = MinerArgs(**workflow_args.get("miner_args", {}))
|
|
489
491
|
|
|
490
|
-
current_marker = self.read_marker(workflow_args)
|
|
492
|
+
current_marker = await self.read_marker(workflow_args)
|
|
491
493
|
if current_marker:
|
|
492
494
|
miner_args.current_marker = current_marker
|
|
493
495
|
|
|
@@ -522,8 +524,10 @@ class SQLQueryExtractionActivities(ActivitiesInterface):
|
|
|
522
524
|
with open(metadata_file_path, "w") as f:
|
|
523
525
|
f.write(json.dumps(parallel_markers))
|
|
524
526
|
|
|
525
|
-
await
|
|
526
|
-
|
|
527
|
+
await ObjectStore.upload_file(
|
|
528
|
+
source=metadata_file_path,
|
|
529
|
+
destination=get_object_store_prefix(metadata_file_path),
|
|
530
|
+
store_name=UPSTREAM_OBJECT_STORE_NAME,
|
|
527
531
|
)
|
|
528
532
|
|
|
529
533
|
try:
|
|
@@ -2,8 +2,10 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
2
2
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
3
3
|
|
|
4
4
|
from application_sdk.activities import ActivitiesInterface
|
|
5
|
+
from application_sdk.clients.base import BaseClient
|
|
5
6
|
from application_sdk.clients.utils import get_workflow_client
|
|
6
7
|
from application_sdk.events.models import EventRegistration
|
|
8
|
+
from application_sdk.handlers.base import BaseHandler
|
|
7
9
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
8
10
|
from application_sdk.server import ServerInterface
|
|
9
11
|
from application_sdk.server.fastapi import APIServer, HttpWorkflowTrigger
|
|
@@ -28,6 +30,8 @@ class BaseApplication:
|
|
|
28
30
|
name: str,
|
|
29
31
|
server: Optional[ServerInterface] = None,
|
|
30
32
|
application_manifest: Optional[dict] = None,
|
|
33
|
+
client_class: Optional[Type[BaseClient]] = None,
|
|
34
|
+
handler_class: Optional[Type[BaseHandler]] = None,
|
|
31
35
|
):
|
|
32
36
|
"""
|
|
33
37
|
Initialize the application.
|
|
@@ -48,6 +52,9 @@ class BaseApplication:
|
|
|
48
52
|
self.application_manifest: Dict[str, Any] = application_manifest
|
|
49
53
|
self.bootstrap_event_registration()
|
|
50
54
|
|
|
55
|
+
self.client_class = client_class or BaseClient
|
|
56
|
+
self.handler_class = handler_class or BaseHandler
|
|
57
|
+
|
|
51
58
|
def bootstrap_event_registration(self):
|
|
52
59
|
self.event_subscriptions = {}
|
|
53
60
|
if self.application_manifest is None:
|
|
@@ -168,6 +175,7 @@ class BaseApplication:
|
|
|
168
175
|
self.server = APIServer(
|
|
169
176
|
workflow_client=self.workflow_client,
|
|
170
177
|
ui_enabled=ui_enabled,
|
|
178
|
+
handler=self.handler_class(client=self.client_class()),
|
|
171
179
|
)
|
|
172
180
|
|
|
173
181
|
if self.event_subscriptions:
|
|
@@ -13,8 +13,8 @@ from application_sdk.constants import (
|
|
|
13
13
|
WORKFLOW_AUTH_ENABLED,
|
|
14
14
|
WORKFLOW_AUTH_URL_KEY,
|
|
15
15
|
)
|
|
16
|
-
from application_sdk.inputs.secretstore import SecretStoreInput
|
|
17
16
|
from application_sdk.observability.logger_adaptor import get_logger
|
|
17
|
+
from application_sdk.services.secretstore import SecretStore
|
|
18
18
|
|
|
19
19
|
logger = get_logger(__name__)
|
|
20
20
|
|
|
@@ -39,7 +39,7 @@ class AtlanAuthClient:
|
|
|
39
39
|
(environment variables, AWS Secrets Manager, Azure Key Vault, etc.)
|
|
40
40
|
"""
|
|
41
41
|
self.application_name = APPLICATION_NAME
|
|
42
|
-
self.auth_config: Dict[str, Any] =
|
|
42
|
+
self.auth_config: Dict[str, Any] = SecretStore.get_deployment_secret()
|
|
43
43
|
self.auth_enabled: bool = WORKFLOW_AUTH_ENABLED
|
|
44
44
|
self.auth_url: Optional[str] = None
|
|
45
45
|
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
|
+
|
|
3
|
+
import httpx
|
|
4
|
+
from httpx import Headers
|
|
5
|
+
from httpx._types import (
|
|
6
|
+
AuthTypes,
|
|
7
|
+
HeaderTypes,
|
|
8
|
+
QueryParamTypes,
|
|
9
|
+
RequestData,
|
|
10
|
+
RequestFiles,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from application_sdk.clients import ClientInterface
|
|
14
|
+
from application_sdk.observability.logger_adaptor import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class BaseClient(ClientInterface):
|
|
20
|
+
"""
|
|
21
|
+
Base client for non-SQL based applications.
|
|
22
|
+
|
|
23
|
+
This class provides a base implementation for clients that need to connect
|
|
24
|
+
to non-SQL data sources. It implements the ClientInterface and provides
|
|
25
|
+
basic functionality that can be extended by subclasses.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
credentials (Dict[str, Any]): Client credentials for authentication.
|
|
29
|
+
http_headers (HeaderTypes): HTTP headers for all http requests made by this client. Supports dict, Headers object, or list of tuples.
|
|
30
|
+
http_retry_transporter (httpx.AsyncBaseTransport): HTTP transport for requests. Uses httpx default transport by default.
|
|
31
|
+
Can be overridden in load() method for custom retry behavior.
|
|
32
|
+
|
|
33
|
+
Extending the Client:
|
|
34
|
+
To customize retry behavior, subclasses can override the http_retry_transporter
|
|
35
|
+
in the load() method, similar to how http_headers is set:
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
>>> class MyClient(BaseClient):
|
|
39
|
+
... async def load(self, **kwargs):
|
|
40
|
+
... # Set up HTTP headers in load method for better modularity
|
|
41
|
+
... credentials = kwargs.get("credentials", {})
|
|
42
|
+
... # Can use dict, Headers object, or list of tuples
|
|
43
|
+
... self.http_headers = {
|
|
44
|
+
... "Authorization": f"Bearer {credentials.get('token')}",
|
|
45
|
+
... "User-Agent": "MyApp/1.0"
|
|
46
|
+
... }
|
|
47
|
+
... # Optionally override retry transport with custom configuration
|
|
48
|
+
... # For advanced retry logic with status code handling, use httpx-retries:
|
|
49
|
+
... # from httpx_retries import Retry, RetryTransport
|
|
50
|
+
... # retry = Retry(total=5, backoff_factor=20)
|
|
51
|
+
... # self.http_retry_transporter = RetryTransport(retry=retry) #replace transport with custom transport if needed
|
|
52
|
+
|
|
53
|
+
Advanced Retry Configuration:
|
|
54
|
+
For applications requiring advanced retry logic (e.g., status code-based retries,
|
|
55
|
+
rate limiting, custom backoff strategies), consider using httpx-retries library:
|
|
56
|
+
|
|
57
|
+
>>> class MyClient(BaseClient):
|
|
58
|
+
... async def load(self, **kwargs):
|
|
59
|
+
... # Set up headers
|
|
60
|
+
... self.http_headers = {"Authorization": f"Bearer {kwargs.get('token')}"}
|
|
61
|
+
...
|
|
62
|
+
... # Install httpx-retries: pip install httpx-retries
|
|
63
|
+
... from httpx_retries import Retry, RetryTransport
|
|
64
|
+
...
|
|
65
|
+
... # Configure retry for status codes and network errors
|
|
66
|
+
... retry = Retry(
|
|
67
|
+
... total=5,
|
|
68
|
+
... backoff_factor=10,
|
|
69
|
+
... status_forcelist=[429, 500, 502, 503, 504]
|
|
70
|
+
... )
|
|
71
|
+
... self.http_retry_transporter = RetryTransport(retry=retry)
|
|
72
|
+
|
|
73
|
+
Header Management:
|
|
74
|
+
The client supports a two-level header system using httpx Headers for merging headers:
|
|
75
|
+
- Client-level headers: Set in the load() method and used for all requests
|
|
76
|
+
- Method-level headers: Passed to individual methods and override/add to client headers
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
>>> client = MyClient()
|
|
80
|
+
>>> await client.load(credentials={"token": "initial_token"})
|
|
81
|
+
>>> # This request will use: {"Authorization": "Bearer initial_token", "User-Agent": "MyApp/1.0", "Content-Type": "application/json"}
|
|
82
|
+
>>> response = await client.execute_http_post_request(
|
|
83
|
+
... url="https://api.example.com/data",
|
|
84
|
+
... headers={"Content-Type": "application/json"}
|
|
85
|
+
... )
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
credentials: Dict[str, Any] = {},
|
|
91
|
+
http_headers: HeaderTypes = {},
|
|
92
|
+
):
|
|
93
|
+
"""
|
|
94
|
+
Initialize the base client.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
credentials (Dict[str, Any], optional): Client credentials for authentication. Defaults to {}.
|
|
98
|
+
http_headers (HeaderTypes, optional): HTTP headers for all requests. Defaults to {}.
|
|
99
|
+
"""
|
|
100
|
+
self.credentials = credentials
|
|
101
|
+
self.http_headers = http_headers
|
|
102
|
+
|
|
103
|
+
# Use httpx default transport (no retries on status codes)
|
|
104
|
+
self.http_retry_transport: httpx.AsyncBaseTransport = httpx.AsyncHTTPTransport()
|
|
105
|
+
|
|
106
|
+
async def load(self, **kwargs: Any) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Initialize the client with credentials and necessary attributes for the client to work.
|
|
109
|
+
|
|
110
|
+
This method should be implemented by subclasses to:
|
|
111
|
+
- Set up authentication headers in self.http_headers in case of http requestss
|
|
112
|
+
- Initialize any required client state
|
|
113
|
+
- Handle credential processing
|
|
114
|
+
- Optionally override self.http_retry_transport for custom retry behavior
|
|
115
|
+
|
|
116
|
+
For advanced retry logic (status code-based retries, rate limiting, custom backoff),
|
|
117
|
+
consider using httpx-retries library and overriding http_retry_transport:
|
|
118
|
+
|
|
119
|
+
Example:
|
|
120
|
+
>>> async def load(self, **kwargs):
|
|
121
|
+
... # Set up headers
|
|
122
|
+
... self.http_headers = {"Authorization": f"Bearer {kwargs.get('token')}"}
|
|
123
|
+
...
|
|
124
|
+
... # For advanced retry logic, install httpx-retries: pip install httpx-retries
|
|
125
|
+
... from httpx_retries import Retry, RetryTransport
|
|
126
|
+
... retry = Retry(total=5, backoff_factor=10, status_forcelist=[429, 500, 502, 503, 504])
|
|
127
|
+
... self.http_retry_transport = RetryTransport(retry=retry)
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
**kwargs: Additional keyword arguments, typically including credentials.
|
|
131
|
+
May also include retry configuration parameters that can be used to
|
|
132
|
+
create a custom http_retry_transport.
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
NotImplementedError: If the subclass does not implement this method.
|
|
136
|
+
"""
|
|
137
|
+
raise NotImplementedError("load method is not implemented")
|
|
138
|
+
|
|
139
|
+
async def execute_http_get_request(
|
|
140
|
+
self,
|
|
141
|
+
url: str,
|
|
142
|
+
headers: Optional[HeaderTypes] = None,
|
|
143
|
+
params: Optional[QueryParamTypes] = None,
|
|
144
|
+
auth: Optional[AuthTypes] = None,
|
|
145
|
+
timeout: int = 10,
|
|
146
|
+
) -> Optional[httpx.Response]:
|
|
147
|
+
"""
|
|
148
|
+
Perform an HTTP GET request using the configured transport.
|
|
149
|
+
|
|
150
|
+
This method uses httpx default transport which only retries on network-level errors
|
|
151
|
+
(connection failures, timeouts). For status code-based retries (429, 500, etc.),
|
|
152
|
+
consider overriding http_retry_transport in the load() method using httpx-retries library.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
url (str): The URL to make the GET request to
|
|
156
|
+
headers (Optional[HeaderTypes]): HTTP headers to include in the request. Supports dict, Headers object, or list of tuples. These headers will override/add to any client-level headers set in the load() method.
|
|
157
|
+
params (Optional[QueryParamTypes]): Query parameters to include in the request. Supports dict, list of tuples, or string.
|
|
158
|
+
auth (Optional[AuthTypes]): Authentication to use for the request. Supports BasicAuth, DigestAuth, custom auth classes, or tuples for basic auth.
|
|
159
|
+
timeout (int): Request timeout in seconds. Defaults to 10.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Optional[httpx.Response]: The HTTP response if successful, None if failed
|
|
163
|
+
|
|
164
|
+
Example:
|
|
165
|
+
>>> # Using Basic Authentication
|
|
166
|
+
>>> from httpx import BasicAuth
|
|
167
|
+
>>> response = await client.execute_http_get_request(
|
|
168
|
+
... url="https://api.example.com/data",
|
|
169
|
+
... auth=BasicAuth("username", "password"),
|
|
170
|
+
... params={"limit": 100}
|
|
171
|
+
... )
|
|
172
|
+
>>>
|
|
173
|
+
>>> # Using tuple for basic auth (username, password)
|
|
174
|
+
>>> response = await client.execute_http_get_request(
|
|
175
|
+
... url="https://api.example.com/data",
|
|
176
|
+
... auth=("username", "password"),
|
|
177
|
+
... params={"limit": 100}
|
|
178
|
+
... )
|
|
179
|
+
>>>
|
|
180
|
+
>>> # Using custom headers for Bearer token
|
|
181
|
+
>>> response = await client.execute_http_get_request(
|
|
182
|
+
... url="https://api.example.com/data",
|
|
183
|
+
... headers={"Authorization": "Bearer token"},
|
|
184
|
+
... params={"limit": 100}
|
|
185
|
+
... )
|
|
186
|
+
"""
|
|
187
|
+
async with httpx.AsyncClient(
|
|
188
|
+
timeout=timeout, transport=self.http_retry_transport
|
|
189
|
+
) as client:
|
|
190
|
+
merged_headers = Headers(self.http_headers)
|
|
191
|
+
if headers:
|
|
192
|
+
merged_headers.update(headers)
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
response = await client.get(
|
|
196
|
+
url,
|
|
197
|
+
headers=merged_headers,
|
|
198
|
+
params=params,
|
|
199
|
+
auth=auth if auth is not None else httpx.USE_CLIENT_DEFAULT,
|
|
200
|
+
)
|
|
201
|
+
return response
|
|
202
|
+
except httpx.HTTPStatusError as e:
|
|
203
|
+
logger.error(f"HTTP error for {url}: {e.response.status_code}")
|
|
204
|
+
return None
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"Request failed for {url}: {e}")
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
async def execute_http_post_request(
|
|
210
|
+
self,
|
|
211
|
+
url: str,
|
|
212
|
+
data: Optional[RequestData] = None,
|
|
213
|
+
json_data: Optional[Any] = None,
|
|
214
|
+
content: Optional[bytes] = None,
|
|
215
|
+
files: Optional[RequestFiles] = None,
|
|
216
|
+
headers: Optional[HeaderTypes] = None,
|
|
217
|
+
params: Optional[QueryParamTypes] = None,
|
|
218
|
+
cookies: Optional[Dict[str, str]] = None,
|
|
219
|
+
auth: Optional[AuthTypes] = None,
|
|
220
|
+
follow_redirects: bool = True,
|
|
221
|
+
verify: bool = True,
|
|
222
|
+
timeout: int = 30,
|
|
223
|
+
) -> Optional[httpx.Response]:
|
|
224
|
+
"""
|
|
225
|
+
Perform an HTTP POST request using the configured transport.
|
|
226
|
+
|
|
227
|
+
This method uses httpx default transport which only retries on network-level errors
|
|
228
|
+
(connection failures, timeouts). For status code-based retries (429, 500, etc.),
|
|
229
|
+
consider overriding http_retry_transport in the load() method using httpx-retries library.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
url (str): The URL to make the POST request to
|
|
233
|
+
data (Optional[RequestData]): Form data to send in the request body. Supports dict, list of tuples, or other httpx-compatible formats.
|
|
234
|
+
json_data (Optional[Any]): JSON data to send in the request body. Any JSON-serializable object.
|
|
235
|
+
content (Optional[bytes]): Raw binary content to send in the request body
|
|
236
|
+
files (Optional[RequestFiles]): Files to upload in the request body. Supports various file formats and tuples.
|
|
237
|
+
headers (Optional[HeaderTypes]): HTTP headers to include in the request. Supports dict, Headers object, or list of tuples. These headers will override/add to any client-level headers set in the load() method.
|
|
238
|
+
params (Optional[QueryParamTypes]): Query parameters to include in the request. Supports dict, list of tuples, or string.
|
|
239
|
+
cookies (Optional[Dict[str, str]]): Cookies to include in the request
|
|
240
|
+
auth (Optional[AuthTypes]): Authentication to use for the request. Supports BasicAuth, DigestAuth, custom auth classes, or tuples for basic auth.
|
|
241
|
+
follow_redirects (bool): Whether to follow HTTP redirects. Defaults to True.
|
|
242
|
+
verify (bool): Whether to verify SSL certificates. Defaults to True.
|
|
243
|
+
timeout (int): Request timeout in seconds. Defaults to 30.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Optional[httpx.Response]: The HTTP response if successful, None if failed
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
>>> # Basic JSON POST request with authentication
|
|
250
|
+
>>> from httpx import BasicAuth
|
|
251
|
+
>>> response = await client.execute_http_post_request(
|
|
252
|
+
... url="https://api.example.com/data",
|
|
253
|
+
... json_data={"name": "test", "value": 123},
|
|
254
|
+
... headers={"Content-Type": "application/json"},
|
|
255
|
+
... auth=BasicAuth("username", "password")
|
|
256
|
+
... )
|
|
257
|
+
>>>
|
|
258
|
+
>>> # File upload with basic auth tuple
|
|
259
|
+
>>> with open("file.txt", "rb") as f:
|
|
260
|
+
... response = await client.execute_http_post_request(
|
|
261
|
+
... url="https://api.example.com/upload",
|
|
262
|
+
... data={"description": "My file"},
|
|
263
|
+
... files={"file": ("file.txt", f.read(), "text/plain")},
|
|
264
|
+
... auth=("username", "password")
|
|
265
|
+
... )
|
|
266
|
+
"""
|
|
267
|
+
async with httpx.AsyncClient(
|
|
268
|
+
timeout=timeout, transport=self.http_retry_transport, verify=verify
|
|
269
|
+
) as client:
|
|
270
|
+
merged_headers = Headers(self.http_headers)
|
|
271
|
+
if headers:
|
|
272
|
+
merged_headers.update(headers)
|
|
273
|
+
|
|
274
|
+
try:
|
|
275
|
+
response = await client.post(
|
|
276
|
+
url,
|
|
277
|
+
data=data,
|
|
278
|
+
json=json_data,
|
|
279
|
+
content=content,
|
|
280
|
+
files=files,
|
|
281
|
+
headers=merged_headers,
|
|
282
|
+
params=params,
|
|
283
|
+
cookies=cookies,
|
|
284
|
+
auth=auth if auth is not None else httpx.USE_CLIENT_DEFAULT,
|
|
285
|
+
follow_redirects=follow_redirects,
|
|
286
|
+
)
|
|
287
|
+
return response
|
|
288
|
+
except httpx.HTTPStatusError as e:
|
|
289
|
+
logger.error(f"HTTP error for {url}: {e.response.status_code}")
|
|
290
|
+
return None
|
|
291
|
+
except Exception as e:
|
|
292
|
+
logger.error(f"Request failed for {url}: {e}")
|
|
293
|
+
return None
|